jadechoghari's picture
Update audioldm_train/modules/diffusionmodules/PixArt.py
4993d3c verified
raw
history blame
134 kB
from abc import abstractmethod
from functools import partial
import math
from typing import Iterable
import numpy as np
import torch as th
#from .utils_pos_embedding.pos_embed import RoPE2D
import torch.nn as nn
import torch.nn.functional as F
import sys
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
ParallelEmbedding,
RowParallelLinear,
)
from timm.models.layers import DropPath
from .utils import auto_grad_checkpoint, to_2tuple
from .PixArt_blocks import t2i_modulate, WindowAttention, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, FinalLayer
import xformers.ops
import math
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(
self,
img_size=(256, 16),
patch_size=(16, 4),
overlap = (0, 0),
in_chans=128,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
):
super().__init__()
# img_size=(256, 16)
# patch_size=(16, 4)
# overlap = (2, 2)
# in_chans=128
# embed_dim=768
# import pdb
# pdb.set_trace()
self.img_size = img_size
self.patch_size = patch_size
self.ol = overlap
self.grid_size = (math.ceil((img_size[0] - patch_size[0]) / (patch_size[0]-overlap[0])) + 1,
math.ceil((img_size[1] - patch_size[1]) / (patch_size[1]-overlap[1])) + 1)
self.pad_size = ((self.grid_size[0]-1) * (self.patch_size[0]-overlap[0])+self.patch_size[0]-self.img_size[0],
+(self.grid_size[1]-1)*(self.patch_size[1]-overlap[1])+self.patch_size[1]-self.img_size[1])
self.pad_size = (self.pad_size[0] // 2, self.pad_size[1] // 2)
# self.p-ad_size = (((img_size[0] - patch_size[0]) // ), )
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0]-overlap[0], patch_size[1]-overlap[1]), bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
# B, C, H, W = x.shape
# _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
# _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
x = F.pad(x, (self.pad_size[-1], self.pad_size[-1], self.pad_size[-2], self.pad_size[-2]), "constant", 0)
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class PatchEmbed_1D(nn.Module):
def __init__(
self,
img_size=(256, 16),
# patch_size=(16, 4),
# overlap = (0, 0),
in_chans=8,
embed_dim=1152,
norm_layer=None,
# flatten=True,
bias=True,
):
super().__init__()
self.proj = nn.Linear(in_chans*img_size[1], embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
# B, C, H, W = x.shape
# _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
# _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
# x = F.pad(x, (self.pad_size[-1], self.pad_size[-1], self.pad_size[-2], self.pad_size[-2]), "constant", 0)
# x = self.proj(x)
# if self.flatten:
# x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = th.einsum('bctf->btfc', x)
x = x.flatten(2) # BTFC -> BTD
x = self.proj(x)
x = self.norm(x)
return x
# if __name__ == '__main__':
# x = th.rand(1, 256, 16).unsqueeze(0)
# model = PatchEmbed(in_chans=1)
# y = model(x)
from timm.models.vision_transformer import Attention, Mlp
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
from positional_encodings.torch_encodings import PositionalEncoding1D
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
class PixArtBlock(nn.Module):
"""
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None, use_rel_pos=False, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True,
input_size=input_size if window_size == 0 else (window_size, window_size),
use_rel_pos=use_rel_pos, **block_kwargs)
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.window_size = window_size
self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, **kwargs):
B, N, C = x.shape
# x [B, T, D]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
from qa_mdt.audioldm_train.modules.diffusionmodules.attention import CrossAttention_1D
class PixArtBlock_Slow(nn.Module):
"""
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None, use_rel_pos=False, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = CrossAttention_1D(query_dim=hidden_size, context_dim=hidden_size, heads=num_heads, dim_head=int(hidden_size/num_heads))
self.cross_attn = CrossAttention_1D(query_dim=hidden_size, context_dim=hidden_size, heads=num_heads, dim_head=int(hidden_size/num_heads))
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.window_size = window_size
self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, **kwargs):
B, N, C = x.shape
# x [B, T, D]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
class PixArt(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
# self.x_embedder = PatchEmbed_1D(input)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size[0] // self.patch_size[0] * 2
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
PixArtBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
self.initialize_weights()
# if config:
# logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
# logger.warning(f"lewei scale: {self.lewei_scale}, base size: {self.base_size}")
# else:
# print(f'Warning: lewei scale: {self.lewei_scale}, base size: {self.base_size}')
def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
mask = context_mask_list[0] # (N, L)
assert mask is not None
# if mask is not None:
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
y_lens = [int(_) for _ in y_lens]
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, y, mask)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify(self, x):
"""
x: (N, T, patch_size 0 * patch_size 1 * C)
imgs: (Bs. 256. 16. 8)
"""
# torch_map = th.zeros(self.x_embedder.img_size[0]+2*self.x_embedder.pad_size[0],
# self.x_embedder.img_size[1]+2*self.x_embedder.pad_size[1]).to(x.device)
# lf = self.x_embedder.grid_size[0]
# rf = self.x_embedder.grid_size[1]
# for i in range(lf):
# for j in range(rf):
# xx = (i) * (self.x_embedder.patch_size[0]-self.x_embedder.ol[0])
# yy = (j) * (self.x_embedder.patch_size[1]-self.x_embedder.ol[1])
# torch_map[xx:(xx+self.x_embedder.patch_size[0]), yy:(yy+self.x_embedder.patch_size[1])]+=1
# torch_map = torch_map[self.x_embedder.pad_size[0]:self.x_embedder.pad_size[0]+self.x_embedder.img_size[0],
# self.x_embedder.pad_size[1]:self.x_embedder.pad_size[1]+self.x_embedder.img_size[1]]
# torch_map = th.reciprocal(torch_map)
# c = self.out_channels
# p0, p1 = self.x_embedder.patch_size[0], self.x_embedder.patch_size[1]
# x = x.reshape(shape=(x.shape[0], self.x_embedder.grid_size[0],
# self.x_embedder.grid_size[1], p0, p1, c))
# x = th.einsum('nhwpqc->nchwpq', x)
# added_map = th.zeros(x.shape[0], c,
# self.x_embedder.img_size[0]+2*self.x_embedder.pad_size[0],
# self.x_embedder.img_size[1]+2*self.x_embedder.pad_size[1]).to(x.device)
# for b_id in range(x.shape[0]):
# for i in range(lf):
# for j in range(rf):
# for c_id in range(c):
# xx = (i) * (self.x_embedder.patch_size[0]-self.x_embedder.ol[0])
# yy = (j) * (self.x_embedder.patch_size[1]-self.x_embedder.ol[1])
# added_map[b_id][c_id][xx:(xx+self.x_embedder.patch_size[0]), yy:(yy+self.x_embedder.patch_size[1])] += \
# x[b_id, c_id, i, j]
# ret_map = th.zeros(x.shape[0], c, self.x_embedder.img_size[0],
# self.x_embedder.img_size[1]).to(x.device)
# for b_id in range(x.shape[0]):
# for id_c in range(c):
# ret_map[b_id, id_c, :, :] = th.mul(added_map[b_id][id_c][self.x_embedder.pad_size[0]:self.x_embedder.pad_size[0]+self.x_embedder.img_size[0],
# self.x_embedder.pad_size[1]:self.x_embedder.pad_size[1]+self.x_embedder.img_size[1]], torch_map)
c = self.out_channels
p0 = self.x_embedder.patch_size[0]
p1 = self.x_embedder.patch_size[1]
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
# h = w = int(x.shape[1] ** 0.5)
# print(x.shape, h, w, p0, p1)
# import pdb
# pdb.set_trace()
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c))
x = th.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1))
return imgs
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@property
def dtype(self):
return next(self.parameters()).dtype
class SwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
self.w2 = RowParallelLinear(
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
)
self.w3 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class DEBlock(nn.Module):
"""
Decoder block with added SpecTNT transformer
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, FFN_type='SwiGLU', drop_path=0., window_size=0, input_size=None, use_rel_pos=False, skip=False, num_f=None, num_t=None, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True,
input_size=input_size if window_size == 0 else (window_size, window_size),
use_rel_pos=use_rel_pos, **block_kwargs)
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
# self.cross_attn_f = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
# self.cross_attn_t = MultiHeadCrossAttention(hidden_size*num_f, num_heads, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm4 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm5 = nn.LayerNorm(hidden_size * num_f, elementwise_affine=False, eps=1e-6)
self.norm6 = nn.LayerNorm(hidden_size * num_f, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
if FFN_type == 'mlp':
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
# self.mlp2 = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
# self.mlp3 = Mlp(in_features=hidden_size*num_f, hidden_features=int(hidden_size*num_f * mlp_ratio), act_layer=approx_gelu, drop=0)
elif FFN_type == 'SwiGLU':
self.mlp = SwiGLU(hidden_size, int(hidden_size * mlp_ratio), 1)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.window_size = window_size
self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5)
# self.scale_shift_table_2 = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5)
# self.scale_shift_table_3 = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5)
self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) if skip else None
self.F_transformer = WindowAttention(hidden_size, num_heads=4, qkv_bias=True,
input_size=input_size if window_size == 0 else (window_size, window_size),
use_rel_pos=use_rel_pos, **block_kwargs)
self.T_transformer = WindowAttention(hidden_size * num_f, num_heads=16, qkv_bias=True,
input_size=input_size if window_size == 0 else (window_size, window_size),
use_rel_pos=use_rel_pos, **block_kwargs)
self.f_pos = nn.Embedding(num_f, hidden_size)
self.t_pos = nn.Embedding(num_t, hidden_size * num_f)
self.num_f = num_f
self.num_t = num_t
def forward(self, x_normal, end, y, t, mask=None, skip=None, ids_keep=None, **kwargs):
# import pdb
# pdb.set_trace()
B, D, C = x_normal.shape
T = self.num_t
F_add_1 = self.num_f
# B, T, F_add_1, C = x.shape
# F_add_1 = F_add_1 + 1
# x_normal = th.reshape()
# # x_end [B, T, 1, C]
# x_end = x[:, :, -1, :].unsqueeze(2)
if self.skip_linear is not None:
x_normal = self.skip_linear(th.cat([x_normal, skip], dim=-1))
D = T * (F_add_1 - 1)
# x_normal [B, D, C]
# import pdb
# pdb.set_trace()
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x_normal = x_normal + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x_normal), shift_msa, scale_msa)).reshape(B, D, C))
x_normal = x_normal.reshape(B, T, F_add_1-1, C)
x_normal = th.cat((x_normal, end), 2)
# x_normal [B*T, F+1, C]
x_normal = x_normal.reshape(B*T, F_add_1, C)
pos_f = th.arange(self.num_f, device=x.device).unsqueeze(0).expand(B*T, -1)
# import pdb; pdb.set_trace()
x_normal = x_normal + self.f_pos(pos_f)
x_normal = x_normal + self.F_transformer(self.norm3(x_normal))
# x_normal = x_normal + self.cross_attn_f(x_normal, y, mask)
# x_normal = x_normal + self.mlp2(self.norm4(x_normal))
# x_normal [B, T, (F+1) * C]
x_normal = x_normal.reshape(B, T, F_add_1 * C)
pos_t = th.arange(self.num_t, device=x.device).unsqueeze(0).expand(B, -1)
x_normal = x_normal + self.t_pos(pos_t)
x_normal = x_normal + self.T_transformer(self.norm5(x_normal))
# x_normal = x_normal + self.cross_attn_t(x_normal, y, mask)
x_normal = x_normal.reshape(B, T ,F_add_1, C)
end = x_normal[:, :, -1, :].unsqueeze(2)
x_normal = x_normal[:, :, :-1, :]
x_normal = x_normal.reshape(B, T*(F_add_1 - 1), C)
x_normal = x_normal + self.cross_attn(x_normal, y, mask)
x_normal = x_normal + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x_normal), shift_mlp, scale_mlp)))
# x_normal = th.cat
return x_normal, end #.reshape(B, )
class MDTBlock(nn.Module):
"""
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, FFN_type='mlp', drop_path=0., window_size=0, input_size=None, use_rel_pos=False, skip=False, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True,
input_size=input_size if window_size == 0 else (window_size, window_size),
use_rel_pos=use_rel_pos, **block_kwargs)
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
if FFN_type == 'mlp':
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
elif FFN_type == 'SwiGLU':
self.mlp = SwiGLU(hidden_size, int(hidden_size * mlp_ratio), 1)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.window_size = window_size
self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5)
self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) if skip else None
def forward(self, x, y, t, mask=None, skip=None, ids_keep=None, **kwargs):
B, N, C = x.shape
if self.skip_linear is not None:
x = self.skip_linear(th.cat([x, skip], dim=-1))
# x [B, T, D]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
class PixArt_MDT_MASK_TF(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_t=0.17, mask_f=0.15, decode_layer=4,**kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
decode_layer = int(decode_layer)
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
# self.x_embedder = PatchEmbed_1D(input)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size[0] // self.patch_size[0] * 2
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
half_depth = (depth - decode_layer)//2
self.half_depth=half_depth
drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] # stochastic depth decay rule
drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)]
self.en_inblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, FFN_type='mlp') for i in range(half_depth)
])
self.en_outblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(half_depth)
])
self.de_blocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(decode_layer)
])
self.sideblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio,
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, FFN_type='mlp') for _ in range(1)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
self.decoder_pos_embed = nn.Parameter(th.zeros(
1, num_patches, hidden_size), requires_grad=True)
# if mask_ratio is not None:
# self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size))
# self.mask_ratio = float(mask_ratio)
# self.decode_layer = int(decode_layer)
# else:
# self.mask_token = nn.Parameter(th.zeros(
# 1, 1, hidden_size), requires_grad=False)
# self.mask_ratio = None
# self.decode_layer = int(decode_layer)
assert mask_t != 0 and mask_f != 0
self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size))
self.mask_t = mask_t
self.mask_f = mask_f
self.decode_layer = int(decode_layer)
print(f"mask ratio: T-{self.mask_t} F-{self.mask_f}", "decode_layer:", self.decode_layer)
self.initialize_weights()
def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
# import pdb
# pdb.set_trace()
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
# if not self.training:
try:
mask = context_mask_list[0] # (N, L)
except:
mask = th.ones(x.shape[0], 1).to(x.device)
print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!")
assert mask is not None
# if mask is not None:
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
y_lens = [int(_) for _ in y_lens]
input_skip = x
masked_stage = False
skips = []
# TODO : masking op for training
if self.mask_t is not None and self.training:
# masking: length -> length * mask_ratio
rand_mask_ratio = th.rand(1, device=x.device) # noise in [0, 1]
rand_mask_ratio_t = rand_mask_ratio * 0.13 + self.mask_t # mask_ratio, mask_ratio + 0.2
rand_mask_ratio_f = rand_mask_ratio * 0.13 + self.mask_f
# print(rand_mask_ratio)
x, mask, ids_restore, ids_keep = self.random_masking_2d(
x, rand_mask_ratio_t, rand_mask_ratio_f)
masked_stage = True
for block in self.en_inblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None)
skips.append(x)
for block in self.en_outblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None)
if self.mask_t is not None and self.mask_f is not None and self.training:
x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore)
masked_stage = False
else:
# add pos embed
x = x + self.decoder_pos_embed
for i in range(len(self.de_blocks)):
block = self.de_blocks[i]
this_skip = input_skip
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, context_list, context_mask_list=None, cfg_scale=4.0, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
# import pdb
# pdb.set_trace()
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, context_list, context_mask_list=None)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify(self, x):
"""
x: (N, T, patch_size 0 * patch_size 1 * C)
imgs: (Bs. 256. 16. 8)
"""
if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]:
c = self.out_channels
p0 = self.x_embedder.patch_size[0]
p1 = self.x_embedder.patch_size[1]
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c))
x = th.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1))
return imgs
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
bs = x.shape[0]
torch_map = self.torch_map
c = self.out_channels
x = x.reshape(shape=(bs, lf, rf, lp, rp, c))
x = th.einsum('nhwpqc->nchwpq', x)
added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device)
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \
x[:, :, i, j, :, :]
# import pdb
# pdb.set_trace()
added_map = added_map[:][:][lpad:lm+lpad, rpad:rm+rpad]
return th.mul(added_map.to(x.device), torch_map.to(x.device))
def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
"""
2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
# if self.use_custom_patch: # overlapped patch
# T=101
# F=12
# else:
# T=64
# F=8
T = self.x_embedder.grid_size[0]
F = self.x_embedder.grid_size[1]
#x = x.reshape(N, T, F, D)
len_keep_t = int(T * (1 - mask_t_prob))
len_keep_f = int(F * (1 - mask_f_prob))
# noise for mask in time
noise_t = th.rand(N, T, device=x.device) # noise in [0, 1]
# sort noise for each sample aling time
ids_shuffle_t = th.argsort(noise_t, dim=1) # ascend: small is keep, large is remove
ids_restore_t = th.argsort(ids_shuffle_t, dim=1)
ids_keep_t = ids_shuffle_t[:,:len_keep_t]
# noise mask in freq
noise_f = th.rand(N, F, device=x.device) # noise in [0, 1]
ids_shuffle_f = th.argsort(noise_f, dim=1) # ascend: small is keep, large is remove
ids_restore_f = th.argsort(ids_shuffle_f, dim=1)
ids_keep_f = ids_shuffle_f[:,:len_keep_f] #
# generate the binary mask: 0 is keep, 1 is remove
# mask in freq
mask_f = th.ones(N, F, device=x.device)
mask_f[:,:len_keep_f] = 0
mask_f = th.gather(mask_f, dim=1, index=ids_restore_f).unsqueeze(1).repeat(1,T,1) # N,T,F
# mask in time
mask_t = th.ones(N, T, device=x.device)
mask_t[:,:len_keep_t] = 0
mask_t = th.gather(mask_t, dim=1, index=ids_restore_t).unsqueeze(1).repeat(1,F,1).permute(0,2,1) # N,T,F
mask = 1-(1-mask_t)*(1-mask_f) # N, T, F
# get masked x
id2res=th.Tensor(list(range(N*T*F))).reshape(N,T,F).to(x.device)
id2res = id2res + 999*mask # add a large value for masked elements
id2res2 = th.argsort(id2res.flatten(start_dim=1))
ids_keep=id2res2.flatten(start_dim=1)[:,:len_keep_f*len_keep_t]
x_masked = th.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
ids_restore = th.argsort(id2res2.flatten(start_dim=1))
mask = mask.flatten(start_dim=1)
return x_masked, mask, ids_restore, ids_keep
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = th.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
# ascend: small is keep, large is remove
ids_shuffle = th.argsort(noise, dim=1)
ids_restore = th.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = th.gather(
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = th.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = th.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore, ids_keep
def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore):
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
x_ = th.cat([x, mask_tokens], dim=1)
x = th.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# add pos embed
x = x + self.decoder_pos_embed
# pass to the basic block
x_before = x
for sideblock in self.sideblocks:
x = sideblock(x, y, t0, y_lens, ids_keep=None)
# masked shortcut
mask = mask.unsqueeze(dim=-1)
x = x*mask + (1-mask)*x_before
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.en_inblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.en_outblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.de_blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.sideblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0):
return
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda')
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
torch_map[xx:(xx+lp), yy:(yy+rp)]+=1
torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad]
self.torch_map = th.reciprocal(torch_map)
@property
def dtype(self):
return next(self.parameters()).dtype
class PixArt_MDT_MOS_AS_TOKEN(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
decode_layer = int(decode_layer)
self.mos_embed = nn.Embedding(num_embeddings=5, embedding_dim=hidden_size)
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
# self.x_embedder = PatchEmbed_1D(input)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size[0] // self.patch_size[0] * 2
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
# self.mos_block = nn.Sequential(
# )
self.y_embedder = nn.Linear(cond_dim, hidden_size)
half_depth = (depth - decode_layer)//2
self.half_depth=half_depth
drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] # stochastic depth decay rule
drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)]
self.en_inblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, FFN_type='mlp') for i in range(half_depth)
])
self.en_outblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(half_depth)
])
self.de_blocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(decode_layer)
])
self.sideblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio,
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, FFN_type='mlp') for _ in range(1)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
self.decoder_pos_embed = nn.Parameter(th.zeros(
1, num_patches, hidden_size), requires_grad=True)
if mask_ratio is not None:
self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size))
self.mask_ratio = float(mask_ratio)
self.decode_layer = int(decode_layer)
else:
self.mask_token = nn.Parameter(th.zeros(
1, 1, hidden_size), requires_grad=False)
self.mask_ratio = None
self.decode_layer = int(decode_layer)
print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer)
self.initialize_weights()
def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, mos=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
# mos = th.ones(x.shape[0], dtype=th.int).to(x.device)
#print(f'DEBUG! {x}, {mos}')
assert mos.shape[0] == x.shape[0]
#import pdb; pdb.set_trace()
mos = mos - 1
mos = self.mos_embed(mos.to(x.device).to(th.int)) # [N, dim]
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
# import pdb
# pdb.set_trace()
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
# if not self.training:
try:
mask = context_mask_list[0] # (N, L)
except:
mask = th.ones(x.shape[0], 1).to(x.device)
assert mask is not None
# if mask is not None:
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
y_lens = [int(_) for _ in y_lens]
masked_stage = False
skips = []
# TODO : masking op for training
try:
x = th.cat([mos, x], dim=1) # [N, L+1, dim]
except:
x = th.cat([mos.unsqueeze(1), x], dim=1)
input_skip = x
if self.mask_ratio is not None and self.training:
# masking: length -> length * mask_ratio
rand_mask_ratio = th.rand(1, device=x.device) # noise in [0, 1]
rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio # mask_ratio, mask_ratio + 0.2
# print(rand_mask_ratio)
x, mask, ids_restore, ids_keep = self.random_masking(
x, rand_mask_ratio)
masked_stage = True
for block in self.en_inblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None)
skips.append(x)
for block in self.en_outblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None)
if self.mask_ratio is not None and self.training:
x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore)
masked_stage = False
else:
# add pos embed
x[:, 1:, :] = x[:, 1:, :] + self.decoder_pos_embed
# x = x + self.decoder_pos_embed
# import pdb
# pdb.set_trace()
for i in range(len(self.de_blocks)):
block = self.de_blocks[i]
this_skip = input_skip
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None)
x = x[:, 1:, :]
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
# import pdb
# pdb.set_trace()
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, context_list, context_mask_list=None, cfg_scale=4.0, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
# import pdb
# pdb.set_trace()
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, context_list, context_mask_list=None)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify(self, x):
"""
x: (N, T, patch_size 0 * patch_size 1 * C)
imgs: (Bs. 256. 16. 8)
"""
if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]:
c = self.out_channels
p0 = self.x_embedder.patch_size[0]
p1 = self.x_embedder.patch_size[1]
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c))
x = th.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1))
return imgs
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
bs = x.shape[0]
torch_map = self.torch_map
c = self.out_channels
x = x.reshape(shape=(bs, lf, rf, lp, rp, c))
x = th.einsum('nhwpqc->nchwpq', x)
added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device)
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \
x[:, :, i, j, :, :]
# import pdb
# pdb.set_trace()
added_map = added_map[:][:][lpad:lm+lpad, rpad:rm+rpad]
return th.mul(added_map.to(x.device), torch_map.to(x.device))
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
L = L - 1
len_keep = int(L * (1 - mask_ratio))
noise = th.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
# ascend: small is keep, large is remove
ids_shuffle = th.argsort(noise, dim=1)
ids_restore = th.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = th.gather(
x[:, 1:, :], dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
x_masked = th.cat([x[:, 0, :].unsqueeze(1), x_masked], dim=1)
# import pdb
# pdb.set_trace()
# generate the binary mask: 0 is keep, 1 is remove
mask = th.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = th.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore, ids_keep
def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore):
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] - x.shape[1] + 1, 1)
x_ = th.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = th.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# add pos embed
x_ = x_ + self.decoder_pos_embed
x = th.cat([x[:, 0, :].unsqueeze(1), x_], dim=1)
# import pdb
# pdb.set_trace()
# pass to the basic block
x_before = x
for sideblock in self.sideblocks:
x = sideblock(x, y, t0, y_lens, ids_keep=None)
# masked shortcut
mask = mask.unsqueeze(dim=-1)
# import pdb;pdb.set_trace()
mask = th.cat([th.ones(mask.shape[0], 1, 1).to(mask.device), mask], dim=1)
x = x*mask + (1-mask)*x_before
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.en_inblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.en_outblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.de_blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.sideblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0):
return
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda')
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
torch_map[xx:(xx+lp), yy:(yy+rp)]+=1
torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad]
self.torch_map = th.reciprocal(torch_map)
@property
def dtype(self):
return next(self.parameters()).dtype
class PixArt_MDT_LC(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
decode_layer = int(decode_layer)
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
# self.x_embedder = PatchEmbed_1D(input)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size[0] // self.patch_size[0] * 2
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
half_depth = (depth - decode_layer)//2
self.half_depth=half_depth
drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] # stochastic depth decay rule
drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)]
self.en_inblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, FFN_type='mlp') for i in range(half_depth)
])
self.en_outblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(half_depth)
])
self.de_blocks = nn.ModuleList([
DEBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True, FFN_type='mlp', num_f=self.x_embedder.grid_size[1]+1, num_t=self.x_embedder.grid_size[0]) for i in range(decode_layer)
])
self.sideblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio,
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, FFN_type='mlp') for _ in range(1)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
self.decoder_pos_embed = nn.Parameter(th.zeros(
1, num_patches, hidden_size), requires_grad=True)
if mask_ratio is not None:
self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size))
self.mask_ratio = float(mask_ratio)
self.decode_layer = int(decode_layer)
else:
self.mask_token = nn.Parameter(th.zeros(
1, 1, hidden_size), requires_grad=False)
self.mask_ratio = None
self.decode_layer = int(decode_layer)
print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer)
self.initialize_weights()
def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
# import pdb
# pdb.set_trace()
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
# if not self.training:
try:
mask = context_mask_list[0] # (N, L)
except:
mask = th.ones(x.shape[0], 1).to(x.device)
print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!")
assert mask is not None
# if mask is not None:
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
y_lens = [int(_) for _ in y_lens]
input_skip = x
masked_stage = False
skips = []
# TODO : masking op for training
if self.mask_ratio is not None and self.training:
# masking: length -> length * mask_ratio
rand_mask_ratio = th.rand(1, device=x.device) # noise in [0, 1]
rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio # mask_ratio, mask_ratio + 0.2
# print(rand_mask_ratio)
x, mask, ids_restore, ids_keep = self.random_masking(
x, rand_mask_ratio)
masked_stage = True
for block in self.en_inblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None)
skips.append(x)
for block in self.en_outblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None)
if self.mask_ratio is not None and self.training:
x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore)
masked_stage = False
else:
# add pos embed
x = x + self.decoder_pos_embed
bs = x.shape[0]
bs, D, L = x.shape
T, F = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
# reshaped = x.reshape(bs, T, F, L).to(x.device)
end = th.zeros(bs, T, 1, L).to(x.device)
# x = th.cat((reshaped, zero_tensor), 2)
# import pdb;pdb.set_trace()
# assert x.shape == [bs, T, F + 1, L]
for i in range(len(self.de_blocks)):
block = self.de_blocks[i]
this_skip = input_skip
x, end = auto_grad_checkpoint(block, x, end, y, t0, y_lens, skip=this_skip, ids_keep=None)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, context_list, context_mask_list=None, cfg_scale=4.0, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
# import pdb
# pdb.set_trace()
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, context_list, context_mask_list=None)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify(self, x):
"""
x: (N, T, patch_size 0 * patch_size 1 * C)
imgs: (Bs. 256. 16. 8)
"""
if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]:
c = self.out_channels
p0 = self.x_embedder.patch_size[0]
p1 = self.x_embedder.patch_size[1]
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c))
x = th.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1))
return imgs
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
bs = x.shape[0]
torch_map = self.torch_map
c = self.out_channels
x = x.reshape(shape=(bs, lf, rf, lp, rp, c))
x = th.einsum('nhwpqc->nchwpq', x)
added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device)
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \
x[:, :, i, j, :, :]
# import pdb
# pdb.set_trace()
added_map = added_map[:][:][lpad:lm+lpad, rpad:rm+rpad]
return th.mul(added_map.to(x.device), torch_map.to(x.device))
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = th.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
# ascend: small is keep, large is remove
ids_shuffle = th.argsort(noise, dim=1)
ids_restore = th.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = th.gather(
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = th.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = th.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore, ids_keep
def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore):
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
x_ = th.cat([x, mask_tokens], dim=1)
x = th.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# add pos embed
x = x + self.decoder_pos_embed
# pass to the basic block
x_before = x
for sideblock in self.sideblocks:
x = sideblock(x, y, t0, y_lens, ids_keep=None)
# masked shortcut
mask = mask.unsqueeze(dim=-1)
x = x*mask + (1-mask)*x_before
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.en_inblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.en_outblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.de_blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.sideblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0):
return
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda')
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
torch_map[xx:(xx+lp), yy:(yy+rp)]+=1
torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad]
self.torch_map = th.reciprocal(torch_map)
@property
def dtype(self):
return next(self.parameters()).dtype
class PixArt_MDT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
decode_layer = int(decode_layer)
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
# self.x_embedder = PatchEmbed_1D(input)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size[0] // self.patch_size[0] * 2
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
half_depth = (depth - decode_layer)//2
self.half_depth=half_depth
drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] # stochastic depth decay rule
drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)]
self.en_inblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False) for i in range(half_depth)
])
self.en_outblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True) for i in range(half_depth)
])
self.de_blocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True) for i in range(decode_layer)
])
self.sideblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio,
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False) for _ in range(1)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
self.decoder_pos_embed = nn.Parameter(th.zeros(
1, num_patches, hidden_size), requires_grad=True)
if mask_ratio is not None:
self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size))
self.mask_ratio = float(mask_ratio)
self.decode_layer = int(decode_layer)
else:
self.mask_token = nn.Parameter(th.zeros(
1, 1, hidden_size), requires_grad=False)
self.mask_ratio = None
self.decode_layer = int(decode_layer)
print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer)
self.initialize_weights()
def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
# print(f'debug_MDT : {x.shape[0]}')
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
# import pdb
# print(f'debug_MDT : {x.shape[0]}')
# pdb.set_trace()
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
# print(f'debug_MDT : {x.shape[0]}')
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
# print(f'debug_MDT : {x.shape[0]}')
y = self.y_embedder(y) # (N, L, D)
# if not self.training:
try:
mask = context_mask_list[0] # (N, L)
except:
mask = th.ones(x.shape[0], 1).to(x.device)
print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!")
assert mask is not None
# if mask is not None:
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
y_lens = [int(_) for _ in y_lens]
# print(f'debug_MDT : {x.shape[0]}')
input_skip = x
masked_stage = False
skips = []
# TODO : masking op for training
if self.mask_ratio is not None and self.training:
# masking: length -> length * mask_ratio
rand_mask_ratio = th.rand(1, device=x.device) # noise in [0, 1]
rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio # mask_ratio, mask_ratio + 0.2
# print(rand_mask_ratio)
x, mask, ids_restore, ids_keep = self.random_masking(
x, rand_mask_ratio)
masked_stage = True
for block in self.en_inblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None)
skips.append(x)
for block in self.en_outblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None)
if self.mask_ratio is not None and self.training:
x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore)
masked_stage = False
else:
# add pos embed
x = x + self.decoder_pos_embed
for i in range(len(self.de_blocks)):
block = self.de_blocks[i]
this_skip = input_skip
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, y, mask)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify(self, x):
"""
x: (N, T, patch_size 0 * patch_size 1 * C)
imgs: (Bs. 256. 16. 8)
"""
if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]:
c = self.out_channels
p0 = self.x_embedder.patch_size[0]
p1 = self.x_embedder.patch_size[1]
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c))
x = th.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1))
return imgs
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
bs = x.shape[0]
torch_map = self.torch_map
c = self.out_channels
x = x.reshape(shape=(bs, lf, rf, lp, rp, c))
x = th.einsum('nhwpqc->nchwpq', x)
added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device)
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \
x[:, :, i, j, :, :]
added_map = added_map[:, :, lpad:lm+lpad, rpad:rm+rpad]
return th.mul(added_map, torch_map.to(added_map.device))
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = th.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
# ascend: small is keep, large is remove
ids_shuffle = th.argsort(noise, dim=1)
ids_restore = th.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = th.gather(
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = th.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = th.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore, ids_keep
def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore):
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
x_ = th.cat([x, mask_tokens], dim=1)
x = th.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# add pos embed
x = x + self.decoder_pos_embed
# pass to the basic block
x_before = x
for sideblock in self.sideblocks:
x = sideblock(x, y, t0, y_lens, ids_keep=None)
# masked shortcut
mask = mask.unsqueeze(dim=-1)
x = x*mask + (1-mask)*x_before
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.en_inblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.en_outblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.de_blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.sideblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0):
return
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda')
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
torch_map[xx:(xx+lp), yy:(yy+rp)]+=1
torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad]
self.torch_map = th.reciprocal(torch_map)
@property
def dtype(self):
return next(self.parameters()).dtype
class PixArt_MDT_FIT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
decode_layer = int(decode_layer)
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
# self.x_embedder = PatchEmbed_1D(input)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size[0] // self.patch_size[0] * 2
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
half_depth = (depth - decode_layer)//2
self.half_depth=half_depth
drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] # stochastic depth decay rule
drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)]
self.en_inblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False) for i in range(half_depth)
])
self.en_outblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True) for i in range(half_depth)
])
self.de_blocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True) for i in range(decode_layer)
])
self.sideblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio,
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False) for _ in range(1)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
self.decoder_pos_embed = nn.Parameter(th.zeros(
1, num_patches, hidden_size), requires_grad=True)
if mask_ratio is not None:
self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size))
self.mask_ratio = float(mask_ratio)
self.decode_layer = int(decode_layer)
else:
self.mask_token = nn.Parameter(th.zeros(
1, 1, hidden_size), requires_grad=False)
self.mask_ratio = None
self.decode_layer = int(decode_layer)
print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer)
self.initialize_weights()
def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
# import pdb
# pdb.set_trace()
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
# if not self.training:
try:
mask = context_mask_list[0] # (N, L)
except:
mask = th.ones(x.shape[0], 1).to(x.device)
print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!")
assert mask is not None
# if mask is not None:
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
y_lens = [int(_) for _ in y_lens]
input_skip = x
masked_stage = False
skips = []
# TODO : masking op for training
if self.mask_ratio is not None and self.training:
# masking: length -> length * mask_ratio
rand_mask_ratio = th.rand(1, device=x.device) # noise in [0, 1]
rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio # mask_ratio, mask_ratio + 0.2
# print(rand_mask_ratio)
x, mask, ids_restore, ids_keep = self.random_masking(
x, rand_mask_ratio)
masked_stage = True
for block in self.en_inblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None)
skips.append(x)
for block in self.en_outblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None)
if self.mask_ratio is not None and self.training:
x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore)
masked_stage = False
else:
# add pos embed
x = x + self.decoder_pos_embed
for i in range(len(self.de_blocks)):
block = self.de_blocks[i]
this_skip = input_skip
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, y, mask)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify(self, x):
"""
x: (N, T, patch_size 0 * patch_size 1 * C)
imgs: (Bs. 256. 16. 8)
"""
c = self.out_channels
p0 = self.x_embedder.patch_size[0]
p1 = self.x_embedder.patch_size[1]
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c))
x = th.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1))
return imgs
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = th.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
# ascend: small is keep, large is remove
ids_shuffle = th.argsort(noise, dim=1)
ids_restore = th.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = th.gather(
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = th.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = th.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore, ids_keep
def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore):
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
x_ = th.cat([x, mask_tokens], dim=1)
x = th.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# add pos embed
x = x + self.decoder_pos_embed
# pass to the basic block
x_before = x
for sideblock in self.sideblocks:
x = sideblock(x, y, t0, y_lens, ids_keep=None)
# masked shortcut
mask = mask.unsqueeze(dim=-1)
x = x*mask + (1-mask)*x_before
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
# Replace the absolute embedding with 2d-rope position embedding:
# pos_embed =
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.en_inblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.en_outblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.de_blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.sideblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@property
def dtype(self):
return next(self.parameters()).dtype
class PixArt_Slow(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
# self.x_embedder = PatchEmbed_1D(input)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size[0] // self.patch_size[0] * 2
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
PixArtBlock_Slow(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
self.initialize_weights()
def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
mask = context_mask_list[0] # (N, L)
assert mask is not None
# if mask is not None:
# y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
# y_lens = mask.sum(dim=1).tolist()
# y_lens = [int(_) for _ in y_lens]
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, mask) # (N, T, D) #support grad checkpoint
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, y, mask)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify(self, x):
"""
x: (N, T, patch_size 0 * patch_size 1 * C)
imgs: (Bs. 256. 16. 8)
"""
c = self.out_channels
p0 = self.x_embedder.patch_size[0]
p1 = self.x_embedder.patch_size[1]
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c))
x = th.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1))
return imgs
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
# for block in self.blocks:
# nn.init.constant_(block.cross_attn.proj.weight, 0)
# nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@property
def dtype(self):
return next(self.parameters()).dtype
class PixArtBlock_1D(nn.Module):
"""
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, use_rel_pos=False, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True,
input_size=None,
use_rel_pos=use_rel_pos, **block_kwargs)
# self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.window_size = window_size
self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, **kwargs):
B, N, C = x.shape
# x [3, 133, 1152]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
class PixArt_1D(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
self.x_embedder = PatchEmbed_1D(input_size, in_channels, hidden_size)
# self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.p_enc_1d_model = PositionalEncoding1D(hidden_size)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
PixArtBlock_1D(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
window_size=0,
use_rel_pos=False)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(hidden_size, (1, input_size[1]), self.out_channels)
self.initialize_weights()
# if config:
# logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
# logger.warning(f"lewei scale: {self.lewei_scale}, base size: {self.base_size}")
# else:
# print(f'Warning: lewei scale: {self.lewei_scale}, base size: {self.base_size}')
def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
x = self.x_embedder(x) # (N, T, D)
pos_embed = self.p_enc_1d_model(x)
x = x + pos_embed
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
try:
mask = context_mask_list[0] # (N, L)
except:
mask = th.ones(x.shape[0], 1).to(x.device)
print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!")
assert mask is not None
# if mask is not None:
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
y_lens = [int(_) for _ in y_lens]
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify_1D(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, y, mask)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify_1D(self, x):
"""
"""
c = self.out_channels
x = x.reshape(shape=(x.shape[0], self.input_size[0], self.input_size[1], c))
x = th.einsum('btfc->bctf', x)
# imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
# pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
# self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@property
def dtype(self):
return next(self.parameters()).dtype
class PixArt_Slow_1D(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
self.x_embedder = PatchEmbed_1D(input_size, in_channels, hidden_size)
# self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.p_enc_1d_model = PositionalEncoding1D(hidden_size)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
PixArtBlock_Slow(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
window_size=0,
use_rel_pos=False)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(hidden_size, (1, input_size[1]), self.out_channels)
self.initialize_weights()
# if config:
# logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
# logger.warning(f"lewei scale: {self.lewei_scale}, base size: {self.base_size}")
# else:
# print(f'Warning: lewei scale: {self.lewei_scale}, base size: {self.base_size}')
def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
x = self.x_embedder(x) # (N, T, D)
pos_embed = self.p_enc_1d_model(x)
x = x + pos_embed
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
mask = context_mask_list[0] # (N, L)
assert mask is not None
# if mask is not None:
# y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
# y_lens = mask.sum(dim=1).tolist()
# y_lens = [int(_) for _ in y_lens]
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, mask) # (N, T, D) #support grad checkpoint
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify_1D(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, y, mask)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify_1D(self, x):
"""
"""
c = self.out_channels
x = x.reshape(shape=(x.shape[0], self.input_size[0], self.input_size[1], c))
x = th.einsum('btfc->bctf', x)
# imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
# pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
# self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
# for block in self.blocks:
# nn.init.constant_(block.cross_attn.proj.weight, 0)
# nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@property
def dtype(self):
return next(self.parameters()).dtype
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, lewei_scale=1.0, base_size=16):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
# import pdb
# pdb.set_trace()
if isinstance(grid_size, int):
grid_size = to_2tuple(grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0]/base_size) / lewei_scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1]/base_size) / lewei_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
return np.concatenate([emb_h, emb_w], axis=1)
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
return np.concatenate([emb_sin, emb_cos], axis=1)
# if __name__ == '__main__' :
# import pdb
# pdb.set_trace()
# model = PixArt_1D().to('cuda')
# # x: (N, T, patch_size 0 * patch_size 1 * C)
# th.manual_seed(233)
# # x = th.rand(1, 4*16, 16*4*16).to('cuda')
# x = th.rand(3, 8, 256, 16).to('cuda')
# t = th.tensor([1, 2, 3]).to('cuda')
# c = th.rand(3, 20, 1024).to('cuda')
# c_mask = th.ones(3, 20).to('cuda')
# c_list = [c]
# c_mask_list = [c_mask]
# y = model.forward(x, t, c_list, c_mask_list)
# res = model.unpatchify(x)
# class DiTModel(nn.Module):
# """
# The full UNet model with attention and timestep embedding.
# :param in_channels: channels in the input Tensor.
# :param model_channels: base channel count for the model.
# :param out_channels: channels in the output Tensor.
# :param num_res_blocks: number of residual blocks per downsample.
# :param attention_resolutions: a collection of downsample rates at which
# attention will take place. May be a set, list, or tuple.
# For example, if this contains 4, then at 4x downsampling, attention
# will be used.
# :param dropout: the dropout probability.
# :param channel_mult: channel multiplier for each level of the UNet.
# :param conv_resample: if True, use learned convolutions for upsampling and
# downsampling.
# :param dims: determines if the signal is 1D, 2D, or 3D.
# :param num_classes: if specified (as an int), then this model will be
# class-conditional with `num_classes` classes.
# :param use_checkpoint: use gradient checkpointing to reduce memory usage.
# :param num_heads: the number of attention heads in each attention layer.
# :param num_heads_channels: if specified, ignore num_heads and instead use
# a fixed channel width per attention head.
# :param num_heads_upsample: works with num_heads to set a different number
# of heads for upsampling. Deprecated.
# :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
# :param resblock_updown: use residual blocks for up/downsampling.
# :param use_new_attention_order: use a different attention pattern for potentially
# increased efficiency.
# """
# def __init__(
# self,
# input_size,
# patch_size,
# overlap,
# in_channels,
# embed_dim,
# model_channels,
# out_channels,
# dims=2,
# extra_film_condition_dim=None,
# use_checkpoint=False,
# use_fp16=False,
# num_heads=-1,
# num_head_channels=-1,
# use_scale_shift_norm=False,
# use_new_attention_order=False,
# transformer_depth=1, # custom transformer support
# context_dim=None, # custom transformer support
# n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
# legacy=True,
# ):
# super().__init__()
# self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, embed_dim, bias=True)
# num_patches = self.x_embedder.num_patches
# self.pos_embed = nn.Parameter(th.zeros(1, num_patches, embed_dim), requires_grad=False)
# self.blocks = nn.ModuleList([
# DiTBlock_crossattn
# ])
# def convert_to_fp16(self):
# """
# Convert the torso of the model to float16.
# """
# # self.input_blocks.apply(convert_module_to_f16)
# # self.middle_block.apply(convert_module_to_f16)
# # self.output_blocks.apply(convert_module_to_f16)
# def convert_to_fp32(self):
# """
# Convert the torso of the model to float32.
# """
# # self.input_blocks.apply(convert_module_to_f32)
# # self.middle_block.apply(convert_module_to_f32)
# # self.output_blocks.apply(convert_module_to_f32)
# def forward(
# self,
# x,
# timesteps=None,
# y=None,
# context_list=None,
# context_attn_mask_list=None,
# **kwargs,
# ):
# """
# Apply the model to an input batch.
# :param x: an [N x C x ...] Tensor of inputs.
# :param timesteps: a 1-D batch of timesteps.
# :param context: conditioning plugged in via crossattn
# :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional
# :return: an [N x C x ...] Tensor of outputs.
# """
# x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
# t = self.t_embedder(timesteps) # (N, D)
# y = self.y_embedder(y, self.training) # (N, D)
# c = t + y # (N, D)
# for block in self.blocks:
# x = block(x, c) # (N, T, D)
# x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
# x = self.unpatchify(x) # (N, out_channels, H, W)