Spaces:
Running
on
Zero
Running
on
Zero
from diffusers import AutoPipelineForImage2Image | |
from diffusers import DDPMScheduler | |
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps, retrieve_latents | |
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput | |
import torch | |
from PIL import Image | |
num_steps_inversion = 5 | |
strngth = 0.8 | |
generator = None | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
image_path = "edit_dataset/01.jpg" | |
src_prompt = "butterfly perched on purple flower" | |
tgt_prompt = "dragonfly perched on purple flower" | |
ws1 = [1.5, 1.5, 1.5, 1.5] | |
ws2 = [1, 1, 1, 1] | |
def encode_image(image, pipe): | |
image = pipe.image_processor.preprocess(image) | |
image = image.to(device=device, dtype=pipeline.dtype) | |
if pipe.vae.config.force_upcast: | |
image = image.float() | |
pipe.vae.to(dtype=torch.float32) | |
if isinstance(generator, list): | |
init_latents = [ | |
retrieve_latents(pipe.vae.encode(image[i : i + 1]), generator=generator[i]) | |
for i in range(1) | |
] | |
init_latents = torch.cat(init_latents, dim=0) | |
else: | |
init_latents = retrieve_latents(pipe.vae.encode(image), generator=generator) | |
if pipe.vae.config.force_upcast: | |
pipe.vae.to(pipeline.dtype) | |
init_latents = init_latents.to(pipeline.dtype) | |
init_latents = pipe.vae.config.scaling_factor * init_latents | |
return init_latents.to(dtype=torch.float16) | |
# def create_xts(scheduler, timesteps, x_0, noise_shift_delta=1, generator=None): | |
# noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1]) | |
# noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps] | |
# noise_timesteps = noise_timesteps[:3] | |
# x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1) | |
# noise = torch.randn(x_0_expanded.size(), generator=generator, device="cpu", dtype=x_0.dtype).to(x_0.device) | |
# x_ts = scheduler.add_noise(x_0_expanded, noise, torch.IntTensor(noise_timesteps)) | |
# x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)] | |
# x_ts += [x_0] | |
# return x_ts | |
def deterministic_ddpm_step( | |
model_output: torch.FloatTensor, | |
timestep, | |
sample: torch.FloatTensor, | |
eta, | |
use_clipped_model_output, | |
generator, | |
variance_noise, | |
return_dict, | |
scheduler, | |
): | |
""" | |
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | |
process from the learned model outputs (most often the predicted noise). | |
Args: | |
model_output (`torch.FloatTensor`): | |
The direct output from learned diffusion model. | |
timestep (`float`): | |
The current discrete timestep in the diffusion chain. | |
sample (`torch.FloatTensor`): | |
A current instance of a sample created by the diffusion process. | |
generator (`torch.Generator`, *optional*): | |
A random number generator. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. | |
Returns: | |
[`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`: | |
If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a | |
tuple is returned where the first element is the sample tensor. | |
""" | |
t = timestep | |
prev_t = scheduler.previous_timestep(t) | |
if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [ | |
"learned", | |
"learned_range", | |
]: | |
model_output, predicted_variance = torch.split( | |
model_output, sample.shape[1], dim=1 | |
) | |
else: | |
predicted_variance = None | |
# 1. compute alphas, betas | |
alpha_prod_t = scheduler.alphas_cumprod[t] | |
alpha_prod_t_prev = ( | |
scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one | |
) | |
beta_prod_t = 1 - alpha_prod_t | |
beta_prod_t_prev = 1 - alpha_prod_t_prev | |
current_alpha_t = alpha_prod_t / alpha_prod_t_prev | |
current_beta_t = 1 - current_alpha_t | |
# 2. compute predicted original sample from predicted noise also called | |
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf | |
if scheduler.config.prediction_type == "epsilon": | |
pred_original_sample = ( | |
sample - beta_prod_t ** (0.5) * model_output | |
) / alpha_prod_t ** (0.5) | |
elif scheduler.config.prediction_type == "sample": | |
pred_original_sample = model_output | |
elif scheduler.config.prediction_type == "v_prediction": | |
pred_original_sample = (alpha_prod_t**0.5) * sample - ( | |
beta_prod_t**0.5 | |
) * model_output | |
else: | |
raise ValueError( | |
f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" | |
" `v_prediction` for the DDPMScheduler." | |
) | |
# 3. Clip or threshold "predicted x_0" | |
if scheduler.config.thresholding: | |
pred_original_sample = scheduler._threshold_sample(pred_original_sample) | |
elif scheduler.config.clip_sample: | |
pred_original_sample = pred_original_sample.clamp( | |
-scheduler.config.clip_sample_range, scheduler.config.clip_sample_range | |
) | |
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t | |
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf | |
pred_original_sample_coeff = ( | |
alpha_prod_t_prev ** (0.5) * current_beta_t | |
) / beta_prod_t | |
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t | |
# 5. Compute predicted previous sample µ_t | |
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf | |
pred_prev_sample = ( | |
pred_original_sample_coeff * pred_original_sample | |
+ current_sample_coeff * sample | |
) | |
return pred_prev_sample | |
def normalize( | |
z_t, | |
i, | |
max_norm_zs, | |
): | |
max_norm = max_norm_zs[i] | |
if max_norm < 0: | |
return z_t, 1 | |
norm = torch.norm(z_t) | |
if norm < max_norm: | |
return z_t, 1 | |
coeff = max_norm / norm | |
z_t = z_t * coeff | |
return z_t, coeff | |
def step_save_latents( | |
self, | |
model_output: torch.FloatTensor, | |
timestep: int, | |
sample: torch.FloatTensor, | |
eta: float = 0.0, | |
use_clipped_model_output: bool = False, | |
generator=None, | |
variance_noise= None, | |
return_dict: bool = True, | |
): | |
timestep_index = self._inner_index | |
next_timestep_index = timestep_index + 1 | |
u_hat_t = deterministic_ddpm_step( | |
model_output=model_output, | |
timestep=timestep, | |
sample=sample, | |
eta=eta, | |
use_clipped_model_output=use_clipped_model_output, | |
generator=generator, | |
variance_noise=variance_noise, | |
return_dict=False, | |
scheduler=self, | |
) | |
x_t_minus_1 = self.x_ts[timestep_index] | |
self.x_ts_c_hat.append(u_hat_t) | |
z_t = x_t_minus_1 - u_hat_t | |
self.latents.append(z_t) | |
z_t, _ = normalize(z_t, timestep_index, [-1, -1, -1, 15.5]) | |
x_t_minus_1_predicted = u_hat_t + z_t | |
if not return_dict: | |
return (x_t_minus_1_predicted,) | |
return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None) | |
def step_use_latents( | |
self, | |
model_output: torch.FloatTensor, | |
timestep: int, | |
sample: torch.FloatTensor, | |
eta: float = 0.0, | |
use_clipped_model_output: bool = False, | |
generator=None, | |
variance_noise= None, | |
return_dict: bool = True, | |
): | |
print(f'_inner_index: {self._inner_index}') | |
timestep_index = self._inner_index | |
next_timestep_index = timestep_index + 1 | |
z_t = self.latents[timestep_index] # + 1 because latents[0] is X_T | |
_, normalize_coefficient = normalize( | |
z_t, | |
timestep_index, | |
[-1, -1, -1, 15.5], | |
) | |
if normalize_coefficient == 0: | |
eta = 0 | |
# eta = normalize_coefficient | |
x_t_hat_c_hat = deterministic_ddpm_step( | |
model_output=model_output, | |
timestep=timestep, | |
sample=sample, | |
eta=eta, | |
use_clipped_model_output=use_clipped_model_output, | |
generator=generator, | |
variance_noise=variance_noise, | |
return_dict=False, | |
scheduler=self, | |
) | |
w1 = ws1[timestep_index] | |
w2 = ws2[timestep_index] | |
x_t_minus_1_exact = self.x_ts[timestep_index] | |
x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat) | |
x_t_c_hat: torch.Tensor = self.x_ts_c_hat[timestep_index] | |
x_t_c = x_t_c_hat[0].expand_as(x_t_hat_c_hat) | |
zero_index_reconstruction = 0 | |
edit_prompts_num = (model_output.size(0) - zero_index_reconstruction) // 2 | |
x_t_hat_c_indices = (zero_index_reconstruction, edit_prompts_num + zero_index_reconstruction) | |
edit_images_indices = ( | |
edit_prompts_num + zero_index_reconstruction, | |
model_output.size(0) | |
) | |
x_t_hat_c = torch.zeros_like(x_t_hat_c_hat) | |
x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[ | |
x_t_hat_c_indices[0] : x_t_hat_c_indices[1] | |
] | |
v1 = x_t_hat_c_hat - x_t_hat_c | |
v2 = x_t_hat_c - normalize_coefficient * x_t_c | |
x_t_minus_1 = normalize_coefficient * x_t_minus_1_exact + w1 * v1 + w2 * v2 | |
x_t_minus_1[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1[ | |
edit_images_indices[0] : edit_images_indices[1] | |
] # update x_t_hat_c to be x_t_hat_c_hat | |
if not return_dict: | |
return (x_t_minus_1,) | |
return DDIMSchedulerOutput( | |
prev_sample=x_t_minus_1, | |
pred_original_sample=None, | |
) | |
class myDDPMScheduler(DDPMScheduler): | |
def step( | |
self, | |
model_output: torch.FloatTensor, | |
timestep: int, | |
sample: torch.FloatTensor, | |
eta: float = 0.0, | |
use_clipped_model_output: bool = False, | |
generator=None, | |
variance_noise= None, | |
return_dict: bool = True, | |
): | |
print(f"timestep: {timestep}") | |
res_inv = step_save_latents( | |
self, | |
model_output[:1, :, :, :], | |
timestep, | |
sample[:1, :, :, :], | |
eta, | |
use_clipped_model_output, | |
generator, | |
variance_noise, | |
return_dict, | |
) | |
res_inf = step_use_latents( | |
self, | |
model_output[1:, :, :, :], | |
timestep, | |
sample[1:, :, :, :], | |
eta, | |
use_clipped_model_output, | |
generator, | |
variance_noise, | |
return_dict, | |
) | |
self._inner_index+=1 | |
res = (torch.cat((res_inv[0], res_inf[0]), dim=0),) | |
return res | |
pipeline = AutoPipelineForImage2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", safety_checker = None) | |
pipeline = pipeline.to(device) | |
pipeline.scheduler = DDPMScheduler.from_pretrained( # type: ignore | |
'stabilityai/sdxl-turbo', | |
subfolder="scheduler", | |
# cache_dir="/home/joberant/NLP_2223/giladd/test_dir/sdxl-turbo/models_cache", | |
) | |
# pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config) | |
denoising_start = 0.2 | |
timesteps, num_inference_steps = retrieve_timesteps( | |
pipeline.scheduler, num_steps_inversion, device, None | |
) | |
timesteps, num_inference_steps = pipeline.get_timesteps( | |
num_inference_steps=num_inference_steps, | |
device=device, | |
denoising_start=denoising_start, | |
strength=0, | |
) | |
timesteps = timesteps.type(torch.int64) | |
from functools import partial | |
timesteps = [torch.tensor(t) for t in timesteps.tolist()] | |
pipeline.__call__ = partial( | |
pipeline.__call__, | |
num_inference_steps=num_steps_inversion, | |
guidance_scale=0, | |
generator=generator, | |
denoising_start=denoising_start, | |
strength=0, | |
) | |
# timesteps, num_inference_steps = retrieve_timesteps(pipeline.scheduler, num_steps_inversion, device, None) | |
# timesteps, num_inference_steps = pipeline.get_timesteps(num_inference_steps=num_inference_steps, device=device, strength=strngth) | |
from utils import get_ddpm_inversion_scheduler, create_xts | |
from config import get_config, get_config_name | |
import argparse | |
# parser = argparse.ArgumentParser() | |
# parser.add_argument("--images_paths", type=str, default=None) | |
# parser.add_argument("--images_folder", type=str, default=None) | |
# parser.set_defaults(force_use_cpu=False) | |
# parser.add_argument("--force_use_cpu", action="store_true") | |
# parser.add_argument("--folder_name", type=str, default='test_measure_time') | |
# parser.add_argument("--config_from_file", type=str, default='run_configs/noise_shift_guidance_1_5.yaml') | |
# parser.set_defaults(save_intermediate_results=False) | |
# parser.add_argument("--save_intermediate_results", action="store_true") | |
# parser.add_argument("--batch_size", type=int, default=None) | |
# parser.set_defaults(skip_p_to_p=False) | |
# parser.add_argument("--skip_p_to_p", action="store_true", default=True) | |
# parser.set_defaults(only_p_to_p=False) | |
# parser.add_argument("--only_p_to_p", action="store_true") | |
# parser.set_defaults(fp16=False) | |
# parser.add_argument("--fp16", action="store_true", default=False) | |
# parser.add_argument("--prompts_file", type=str, default='dataset_measure_time/dataset.json') | |
# parser.add_argument("--images_in_prompts_file", type=str, default=None) | |
# parser.add_argument("--seed", type=int, default=2) | |
# parser.add_argument("--time_measure_n", type=int, default=1) | |
# args = parser.parse_args() | |
class Object(object): | |
pass | |
args = Object() | |
args.images_paths = None | |
args.images_folder = None | |
args.force_use_cpu = False | |
args.folder_name = 'test_measure_time' | |
args.config_from_file = 'run_configs/noise_shift_guidance_1_5.yaml' | |
args.save_intermediate_results = False | |
args.batch_size = None | |
args.skip_p_to_p = True | |
args.only_p_to_p = False | |
args.fp16 = False | |
args.prompts_file = 'dataset_measure_time/dataset.json' | |
args.images_in_prompts_file = None | |
args.seed = 986 | |
args.time_measure_n = 1 | |
assert ( | |
args.batch_size is None or args.save_intermediate_results is False | |
), "save_intermediate_results is not implemented for batch_size > 1" | |
config = get_config(args) | |
# latent = latents[0].expand(3, -1, -1, -1) | |
# prompt = [src_prompt, src_prompt, tgt_prompt] | |
# image = pipeline.__call__(image=latent, prompt=prompt, eta=1).images | |
# for i, im in enumerate(image): | |
# im.save(f"output_{i}.png") | |
def run(image_path, src_prompt, tgt_prompt, seed, w1, w2): | |
generator = torch.Generator().manual_seed(seed) | |
x_0_image = Image.open(image_path).convert("RGB").resize((512, 512), Image.LANCZOS) | |
x_0 = encode_image(x_0_image, pipeline) | |
# x_ts = create_xts(pipeline.scheduler, timesteps, x_0, noise_shift_delta=1, generator=generator) | |
x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False) | |
x_ts = [xt.to(dtype=torch.float16) for xt in x_ts] | |
latents = [x_ts[0]] | |
x_ts_c_hat = [None] | |
config.ws1 = [w1] * 4 | |
config.ws2 = [w2] * 4 | |
pipeline.scheduler = get_ddpm_inversion_scheduler( | |
pipeline.scheduler, | |
config.step_function, | |
config, | |
timesteps, | |
config.save_timesteps, | |
latents, | |
x_ts, | |
x_ts_c_hat, | |
args.save_intermediate_results, | |
pipeline, | |
x_0, | |
v1s_images := [], | |
v2s_images := [], | |
deltas_images := [], | |
v1_x0s := [], | |
v2_x0s := [], | |
deltas_x0s := [], | |
"res12", | |
image_name="im_name", | |
time_measure_n=args.time_measure_n, | |
) | |
latent = latents[0].expand(3, -1, -1, -1) | |
prompt = [src_prompt, src_prompt, tgt_prompt] | |
image = pipeline.__call__(image=latent, prompt=prompt, eta=1).images | |
return image[2] | |
if __name__ == "__main__": | |
res = run(image_path, src_prompt, tgt_prompt, args.seed, 1.5, 1.0) | |
res.save("output.png") | |