File size: 4,807 Bytes
f0533a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from typing import Tuple, Union
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from collections import deque
from einops import rearrange
from timm.models.layers import trunc_normal_
from IPython import embed
from torch import Tensor
from utils import (
is_context_parallel_initialized,
get_context_parallel_group,
get_context_parallel_world_size,
get_context_parallel_rank,
get_context_parallel_group_rank,
)
from .context_parallel_ops import (
conv_scatter_to_context_parallel_region,
conv_gather_from_context_parallel_region,
cp_pass_from_previous_rank,
)
def divisible_by(num, den):
return (num % den) == 0
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
def is_odd(n):
return not divisible_by(n, 2)
class CausalGroupNorm(nn.GroupNorm):
def forward(self, x: Tensor) -> Tensor:
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = super().forward(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
return x
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
pad_mode: str ='constant',
**kwargs
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.time_kernel_size = time_kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
dilation = kwargs.pop('dilation', 1)
self.pad_mode = pad_mode
if isinstance(stride, int):
stride = (stride, 1, 1)
time_pad = dilation * (time_kernel_size - 1)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.temporal_stride = stride[0]
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
self.time_uncausal_padding = (width_pad, width_pad, height_pad, height_pad, 0, 0)
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, **kwargs)
self.cache_front_feat = deque()
def _clear_context_parallel_cache(self):
del self.cache_front_feat
self.cache_front_feat = deque()
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def context_parallel_forward(self, x):
x = cp_pass_from_previous_rank(x, dim=2, kernel_size=self.time_kernel_size)
x = F.pad(x, self.time_uncausal_padding, mode='constant')
cp_rank = get_context_parallel_rank()
if cp_rank != 0:
if self.temporal_stride == 2 and self.time_kernel_size == 3:
x = x[:,:,1:]
x = self.conv(x)
return x
def forward(self, x, is_init_image=True, temporal_chunk=False):
# temporal_chunk: whether to use the temporal chunk
if is_context_parallel_initialized():
return self.context_parallel_forward(x)
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'
if not temporal_chunk:
x = F.pad(x, self.time_causal_padding, mode=pad_mode)
else:
assert not self.training, "The feature cache should not be used in training"
if is_init_image:
# Encode the first chunk
x = F.pad(x, self.time_causal_padding, mode=pad_mode)
self._clear_context_parallel_cache()
self.cache_front_feat.append(x[:, :, -2:].clone().detach())
else:
x = F.pad(x, self.time_uncausal_padding, mode=pad_mode)
video_front_context = self.cache_front_feat.pop()
self._clear_context_parallel_cache()
if self.temporal_stride == 1 and self.time_kernel_size == 3:
x = torch.cat([video_front_context, x], dim=2)
elif self.temporal_stride == 2 and self.time_kernel_size == 3:
x = torch.cat([video_front_context[:,:,-1:], x], dim=2)
self.cache_front_feat.append(x[:, :, -2:].clone().detach())
x = self.conv(x)
return x |