|
from typing import Any, Dict, Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import math |
|
|
|
from diffusers.models.activations import get_activation |
|
from einops import rearrange |
|
|
|
|
|
def get_1d_sincos_pos_embed( |
|
embed_dim, num_frames, cls_token=False, extra_tokens=0, |
|
): |
|
t = np.arange(num_frames, dtype=np.float32) |
|
pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, t) |
|
if cls_token and extra_tokens > 0: |
|
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) |
|
return pos_embed |
|
|
|
|
|
def get_2d_sincos_pos_embed( |
|
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 |
|
): |
|
""" |
|
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or |
|
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
|
""" |
|
if isinstance(grid_size, int): |
|
grid_size = (grid_size, grid_size) |
|
|
|
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale |
|
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale |
|
grid = np.meshgrid(grid_w, grid_h) |
|
grid = np.stack(grid, axis=0) |
|
|
|
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) |
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
|
if cls_token and extra_tokens > 0: |
|
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) |
|
return pos_embed |
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
|
if embed_dim % 2 != 0: |
|
raise ValueError("embed_dim must be divisible by 2") |
|
|
|
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
|
emb = np.concatenate([emb_h, emb_w], axis=1) |
|
return emb |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
|
""" |
|
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) |
|
""" |
|
if embed_dim % 2 != 0: |
|
raise ValueError("embed_dim must be divisible by 2") |
|
|
|
omega = np.arange(embed_dim // 2, dtype=np.float64) |
|
omega /= embed_dim / 2.0 |
|
omega = 1.0 / 10000**omega |
|
|
|
pos = pos.reshape(-1) |
|
out = np.einsum("m,d->md", pos, omega) |
|
|
|
emb_sin = np.sin(out) |
|
emb_cos = np.cos(out) |
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) |
|
return emb |
|
|
|
|
|
def get_timestep_embedding( |
|
timesteps: torch.Tensor, |
|
embedding_dim: int, |
|
flip_sin_to_cos: bool = False, |
|
downscale_freq_shift: float = 1, |
|
scale: float = 1, |
|
max_period: int = 10000, |
|
): |
|
""" |
|
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. |
|
:param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. |
|
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the |
|
embeddings. :return: an [N x dim] Tensor of positional embeddings. |
|
""" |
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" |
|
|
|
half_dim = embedding_dim // 2 |
|
exponent = -math.log(max_period) * torch.arange( |
|
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device |
|
) |
|
exponent = exponent / (half_dim - downscale_freq_shift) |
|
|
|
emb = torch.exp(exponent) |
|
emb = timesteps[:, None].float() * emb[None, :] |
|
|
|
|
|
emb = scale * emb |
|
|
|
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) |
|
|
|
|
|
if flip_sin_to_cos: |
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) |
|
|
|
|
|
if embedding_dim % 2 == 1: |
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) |
|
return emb |
|
|
|
|
|
class Timesteps(nn.Module): |
|
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): |
|
super().__init__() |
|
self.num_channels = num_channels |
|
self.flip_sin_to_cos = flip_sin_to_cos |
|
self.downscale_freq_shift = downscale_freq_shift |
|
|
|
def forward(self, timesteps): |
|
t_emb = get_timestep_embedding( |
|
timesteps, |
|
self.num_channels, |
|
flip_sin_to_cos=self.flip_sin_to_cos, |
|
downscale_freq_shift=self.downscale_freq_shift, |
|
) |
|
return t_emb |
|
|
|
|
|
class TimestepEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
time_embed_dim: int, |
|
act_fn: str = "silu", |
|
out_dim: int = None, |
|
post_act_fn: Optional[str] = None, |
|
sample_proj_bias=True, |
|
): |
|
super().__init__() |
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) |
|
self.act = get_activation(act_fn) |
|
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, sample_proj_bias) |
|
|
|
def forward(self, sample): |
|
sample = self.linear_1(sample) |
|
sample = self.act(sample) |
|
sample = self.linear_2(sample) |
|
return sample |
|
|
|
|
|
class TextProjection(nn.Module): |
|
def __init__(self, in_features, hidden_size, act_fn="silu"): |
|
super().__init__() |
|
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) |
|
self.act_1 = get_activation(act_fn) |
|
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) |
|
|
|
def forward(self, caption): |
|
hidden_states = self.linear_1(caption) |
|
hidden_states = self.act_1(hidden_states) |
|
hidden_states = self.linear_2(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class CombinedTimestepConditionEmbeddings(nn.Module): |
|
def __init__(self, embedding_dim, pooled_projection_dim): |
|
super().__init__() |
|
|
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) |
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) |
|
self.text_embedder = TextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") |
|
|
|
def forward(self, timestep, pooled_projection): |
|
timesteps_proj = self.time_proj(timestep) |
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) |
|
pooled_projections = self.text_embedder(pooled_projection) |
|
conditioning = timesteps_emb + pooled_projections |
|
return conditioning |
|
|
|
|
|
class CombinedTimestepEmbeddings(nn.Module): |
|
def __init__(self, embedding_dim): |
|
super().__init__() |
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) |
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) |
|
|
|
def forward(self, timestep): |
|
timesteps_proj = self.time_proj(timestep) |
|
timesteps_emb = self.timestep_embedder(timesteps_proj) |
|
return timesteps_emb |
|
|
|
|
|
class PatchEmbed3D(nn.Module): |
|
"""Support the 3D Tensor input""" |
|
|
|
def __init__( |
|
self, |
|
height=128, |
|
width=128, |
|
patch_size=2, |
|
in_channels=16, |
|
embed_dim=1536, |
|
layer_norm=False, |
|
bias=True, |
|
interpolation_scale=1, |
|
pos_embed_type="sincos", |
|
temp_pos_embed_type='rope', |
|
pos_embed_max_size=192, |
|
max_num_frames=64, |
|
add_temp_pos_embed=False, |
|
interp_condition_pos=False, |
|
): |
|
super().__init__() |
|
|
|
num_patches = (height // patch_size) * (width // patch_size) |
|
self.layer_norm = layer_norm |
|
self.pos_embed_max_size = pos_embed_max_size |
|
|
|
self.proj = nn.Conv2d( |
|
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias |
|
) |
|
if layer_norm: |
|
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) |
|
else: |
|
self.norm = None |
|
|
|
self.patch_size = patch_size |
|
self.height, self.width = height // patch_size, width // patch_size |
|
self.base_size = height // patch_size |
|
self.interpolation_scale = interpolation_scale |
|
self.add_temp_pos_embed = add_temp_pos_embed |
|
|
|
|
|
if pos_embed_max_size: |
|
grid_size = pos_embed_max_size |
|
else: |
|
grid_size = int(num_patches**0.5) |
|
|
|
if pos_embed_type is None: |
|
self.pos_embed = None |
|
|
|
elif pos_embed_type == "sincos": |
|
pos_embed = get_2d_sincos_pos_embed( |
|
embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale |
|
) |
|
persistent = True if pos_embed_max_size else False |
|
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) |
|
|
|
if add_temp_pos_embed and temp_pos_embed_type == 'sincos': |
|
time_pos_embed = get_1d_sincos_pos_embed(embed_dim, max_num_frames) |
|
self.register_buffer("temp_pos_embed", torch.from_numpy(time_pos_embed).float().unsqueeze(0), persistent=True) |
|
|
|
elif pos_embed_type == "rope": |
|
print("Using the rotary position embedding") |
|
|
|
else: |
|
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") |
|
|
|
self.pos_embed_type = pos_embed_type |
|
self.temp_pos_embed_type = temp_pos_embed_type |
|
self.interp_condition_pos = interp_condition_pos |
|
|
|
def cropped_pos_embed(self, height, width, ori_height, ori_width): |
|
"""Crops positional embeddings for SD3 compatibility.""" |
|
if self.pos_embed_max_size is None: |
|
raise ValueError("`pos_embed_max_size` must be set for cropping.") |
|
|
|
height = height // self.patch_size |
|
width = width // self.patch_size |
|
ori_height = ori_height // self.patch_size |
|
ori_width = ori_width // self.patch_size |
|
|
|
assert ori_height >= height, "The ori_height needs >= height" |
|
assert ori_width >= width, "The ori_width needs >= width" |
|
|
|
if height > self.pos_embed_max_size: |
|
raise ValueError( |
|
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." |
|
) |
|
if width > self.pos_embed_max_size: |
|
raise ValueError( |
|
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." |
|
) |
|
|
|
if self.interp_condition_pos: |
|
top = (self.pos_embed_max_size - ori_height) // 2 |
|
left = (self.pos_embed_max_size - ori_width) // 2 |
|
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) |
|
spatial_pos_embed = spatial_pos_embed[:, top : top + ori_height, left : left + ori_width, :] |
|
if ori_height != height or ori_width != width: |
|
spatial_pos_embed = spatial_pos_embed.permute(0, 3, 1, 2) |
|
spatial_pos_embed = torch.nn.functional.interpolate(spatial_pos_embed, size=(height, width), mode='bilinear') |
|
spatial_pos_embed = spatial_pos_embed.permute(0, 2, 3, 1) |
|
else: |
|
top = (self.pos_embed_max_size - height) // 2 |
|
left = (self.pos_embed_max_size - width) // 2 |
|
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) |
|
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] |
|
|
|
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) |
|
|
|
return spatial_pos_embed |
|
|
|
def forward_func(self, latent, time_index=0, ori_height=None, ori_width=None): |
|
if self.pos_embed_max_size is not None: |
|
height, width = latent.shape[-2:] |
|
else: |
|
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size |
|
|
|
bs = latent.shape[0] |
|
temp = latent.shape[2] |
|
|
|
latent = rearrange(latent, 'b c t h w -> (b t) c h w') |
|
latent = self.proj(latent) |
|
latent = latent.flatten(2).transpose(1, 2) |
|
|
|
if self.layer_norm: |
|
latent = self.norm(latent) |
|
|
|
if self.pos_embed_type == 'sincos': |
|
|
|
if self.pos_embed_max_size: |
|
pos_embed = self.cropped_pos_embed(height, width, ori_height, ori_width) |
|
else: |
|
raise NotImplementedError("Not implemented sincos pos embed without sd3 max pos crop") |
|
if self.height != height or self.width != width: |
|
pos_embed = get_2d_sincos_pos_embed( |
|
embed_dim=self.pos_embed.shape[-1], |
|
grid_size=(height, width), |
|
base_size=self.base_size, |
|
interpolation_scale=self.interpolation_scale, |
|
) |
|
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) |
|
else: |
|
pos_embed = self.pos_embed |
|
|
|
if self.add_temp_pos_embed and self.temp_pos_embed_type == 'sincos': |
|
latent_dtype = latent.dtype |
|
latent = latent + pos_embed |
|
latent = rearrange(latent, '(b t) n c -> (b n) t c', t=temp) |
|
latent = latent + self.temp_pos_embed[:, time_index:time_index + temp, :] |
|
latent = latent.to(latent_dtype) |
|
latent = rearrange(latent, '(b n) t c -> b t n c', b=bs) |
|
else: |
|
latent = (latent + pos_embed).to(latent.dtype) |
|
latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp) |
|
|
|
else: |
|
assert self.pos_embed_type == "rope", "Only supporting the sincos and rope embedding" |
|
latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp) |
|
|
|
return latent |
|
|
|
def forward(self, latent): |
|
""" |
|
Arguments: |
|
past_condition_latents (Torch.FloatTensor): The past latent during the generation |
|
flatten_input (bool): True indicate flatten the latent into 1D sequence |
|
""" |
|
|
|
if isinstance(latent, list): |
|
output_list = [] |
|
|
|
for latent_ in latent: |
|
if not isinstance(latent_, list): |
|
latent_ = [latent_] |
|
|
|
output_latent = [] |
|
time_index = 0 |
|
ori_height, ori_width = latent_[-1].shape[-2:] |
|
for each_latent in latent_: |
|
hidden_state = self.forward_func(each_latent, time_index=time_index, ori_height=ori_height, ori_width=ori_width) |
|
time_index += each_latent.shape[2] |
|
hidden_state = rearrange(hidden_state, "b t n c -> b (t n) c") |
|
output_latent.append(hidden_state) |
|
|
|
output_latent = torch.cat(output_latent, dim=1) |
|
output_list.append(output_latent) |
|
|
|
return output_list |
|
else: |
|
hidden_states = self.forward_func(latent) |
|
hidden_states = rearrange(hidden_states, "b t n c -> b (t n) c") |
|
return hidden_states |