wlmov / video_vae /modeling_enc_dec.py
multimodalart's picture
Upload 33 files
f0533a5 verified
raw
history blame
16.4 kB
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils.torch_utils import randn_tensor
from diffusers.models.attention_processor import SpatialNorm
from .modeling_block import (
UNetMidBlock2D,
CausalUNetMidBlock2D,
get_down_block,
get_up_block,
get_input_layer,
get_output_layer,
)
from .modeling_resnet import (
Downsample2D,
Upsample2D,
TemporalDownsample2x,
TemporalUpsample2x,
)
from .modeling_causal_conv import CausalConv3d, CausalGroupNorm
@dataclass
class DecoderOutput(BaseOutput):
r"""
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The decoded output sample from the last layer of the model.
"""
sample: torch.FloatTensor
class CausalVaeEncoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
spatial_down_sample: Tuple[bool, ...] = (True,),
temporal_down_sample: Tuple[bool, ...] = (False,),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: Tuple[int, ...] = (2,),
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
block_dropout: Tuple[int, ...] = (0.0,),
mid_block_add_attention=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = CausalConv3d(
in_channels,
block_out_channels[0],
kernel_size=3,
stride=1,
)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
down_block = get_down_block(
down_block_type,
num_layers=self.layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
add_spatial_downsample=spatial_down_sample[i],
add_temporal_downsample=temporal_down_sample[i],
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=None,
dropout=block_dropout[i],
)
self.down_blocks.append(down_block)
# mid
self.mid_block = CausalUNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
dropout=block_dropout[-1],
)
# out
self.conv_norm_out = CausalGroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3, stride=1)
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = self.conv_in(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# down
if is_torch_version(">=", "1.11.0"):
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), sample, is_init_image,
temporal_chunk, use_reentrant=False
)
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, is_init_image,
temporal_chunk, use_reentrant=False
)
else:
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample, is_init_image, temporal_chunk)
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, is_init_image, temporal_chunk)
else:
# down
for down_block in self.down_blocks:
sample = down_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
# middle
sample = self.mid_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
return sample
class CausalVaeDecoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
spatial_up_sample: Tuple[bool, ...] = (True,),
temporal_up_sample: Tuple[bool, ...] = (False,),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: Tuple[int, ...] = (2,),
norm_num_groups: int = 32,
act_fn: str = "silu",
mid_block_add_attention=True,
interpolate: bool = True,
block_dropout: Tuple[int, ...] = (0.0,),
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = CausalConv3d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = CausalUNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
dropout=block_dropout[-1],
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=self.layers_per_block[i],
in_channels=prev_output_channel,
out_channels=output_channel,
prev_output_channel=None,
add_spatial_upsample=spatial_up_sample[i],
add_temporal_upsample=temporal_up_sample[i],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=None,
resnet_time_scale_shift='default',
interpolate=interpolate,
dropout=block_dropout[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = CausalGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3, stride=1)
self.gradient_checkpointing = False
def forward(
self,
sample: torch.FloatTensor,
is_init_image=True,
temporal_chunk=False,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
sample = self.conv_in(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
is_init_image=is_init_image,
temporal_chunk=temporal_chunk,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
is_init_image=is_init_image,
temporal_chunk=temporal_chunk,
use_reentrant=False,
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample,
is_init_image=is_init_image, temporal_chunk=temporal_chunk,)
else:
# middle
sample = self.mid_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk,)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
return sample
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[2, 3, 4],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[2, 3, 4],
)
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor:
return self.mean