File size: 4,914 Bytes
bc088da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from ml_collections import config_dict
import yaml
from diffusers.schedulers import (
DDIMScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
DDPMScheduler,
)
from inversion_utils import (
deterministic_ddim_step,
deterministic_ddpm_step,
deterministic_euler_step,
deterministic_non_ancestral_euler_step,
)
BREAKDOWNS = ["x_t_c_hat", "x_t_hat_c", "no_breakdown", "x_t_hat_c_with_zeros"]
SCHEDULERS = ["ddpm", "ddim", "euler", "euler_non_ancestral"]
MODELS = [
"stabilityai/sdxl-turbo",
"stabilityai/stable-diffusion-xl-base-1.0",
"CompVis/stable-diffusion-v1-4",
]
def get_num_steps_actual(cfg):
return (
cfg.num_steps_inversion
- cfg.step_start
+ (1 if cfg.clean_step_timestep > 0 else 0)
if cfg.timesteps is None
else len(cfg.timesteps) + (1 if cfg.clean_step_timestep > 0 else 0)
)
def get_config(args):
if args.config_from_file and args.config_from_file != "":
with open(args.config_from_file, "r") as f:
cfg = config_dict.ConfigDict(yaml.safe_load(f))
num_steps_actual = get_num_steps_actual(cfg)
else:
cfg = config_dict.ConfigDict()
cfg.seed = 2
cfg.self_r = 0.5
cfg.cross_r = 0.9
cfg.eta = 1
cfg.scheduler_type = SCHEDULERS[0]
cfg.num_steps_inversion = 50 # timesteps: 999, 799, 599, 399, 199
cfg.step_start = 20
cfg.timesteps = None
cfg.noise_timesteps = None
num_steps_actual = get_num_steps_actual(cfg)
cfg.ws1 = [2] * num_steps_actual
cfg.ws2 = [1] * num_steps_actual
cfg.real_cfg_scale = 0
cfg.real_cfg_scale_save = 0
cfg.breakdown = BREAKDOWNS[1]
cfg.noise_shift_delta = 1
cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
cfg.clean_step_timestep = 0
cfg.model = MODELS[1]
if cfg.scheduler_type == "ddim":
cfg.scheduler_class = DDIMScheduler
cfg.step_function = deterministic_ddim_step
elif cfg.scheduler_type == "ddpm":
cfg.scheduler_class = DDPMScheduler
cfg.step_function = deterministic_ddpm_step
elif cfg.scheduler_type == "euler":
cfg.scheduler_class = EulerAncestralDiscreteScheduler
cfg.step_function = deterministic_euler_step
elif cfg.scheduler_type == "euler_non_ancestral":
cfg.scheduler_class = EulerDiscreteScheduler
cfg.step_function = deterministic_non_ancestral_euler_step
else:
raise ValueError(f"Unknown scheduler type: {cfg.scheduler_type}")
with cfg.ignore_type():
if isinstance(cfg.max_norm_zs, (int, float)):
cfg.max_norm_zs = [cfg.max_norm_zs] * num_steps_actual
if isinstance(cfg.ws1, (int, float)):
cfg.ws1 = [cfg.ws1] * num_steps_actual
if isinstance(cfg.ws2, (int, float)):
cfg.ws2 = [cfg.ws2] * num_steps_actual
if not hasattr(cfg, "update_eta"):
cfg.update_eta = False
if not hasattr(cfg, "save_timesteps"):
cfg.save_timesteps = None
if not hasattr(cfg, "scheduler_timesteps"):
cfg.scheduler_timesteps = None
assert (
cfg.scheduler_type == "ddpm" or cfg.timesteps is None
), "timesteps must be None for ddim/euler"
cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
assert (
len(cfg.max_norm_zs) == num_steps_actual
), f"len(cfg.max_norm_zs) ({len(cfg.max_norm_zs)}) != num_steps_actual ({num_steps_actual})"
assert (
len(cfg.ws1) == num_steps_actual
), f"len(cfg.ws1) ({len(cfg.ws1)}) != num_steps_actual ({num_steps_actual})"
assert (
len(cfg.ws2) == num_steps_actual
), f"len(cfg.ws2) ({len(cfg.ws2)}) != num_steps_actual ({num_steps_actual})"
assert cfg.noise_timesteps is None or len(cfg.noise_timesteps) == (
num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0)
), f"len(cfg.noise_timesteps) ({len(cfg.noise_timesteps)}) != num_steps_actual ({num_steps_actual})"
assert cfg.save_timesteps is None or len(cfg.save_timesteps) == (
num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0)
), f"len(cfg.save_timesteps) ({len(cfg.save_timesteps)}) != num_steps_actual ({num_steps_actual})"
return cfg
def get_config_name(config, args):
if args.folder_name is not None and args.folder_name != "":
return args.folder_name
timesteps_str = (
f"step_start {config.step_start}"
if config.timesteps is None
else f"timesteps {config.timesteps}"
)
return f"""\
ws1 {config.ws1[0]} ws2 {config.ws2[0]} real_cfg_scale {config.real_cfg_scale} {timesteps_str} \
real_cfg_scale_save {config.real_cfg_scale_save} seed {config.seed} max_norm_zs {config.max_norm_zs[-1]} noise_shift_delta {config.noise_shift_delta} \
scheduler_type {config.scheduler_type} fp16 {args.fp16}\
"""
|