Spaces:
Running
on
A10G
Running
on
A10G
File size: 2,175 Bytes
26555ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
from . import gaussian_diffusion as gd
from .respace import SpacedDiffusion, space_timesteps
# !important
def create_diffusion(
timestep_respacing="",
noise_schedule="linear", # 'linear' for training
use_kl=False,
rescale_learned_sigmas=False,
prediction_type='v_prediction',
variance_type='fixed_small',
beta_start=0.0001,
beta_end=0.02,
# beta_start=0.00085,
# beta_end=0.012,
diffusion_steps=1000
):
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps, beta_start=beta_start, beta_end=beta_end)
if prediction_type == 'epsilon':
model_mean_type = gd.ModelMeanType.EPSILON # EPSILON type for stable-diffusion-2-1 512
elif prediction_type == 'xstart':
model_mean_type = gd.ModelMeanType.START_X
elif prediction_type == 'v_prediction':
model_mean_type = gd.ModelMeanType.PREVIOUS_V # gd.ModelMeanType.PREVIOUS_V for stable-diffusion-2-1 768/x4-upscaler
if variance_type == 'fixed_small':
model_var_type = gd.ModelVarType.FIXED_SMALL
elif variance_type == 'fixed_large':
model_var_type = gd.ModelVarType.FIXED_LARGE
elif variance_type == 'learned_range':
model_var_type = gd.ModelVarType.LEARNED_RANGE
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if timestep_respacing is None or timestep_respacing == "":
timestep_respacing = [diffusion_steps]
return SpacedDiffusion(
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
betas=betas,
model_mean_type=(model_mean_type),
model_var_type=(model_var_type),
loss_type=loss_type
# rescale_timesteps=rescale_timesteps,
)
|