|
import torch |
|
import os |
|
import sys |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from collections import OrderedDict |
|
from einops import rearrange |
|
from diffusers.utils.torch_utils import randn_tensor |
|
import numpy as np |
|
import math |
|
import random |
|
import PIL |
|
from PIL import Image |
|
from tqdm import tqdm |
|
from torchvision import transforms |
|
from copy import deepcopy |
|
from typing import Any, Callable, Dict, List, Optional, Union |
|
from accelerate import Accelerator |
|
from diffusion_schedulers import PyramidFlowMatchEulerDiscreteScheduler |
|
from video_vae.modeling_causal_vae import CausalVideoVAE |
|
|
|
from trainer_misc import ( |
|
all_to_all, |
|
is_sequence_parallel_initialized, |
|
get_sequence_parallel_group, |
|
get_sequence_parallel_group_rank, |
|
get_sequence_parallel_rank, |
|
get_sequence_parallel_world_size, |
|
get_rank, |
|
) |
|
|
|
from .modeling_pyramid_mmdit import PyramidDiffusionMMDiT |
|
from .modeling_text_encoder import SD3TextEncoderWithMask |
|
|
|
|
|
def compute_density_for_timestep_sampling( |
|
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None |
|
): |
|
if weighting_scheme == "logit_normal": |
|
|
|
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") |
|
u = torch.nn.functional.sigmoid(u) |
|
elif weighting_scheme == "mode": |
|
u = torch.rand(size=(batch_size,), device="cpu") |
|
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) |
|
else: |
|
u = torch.rand(size=(batch_size,), device="cpu") |
|
return u |
|
|
|
|
|
class PyramidDiTForVideoGeneration: |
|
""" |
|
The pyramid dit for both image and video generation, The running class wrapper |
|
This class is mainly for fixed unit implementation: 1 + n + n + n |
|
""" |
|
def __init__(self, model_path, model_dtype='bf16', use_gradient_checkpointing=False, return_log=True, |
|
model_variant="diffusion_transformer_768p", timestep_shift=1.0, stage_range=[0, 1/3, 2/3, 1], |
|
sample_ratios=[1, 1, 1], scheduler_gamma=1/3, use_mixed_training=False, use_flash_attn=False, |
|
load_text_encoder=True, load_vae=True, max_temporal_length=31, frame_per_unit=1, use_temporal_causal=True, |
|
corrupt_ratio=1/3, interp_condition_pos=True, stages=[1, 2, 4], **kwargs, |
|
): |
|
super().__init__() |
|
|
|
if model_dtype == 'bf16': |
|
torch_dtype = torch.bfloat16 |
|
elif model_dtype == 'fp16': |
|
torch_dtype = torch.float16 |
|
else: |
|
torch_dtype = torch.float32 |
|
|
|
self.stages = stages |
|
self.sample_ratios = sample_ratios |
|
self.corrupt_ratio = corrupt_ratio |
|
|
|
dit_path = os.path.join(model_path, model_variant) |
|
|
|
|
|
if use_mixed_training: |
|
print("using mixed precision training, do not explicitly casting models") |
|
self.dit = PyramidDiffusionMMDiT.from_pretrained( |
|
dit_path, use_gradient_checkpointing=use_gradient_checkpointing, |
|
use_flash_attn=use_flash_attn, use_t5_mask=True, |
|
add_temp_pos_embed=True, temp_pos_embed_type='rope', |
|
use_temporal_causal=use_temporal_causal, interp_condition_pos=interp_condition_pos, |
|
) |
|
else: |
|
print("using half precision") |
|
self.dit = PyramidDiffusionMMDiT.from_pretrained( |
|
dit_path, torch_dtype=torch_dtype, |
|
use_gradient_checkpointing=use_gradient_checkpointing, |
|
use_flash_attn=use_flash_attn, use_t5_mask=True, |
|
add_temp_pos_embed=True, temp_pos_embed_type='rope', |
|
use_temporal_causal=use_temporal_causal, interp_condition_pos=interp_condition_pos, |
|
) |
|
|
|
|
|
if load_text_encoder: |
|
self.text_encoder = SD3TextEncoderWithMask(model_path, torch_dtype=torch_dtype) |
|
else: |
|
self.text_encoder = None |
|
|
|
|
|
if load_vae: |
|
self.vae = CausalVideoVAE.from_pretrained(os.path.join(model_path, 'causal_video_vae'), torch_dtype=torch_dtype, interpolate=False) |
|
|
|
for parameter in self.vae.parameters(): |
|
parameter.requires_grad = False |
|
else: |
|
self.vae = None |
|
|
|
|
|
self.vae_shift_factor = 0.1490 |
|
self.vae_scale_factor = 1 / 1.8415 |
|
|
|
|
|
self.vae_video_shift_factor = -0.2343 |
|
self.vae_video_scale_factor = 1 / 3.0986 |
|
|
|
self.downsample = 8 |
|
|
|
|
|
|
|
self.frame_per_unit = frame_per_unit |
|
self.max_temporal_length = max_temporal_length |
|
assert (max_temporal_length - 1) % frame_per_unit == 0, "The frame number should be divided by the frame number per unit" |
|
self.num_units_per_video = 1 + ((max_temporal_length - 1) // frame_per_unit) + int(sum(sample_ratios)) |
|
|
|
self.scheduler = PyramidFlowMatchEulerDiscreteScheduler( |
|
shift=timestep_shift, stages=len(self.stages), |
|
stage_range=stage_range, gamma=scheduler_gamma, |
|
) |
|
print(f"The start sigmas and end sigmas of each stage is Start: {self.scheduler.start_sigmas}, End: {self.scheduler.end_sigmas}, Ori_start: {self.scheduler.ori_start_sigmas}") |
|
|
|
self.cfg_rate = 0.1 |
|
self.return_log = return_log |
|
self.use_flash_attn = use_flash_attn |
|
|
|
def load_checkpoint(self, checkpoint_path, model_key='model', **kwargs): |
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
dit_checkpoint = OrderedDict() |
|
for key in checkpoint: |
|
if key.startswith('vae') or key.startswith('text_encoder'): |
|
continue |
|
if key.startswith('dit'): |
|
new_key = key.split('.') |
|
new_key = '.'.join(new_key[1:]) |
|
dit_checkpoint[new_key] = checkpoint[key] |
|
else: |
|
dit_checkpoint[key] = checkpoint[key] |
|
|
|
load_result = self.dit.load_state_dict(dit_checkpoint, strict=True) |
|
print(f"Load checkpoint from {checkpoint_path}, load result: {load_result}") |
|
|
|
def load_vae_checkpoint(self, vae_checkpoint_path, model_key='model'): |
|
checkpoint = torch.load(vae_checkpoint_path, map_location='cpu') |
|
checkpoint = checkpoint[model_key] |
|
loaded_checkpoint = OrderedDict() |
|
|
|
for key in checkpoint.keys(): |
|
if key.startswith('vae.'): |
|
new_key = key.split('.') |
|
new_key = '.'.join(new_key[1:]) |
|
loaded_checkpoint[new_key] = checkpoint[key] |
|
|
|
load_result = self.vae.load_state_dict(loaded_checkpoint) |
|
print(f"Load the VAE from {vae_checkpoint_path}, load result: {load_result}") |
|
|
|
@torch.no_grad() |
|
def get_pyramid_latent(self, x, stage_num): |
|
|
|
vae_latent_list = [] |
|
vae_latent_list.append(x) |
|
|
|
temp, height, width = x.shape[-3], x.shape[-2], x.shape[-1] |
|
for _ in range(stage_num): |
|
height //= 2 |
|
width //= 2 |
|
x = rearrange(x, 'b c t h w -> (b t) c h w') |
|
x = torch.nn.functional.interpolate(x, size=(height, width), mode='bilinear') |
|
x = rearrange(x, '(b t) c h w -> b c t h w', t=temp) |
|
vae_latent_list.append(x) |
|
|
|
vae_latent_list = list(reversed(vae_latent_list)) |
|
return vae_latent_list |
|
|
|
def prepare_latents( |
|
self, |
|
batch_size, |
|
num_channels_latents, |
|
temp, |
|
height, |
|
width, |
|
dtype, |
|
device, |
|
generator, |
|
): |
|
shape = ( |
|
batch_size, |
|
num_channels_latents, |
|
int(temp), |
|
int(height) // self.downsample, |
|
int(width) // self.downsample, |
|
) |
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
return latents |
|
|
|
def sample_block_noise(self, bs, ch, temp, height, width): |
|
gamma = self.scheduler.config.gamma |
|
dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma) |
|
block_number = bs * ch * temp * (height // 2) * (width // 2) |
|
noise = torch.stack([dist.sample() for _ in range(block_number)]) |
|
noise = rearrange(noise, '(b c t h w) (p q) -> b c t (h p) (w q)',b=bs,c=ch,t=temp,h=height//2,w=width//2,p=2,q=2) |
|
return noise |
|
|
|
@torch.no_grad() |
|
def generate_one_unit( |
|
self, |
|
latents, |
|
past_conditions, |
|
prompt_embeds, |
|
prompt_attention_mask, |
|
pooled_prompt_embeds, |
|
num_inference_steps, |
|
height, |
|
width, |
|
temp, |
|
device, |
|
dtype, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
is_first_frame: bool = False, |
|
): |
|
stages = self.stages |
|
intermed_latents = [] |
|
|
|
for i_s in range(len(stages)): |
|
self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device) |
|
timesteps = self.scheduler.timesteps |
|
|
|
if i_s > 0: |
|
height *= 2; width *= 2 |
|
latents = rearrange(latents, 'b c t h w -> (b t) c h w') |
|
latents = F.interpolate(latents, size=(height, width), mode='nearest') |
|
latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp) |
|
|
|
ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] |
|
gamma = self.scheduler.config.gamma |
|
alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) |
|
beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) |
|
|
|
bs, ch, temp, height, width = latents.shape |
|
noise = self.sample_block_noise(bs, ch, temp, height, width) |
|
noise = noise.to(device=device, dtype=dtype) |
|
latents = alpha * latents + beta * noise |
|
|
|
for idx, t in enumerate(timesteps): |
|
|
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
|
|
|
|
|
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) |
|
|
|
latent_model_input = past_conditions[i_s] + [latent_model_input] |
|
|
|
noise_pred = self.dit( |
|
sample=[latent_model_input], |
|
timestep_ratio=timestep, |
|
encoder_hidden_states=prompt_embeds, |
|
encoder_attention_mask=prompt_attention_mask, |
|
pooled_projections=pooled_prompt_embeds, |
|
) |
|
|
|
noise_pred = noise_pred[0] |
|
|
|
|
|
if self.do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
if is_first_frame: |
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
else: |
|
noise_pred = noise_pred_uncond + self.video_guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
latents = self.scheduler.step( |
|
model_output=noise_pred, |
|
timestep=timestep, |
|
sample=latents, |
|
generator=generator, |
|
).prev_sample |
|
|
|
intermed_latents.append(latents) |
|
|
|
return intermed_latents |
|
|
|
@torch.no_grad() |
|
def generate_i2v( |
|
self, |
|
prompt: Union[str, List[str]] = '', |
|
input_image: PIL.Image = None, |
|
temp: int = 1, |
|
num_inference_steps: Optional[Union[int, List[int]]] = 28, |
|
guidance_scale: float = 7.0, |
|
video_guidance_scale: float = 4.0, |
|
min_guidance_scale: float = 2.0, |
|
use_linear_guidance: bool = False, |
|
alpha: float = 0.5, |
|
negative_prompt: Optional[Union[str, List[str]]]="cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror", |
|
num_images_per_prompt: Optional[int] = 1, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
output_type: Optional[str] = "pil", |
|
save_memory: bool = True, |
|
): |
|
device = self.device |
|
dtype = self.dtype |
|
|
|
width = input_image.width |
|
height = input_image.height |
|
|
|
assert temp % self.frame_per_unit == 0, "The frames should be divided by frame_per unit" |
|
|
|
if isinstance(prompt, str): |
|
batch_size = 1 |
|
prompt = prompt + ", hyper quality, Ultra HD, 8K" |
|
else: |
|
assert isinstance(prompt, list) |
|
batch_size = len(prompt) |
|
prompt = [_ + ", hyper quality, Ultra HD, 8K" for _ in prompt] |
|
|
|
if isinstance(num_inference_steps, int): |
|
num_inference_steps = [num_inference_steps] * len(self.stages) |
|
|
|
negative_prompt = negative_prompt or "" |
|
|
|
|
|
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device) |
|
negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device) |
|
|
|
if use_linear_guidance: |
|
max_guidance_scale = guidance_scale |
|
guidance_scale_list = [max(max_guidance_scale - alpha * t_, min_guidance_scale) for t_ in range(temp+1)] |
|
print(guidance_scale_list) |
|
|
|
self._guidance_scale = guidance_scale |
|
self._video_guidance_scale = video_guidance_scale |
|
|
|
if self.do_classifier_free_guidance: |
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
|
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) |
|
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) |
|
|
|
|
|
num_channels_latents = self.dit.config.in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
temp, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
) |
|
|
|
temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1] |
|
|
|
latents = rearrange(latents, 'b c t h w -> (b t) c h w') |
|
|
|
for _ in range(len(self.stages)-1): |
|
height //= 2;width //= 2 |
|
latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2 |
|
|
|
latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp) |
|
|
|
num_units = temp // self.frame_per_unit |
|
stages = self.stages |
|
|
|
|
|
image_transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), |
|
]) |
|
input_image_tensor = image_transform(input_image).unsqueeze(0).unsqueeze(2) |
|
input_image_latent = (self.vae.encode(input_image_tensor.to(device)).latent_dist.sample() - self.vae_shift_factor) * self.vae_scale_factor |
|
|
|
generated_latents_list = [input_image_latent] |
|
last_generated_latents = input_image_latent |
|
|
|
for unit_index in tqdm(range(1, num_units + 1)): |
|
if use_linear_guidance: |
|
self._guidance_scale = guidance_scale_list[unit_index] |
|
self._video_guidance_scale = guidance_scale_list[unit_index] |
|
|
|
|
|
past_condition_latents = [] |
|
clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1) |
|
|
|
for i_s in range(len(stages)): |
|
last_cond_latent = clean_latents_list[i_s][:,:,-self.frame_per_unit:] |
|
|
|
stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent] |
|
|
|
|
|
cur_unit_num = unit_index |
|
cur_stage = i_s |
|
cur_unit_ptx = 1 |
|
|
|
while cur_unit_ptx < cur_unit_num: |
|
cur_stage = max(cur_stage - 1, 0) |
|
if cur_stage == 0: |
|
break |
|
cur_unit_ptx += 1 |
|
cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)] |
|
stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents) |
|
|
|
if cur_stage == 0 and cur_unit_ptx < cur_unit_num: |
|
cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)] |
|
stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents) |
|
|
|
stage_input = list(reversed(stage_input)) |
|
past_condition_latents.append(stage_input) |
|
|
|
intermed_latents = self.generate_one_unit( |
|
latents[:,:,(unit_index - 1) * self.frame_per_unit:unit_index * self.frame_per_unit], |
|
past_condition_latents, |
|
prompt_embeds, |
|
prompt_attention_mask, |
|
pooled_prompt_embeds, |
|
num_inference_steps, |
|
height, |
|
width, |
|
self.frame_per_unit, |
|
device, |
|
dtype, |
|
generator, |
|
is_first_frame=False, |
|
) |
|
|
|
generated_latents_list.append(intermed_latents[-1]) |
|
last_generated_latents = intermed_latents |
|
|
|
generated_latents = torch.cat(generated_latents_list, dim=2) |
|
|
|
if output_type == "latent": |
|
image = generated_latents |
|
else: |
|
image = self.decode_latent(generated_latents, save_memory=save_memory) |
|
|
|
return image |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
prompt: Union[str, List[str]] = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
temp: int = 1, |
|
num_inference_steps: Optional[Union[int, List[int]]] = 28, |
|
video_num_inference_steps: Optional[Union[int, List[int]]] = 28, |
|
guidance_scale: float = 7.0, |
|
video_guidance_scale: float = 7.0, |
|
min_guidance_scale: float = 2.0, |
|
use_linear_guidance: bool = False, |
|
alpha: float = 0.5, |
|
negative_prompt: Optional[Union[str, List[str]]]="cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror", |
|
num_images_per_prompt: Optional[int] = 1, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
output_type: Optional[str] = "pil", |
|
save_memory: bool = True, |
|
): |
|
device = self.device |
|
dtype = self.dtype |
|
|
|
assert (temp - 1) % self.frame_per_unit == 0, "The frames should be divided by frame_per unit" |
|
|
|
if isinstance(prompt, str): |
|
batch_size = 1 |
|
prompt = prompt + ", hyper quality, Ultra HD, 8K" |
|
else: |
|
assert isinstance(prompt, list) |
|
batch_size = len(prompt) |
|
prompt = [_ + ", hyper quality, Ultra HD, 8K" for _ in prompt] |
|
|
|
if isinstance(num_inference_steps, int): |
|
num_inference_steps = [num_inference_steps] * len(self.stages) |
|
|
|
if isinstance(video_num_inference_steps, int): |
|
video_num_inference_steps = [video_num_inference_steps] * len(self.stages) |
|
|
|
negative_prompt = negative_prompt or "" |
|
|
|
|
|
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device) |
|
negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device) |
|
|
|
if use_linear_guidance: |
|
max_guidance_scale = guidance_scale |
|
|
|
guidance_scale_list = [max(max_guidance_scale - alpha * t_, min_guidance_scale) for t_ in range(temp)] |
|
print(guidance_scale_list) |
|
|
|
self._guidance_scale = guidance_scale |
|
self._video_guidance_scale = video_guidance_scale |
|
|
|
if self.do_classifier_free_guidance: |
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
|
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) |
|
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) |
|
|
|
|
|
num_channels_latents = self.dit.config.in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
temp, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
) |
|
|
|
temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1] |
|
|
|
latents = rearrange(latents, 'b c t h w -> (b t) c h w') |
|
|
|
for _ in range(len(self.stages)-1): |
|
height //= 2;width //= 2 |
|
latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2 |
|
|
|
latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp) |
|
|
|
num_units = 1 + (temp - 1) // self.frame_per_unit |
|
stages = self.stages |
|
|
|
generated_latents_list = [] |
|
last_generated_latents = None |
|
|
|
for unit_index in tqdm(range(num_units)): |
|
if use_linear_guidance: |
|
self._guidance_scale = guidance_scale_list[unit_index] |
|
self._video_guidance_scale = guidance_scale_list[unit_index] |
|
|
|
if unit_index == 0: |
|
past_condition_latents = [[] for _ in range(len(stages))] |
|
intermed_latents = self.generate_one_unit( |
|
latents[:,:,:1], |
|
past_condition_latents, |
|
prompt_embeds, |
|
prompt_attention_mask, |
|
pooled_prompt_embeds, |
|
num_inference_steps, |
|
height, |
|
width, |
|
1, |
|
device, |
|
dtype, |
|
generator, |
|
is_first_frame=True, |
|
) |
|
else: |
|
|
|
past_condition_latents = [] |
|
clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1) |
|
|
|
for i_s in range(len(stages)): |
|
last_cond_latent = clean_latents_list[i_s][:,:,-(self.frame_per_unit):] |
|
|
|
stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent] |
|
|
|
|
|
cur_unit_num = unit_index |
|
cur_stage = i_s |
|
cur_unit_ptx = 1 |
|
|
|
while cur_unit_ptx < cur_unit_num: |
|
cur_stage = max(cur_stage - 1, 0) |
|
if cur_stage == 0: |
|
break |
|
cur_unit_ptx += 1 |
|
cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)] |
|
stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents) |
|
|
|
if cur_stage == 0 and cur_unit_ptx < cur_unit_num: |
|
cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)] |
|
stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents) |
|
|
|
stage_input = list(reversed(stage_input)) |
|
past_condition_latents.append(stage_input) |
|
|
|
intermed_latents = self.generate_one_unit( |
|
latents[:,:, 1 + (unit_index - 1) * self.frame_per_unit:1 + unit_index * self.frame_per_unit], |
|
past_condition_latents, |
|
prompt_embeds, |
|
prompt_attention_mask, |
|
pooled_prompt_embeds, |
|
video_num_inference_steps, |
|
height, |
|
width, |
|
self.frame_per_unit, |
|
device, |
|
dtype, |
|
generator, |
|
is_first_frame=False, |
|
) |
|
|
|
generated_latents_list.append(intermed_latents[-1]) |
|
last_generated_latents = intermed_latents |
|
|
|
generated_latents = torch.cat(generated_latents_list, dim=2) |
|
|
|
if output_type == "latent": |
|
image = generated_latents |
|
else: |
|
image = self.decode_latent(generated_latents, save_memory=save_memory) |
|
|
|
return image |
|
|
|
def decode_latent(self, latents, save_memory=True): |
|
if latents.shape[2] == 1: |
|
latents = (latents / self.vae_scale_factor) + self.vae_shift_factor |
|
else: |
|
latents[:, :, :1] = (latents[:, :, :1] / self.vae_scale_factor) + self.vae_shift_factor |
|
latents[:, :, 1:] = (latents[:, :, 1:] / self.vae_video_scale_factor) + self.vae_video_shift_factor |
|
|
|
if save_memory: |
|
|
|
image = self.vae.decode(latents, temporal_chunk=True, window_size=1, tile_sample_min_size=256).sample |
|
else: |
|
image = self.vae.decode(latents, temporal_chunk=True, window_size=2, tile_sample_min_size=512).sample |
|
|
|
image = image.float() |
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
image = rearrange(image, "B C T H W -> (B T) C H W") |
|
image = image.cpu().permute(0, 2, 3, 1).numpy() |
|
image = self.numpy_to_pil(image) |
|
return image |
|
|
|
@staticmethod |
|
def numpy_to_pil(images): |
|
""" |
|
Convert a numpy image or a batch of images to a PIL image. |
|
""" |
|
if images.ndim == 3: |
|
images = images[None, ...] |
|
images = (images * 255).round().astype("uint8") |
|
if images.shape[-1] == 1: |
|
|
|
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] |
|
else: |
|
pil_images = [Image.fromarray(image) for image in images] |
|
|
|
return pil_images |
|
|
|
@property |
|
def device(self): |
|
return next(self.dit.parameters()).device |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.dit.parameters()).dtype |
|
|
|
@property |
|
def guidance_scale(self): |
|
return self._guidance_scale |
|
|
|
@property |
|
def video_guidance_scale(self): |
|
return self._video_guidance_scale |
|
|
|
@property |
|
def do_classifier_free_guidance(self): |
|
return self._guidance_scale > 0 |
|
|