Spaces:
Paused
Paused
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Part of the code is from https://github.com/m-bain/frozen-in-time/blob/main/model/video_transformer.py | |
# Modified by Yue Zhao | |
# The original code is under MIT License | |
""" | |
Implementations of Video Transformers in PyTorch | |
A PyTorch implementation of space-time transformer as described in | |
'Frozen in Time: A Joint Image and Video Encoder for End-to-End Retrieval' - https://arxiv.org/abs/2104.00650 | |
A PyTorch implementation of timesformer as described in | |
'Is Space-Time Attention All You Need for Video Understanding?' - https://arxiv.org/abs/2102.05095 | |
Acknowledgments: | |
- This code builds on Ross Wightman's vision_transformer code in pytorch-image-models: | |
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py | |
- It is also inspired by lucidrains timesformer implementation: | |
https://github.com/lucidrains/TimeSformer-pytorch | |
Hacked together by Max Bain | |
""" | |
from collections import OrderedDict, defaultdict | |
from functools import partial, reduce | |
import operator | |
import copy | |
import torch | |
import torch.utils.checkpoint as checkpoint | |
from einops import rearrange, repeat | |
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | |
from torch import einsum, nn | |
import torch.nn.functional as F | |
import pdb | |
from lavila.models.prompt_tuning import VisualPromptLearner, CMM | |
def attn(q, k, v): | |
sim = einsum('b i d, b j d -> b i j', q, k) | |
attn = sim.softmax(dim=-1) | |
out = einsum('b i j, b j d -> b i d', attn, v) | |
return out | |
class Mlp(nn.Module): | |
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(hidden_features, out_features) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class VideoPatchEmbed(nn.Module): | |
""" Video to Patch Embedding | |
""" | |
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, | |
num_frames=8, ln_pre=False): | |
super().__init__() | |
img_size = to_2tuple(img_size) | |
patch_size = to_2tuple(patch_size) | |
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * num_frames | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.num_patches = num_patches | |
self.num_frames = num_frames | |
self.embed_dim = embed_dim | |
# ln_pre is inserted to be compatible with CLIP-style model | |
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=not ln_pre) | |
def forward(self, x): | |
B, F, C, H, W = x.shape | |
assert F <= self.num_frames | |
x = x.view(-1, C, H, W) | |
x = self.proj(x) | |
return x | |
class VarAttention(nn.Module): | |
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., | |
initialize='random', num_tokens=0): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights | |
self.scale = qk_scale or head_dim ** -0.5 | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.proj = nn.Linear(dim, dim) | |
if initialize == 'zeros': | |
self.qkv.weight.data.fill_(0) | |
self.qkv.bias.data.fill_(0) | |
# fill proj weight with 1 here to improve training dynamics. Otherwise temporal attention inputs | |
# are multiplied by 0*0, which is hard for the model to move out of. | |
self.proj.weight.data.fill_(1) | |
self.proj.bias.data.fill_(0) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj_drop = nn.Dropout(proj_drop) | |
self.num_tokens = num_tokens | |
def forward(self, x, einops_from, einops_to, einops_dims, cfg): | |
style = cfg.get('style', 'default') | |
pt_att = cfg.get('pt_att', True) | |
n_seg = cfg.get('n_seg', 4) | |
if 'VoP' in style: | |
return self.forward_VoP(x, einops_from, einops_to, einops_dims, n_seg) | |
elif style == 'attall': | |
return self.forward_attall(x, pt_att) | |
else: | |
return self.forward_features(x, einops_from, einops_to, einops_dims, pt_att) | |
def forward_features(self, x, einops_from, einops_to, einops_dims, pt_att=True): | |
h = self.num_heads | |
num_tokens = self.num_tokens | |
if self.num_tokens > 0 and not pt_att: | |
prompts = x[:, 1:self.num_tokens+1, :] | |
x = torch.cat(( | |
x[:, :1, :], # cls_token | |
x[:, self.num_tokens+1:, :] # patch embeddings | |
), dim=1) | |
num_tokens = 0 | |
# project x to q, k, v values | |
q, k, v = self.qkv(x).chunk(3, dim=-1) | |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | |
q *= self.scale | |
# splice out CLS token at index 1 (and prompts) | |
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:num_tokens+1], t[:, num_tokens+1:]), (q, k, v)) # Bh x () x d | |
# let CLS token attend to key / values of all patches across time and space | |
cls_out = attn(cls_q, k, v) # Bh x (1 + p) x d | |
# rearrange across time or space | |
q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_)) # Bh x NT x d -> Bhr x s x d | |
# expand cls token keys and values across time or space and concat | |
r = q_.shape[0] // cls_k.shape[0] | |
cls_k, cls_v = map(lambda t: repeat(t, 'b p d -> (b r) p d', r=r), (cls_k, cls_v)) # Bhr x (1 + p) x d | |
k_ = torch.cat((cls_k, k_), dim=1) | |
v_ = torch.cat((cls_v, v_), dim=1) | |
# attention | |
out = attn(q_, k_, v_) | |
# merge back time or space | |
out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) # Bh x NT x d | |
# concat back the cls token | |
out = torch.cat((cls_out, out), dim=1) # Bh x (1 + p + NT) x d | |
# merge back the heads | |
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) # B x (1 + p + NT) x hd | |
if self.num_tokens > 0 and not pt_att: | |
out = torch.cat(( | |
out[:, :1, :], # cls_tokens | |
prompts, | |
out[:, 1:, :] # patch embeddings | |
), dim=1) | |
# to out | |
x = self.proj(out) | |
x = self.proj_drop(x) | |
return x | |
def forward_VoP(self, x, einops_from, einops_to, einops_dims, n_seg=4): | |
# position-specific prompts for spatial attention | |
h = self.num_heads | |
num_tokens = self.num_tokens | |
# project x to q, k, v values | |
q, k, v = self.qkv(x).chunk(3, dim=-1) # B x (1+p+NT) x hd | |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) # Bh x (1+p+NT) x d | |
q *= self.scale | |
# splice out CLS token at index 1 and prompts | |
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:num_tokens+1], t[:, num_tokens+1:]), (q, k, v)) # Bh x () x d | |
# let CLS token attend to key / values of all patches across time and space | |
cls_out = attn(cls_q[:, :1, :], k, v) # cls token: Bh x 1 x d | |
# segment prompts into s segments in time | |
pstep = num_tokens // n_seg | |
pseg = [range(st, en) for st, en in zip(range(1, num_tokens+1, pstep), range(pstep+1, num_tokens+2, pstep))] | |
p_q, p_k, p_v = map(lambda t: rearrange(t[:, pseg, :], 'b s p d -> (b s) p d'), (cls_q, cls_k, cls_v)) # prompt query: (Bh x n_seg) x p_per_seg x d | |
# segment patch embeddings into s segments in time | |
q_, k_, v_ = map(lambda t: rearrange(t, 'b (f n) d -> b f n d', **einops_dims), (q_, k_, v_)) # Bh x T x N x d | |
num_frames = k_.size(1) | |
tstep = num_frames // n_seg | |
tseg = [range(st, en) for st, en in zip(range(0, num_frames, tstep), range(tstep, num_frames+1, tstep))] | |
q_, k_, v_ = map(lambda t: t[:, tseg, ...], (q_, k_, v_)) # Bh x n_seg x f_per_seg x n x d | |
q_, k_, v_ = map(lambda t: rearrange(t, 'b s f n d -> (b s) (f n) d'), (q_, k_, v_)) # (Bh x n_seg) x (f_per_seg x n) x d | |
# concatenate prompts and patch embeddings | |
k_, v_ = map(lambda t: torch.cat((t[0], t[1]), dim=1), ((p_k, k_), (p_v, v_))) | |
p_out = attn(p_q, k_, v_) # (Bh x n_seg) x p_per_seg x d | |
out = attn(q_, k_, v_) # (Bh x n_seg) x (f_per_seg x n) x d | |
p_out = rearrange(p_out, '(b s) p d -> b (s p) d', s=n_seg) # Bh x p x d | |
out = rearrange(out, '(b s) (f n) d -> b (s f n) d', s=n_seg, f=tstep) # Bh x NT x d | |
# merge tokens | |
out = torch.cat((cls_out, p_out, out), dim=1) # Bh x (1+p+NT) x d | |
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) # B x (NT+1) x hd | |
# to out | |
x = self.proj(out) | |
x = self.proj_drop(x) | |
return x | |
def forward_attall(self, x, pt_att=True): | |
h = self.num_heads | |
if self.num_tokens > 0 and not pt_att: | |
prompts = x[:, 1:self.num_tokens+1, :] | |
x = torch.cat(( | |
x[:, :1, :], # cls_token | |
x[:, self.num_tokens+1:, :] # patch embeddings | |
), dim=1) | |
# project x to q, k, v values | |
q, k, v = self.qkv(x).chunk(3, dim=-1) | |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | |
q *= self.scale | |
# all tokens attend to all tokens | |
out = attn(q, k, v) | |
# merge back the heads | |
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) # B x (1 + p + NT) x hd | |
if self.num_tokens > 0 and not pt_att: | |
out = torch.cat(( | |
out[:, :1, :], # cls_tokens | |
prompts, | |
out[:, 1:, :] # patch embeddings | |
), dim=1) | |
# to out | |
x = self.proj(out) | |
x = self.proj_drop(x) | |
return x | |
class SpaceTimeBlock(nn.Module): | |
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., | |
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, time_init='zeros', | |
attention_style='frozen-in-time', is_tanh_gating=False, num_tokens=0, split_st=False): | |
super().__init__() | |
self.split_st = split_st # split spatial and temporal prompts | |
if split_st: | |
num_tokens = num_tokens // 2 | |
self.num_tokens = num_tokens # learnable prompts | |
self.norm1 = norm_layer(dim) | |
self.attn = VarAttention( | |
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, num_tokens=num_tokens) | |
self.timeattn = VarAttention( | |
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, num_tokens=num_tokens, | |
initialize=time_init) | |
if is_tanh_gating: | |
self.alpha_timeattn = nn.Parameter(torch.zeros([])) | |
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) | |
self.norm3 = norm_layer(dim) | |
self.attention_style = attention_style | |
def forward(self, x, einops_from_space, einops_to_space, einops_from_time, einops_to_time, | |
time_n, space_f, use_checkpoint=False, pt_spt=True, pt_tmp=True, style='default', n_seg=4): | |
if self.split_st: | |
spatial_prompts = x[:, 1:self.num_tokens+1, :] | |
x = torch.cat(( | |
x[:, :1, :], # cls_token | |
x[:, self.num_tokens+1:, :] # temporal prompts and patch embeddings | |
), dim=1) | |
if use_checkpoint: | |
time_output = checkpoint.checkpoint( | |
self.timeattn, self.norm3(x), einops_from_time, einops_to_time, {"n": time_n}, {'pt_att': pt_tmp} | |
) | |
else: | |
time_output = self.timeattn(self.norm3(x), einops_from_time, einops_to_time, {"n": time_n}, {'pt_att': pt_tmp}) | |
if hasattr(self, "alpha_timeattn"): | |
time_output = torch.tanh(self.alpha_timeattn) * time_output | |
time_residual = x + time_output | |
if self.split_st: | |
temporal_prompts = time_residual[:, 1:self.num_tokens+1, :] | |
time_residual = torch.cat(( | |
time_residual[:, :1, :], # cls_token | |
spatial_prompts, | |
time_residual[:, self.num_tokens+1:, :] # patch embeddings | |
), dim=1) | |
cfg = {'style': style, 'pt_att': pt_spt, 'n_seg': n_seg} | |
if use_checkpoint: | |
space_output = checkpoint.checkpoint( | |
self.attn, self.norm1(time_residual), einops_from_space, einops_to_space, {"f": space_f}, cfg | |
) | |
else: | |
space_output = self.attn(self.norm1(time_residual), einops_from_space, | |
einops_to_space, {"f": space_f}, cfg) | |
if self.attention_style == 'frozen-in-time': | |
space_residual = x + self.drop_path(space_output) | |
else: | |
raise NotImplementedError | |
if self.split_st: | |
space_residual = torch.cat(( | |
space_residual[:, :self.num_tokens+1, :], # cls_token and spacial prompts | |
temporal_prompts, | |
space_residual[:, self.num_tokens+1:, :] # patch embeddings | |
), dim=1) | |
x = space_residual + self.drop_path(self.mlp(self.norm2(space_residual))) | |
return x | |
class SpaceTimeTransformer(nn.Module): | |
""" Vision Transformer | |
A PyTorch impl of : `Space-Time Transformer` from Frozen-in-time - by Max Bain. | |
https://arxiv.org/abs/2104.00650 | |
Based off: | |
- ViT implementation from the timm library [https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py] | |
lucidrains timesformer implementation [https://github.com/lucidrains/TimeSformer-pytorch]. | |
Notable differences: | |
- allows for variable length input frames (<= num_frames) | |
- allows for variable length input resolution (<= (img_size, img_size)) [UNTESTED] | |
- different attention block mechanism | |
""" | |
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, | |
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, | |
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, | |
num_frames=8, time_init='rand', attention_style='frozen-in-time', ln_pre=False, | |
act_layer=nn.GELU, is_tanh_gating=False, tune_bias=False, prompt_cfg={}): | |
""" | |
Args: | |
img_size (int, tuple): input image size | |
patch_size (int, tuple): patch size | |
in_chans (int): number of input channels | |
num_classes (int): number of classes for classification head | |
embed_dim (int): embedding dimension | |
depth (int): depth of transformer | |
num_heads (int): number of attention heads | |
mlp_ratio (int): ratio of mlp hidden dim to embedding dim | |
qkv_bias (bool): enable bias for qkv if True | |
qk_scale (float): override default qk scale of head_dim ** -0.5 if set | |
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set | |
drop_rate (float): dropout rate | |
attn_drop_rate (float): attention dropout rate | |
drop_path_rate (float): stochastic depth rate | |
hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module | |
norm_layer: (nn.Module): normalization layer | |
num_frames: (int) maximum number of frames expected as input | |
time_init: (str) how to initialise the time attention layer, 'zeros' allows for the timesformer to start off | |
as ViT. | |
attention_style: (str) how to attend to space and time. | |
""" | |
super().__init__() | |
self.num_classes = num_classes | |
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | |
self.num_frames = num_frames | |
self.embed_dim = embed_dim | |
self.tune_bias = tune_bias | |
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | |
print("######USING ATTENTION STYLE: ", attention_style) | |
self.param_list = [] | |
if hybrid_backbone is not None: | |
raise NotImplementedError('hybrid backbone not implemented') | |
else: | |
self.patch_embed = VideoPatchEmbed( | |
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=num_frames, ln_pre=ln_pre) | |
self.param_list += list(self.patch_embed.parameters()) | |
num_patches = self.patch_embed.num_patches | |
self.patches_per_frame = num_patches // num_frames | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.pos_embed = nn.Parameter( | |
torch.zeros(1, self.patches_per_frame + 1, | |
embed_dim)) # remember to take pos_embed[1:] for tiling over time | |
self.temporal_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim)) | |
self.param_list += [self.cls_token, self.pos_embed, self.temporal_embed] | |
if ln_pre: | |
self.ln_pre = nn.LayerNorm(embed_dim) | |
if self.tune_bias: | |
self.param_list += [m for n, m in self.ln_pre.named_parameters() if 'bias' not in n] | |
else: | |
self.param_list += list(self.ln_pre.parameters()) | |
else: | |
self.ln_pre = None | |
self.pos_drop = nn.Dropout(p=drop_rate) | |
# config for prompts | |
self.num_tokens = prompt_cfg.get('num_tokens', 0) | |
self.prompt_dim = prompt_cfg.get('prompt_dim', 768) | |
self.pt_spt = prompt_cfg.pop('pt_spt', True) | |
self.pt_tmp = prompt_cfg.pop('pt_tmp', True) | |
self.style = prompt_cfg.pop('style', 'default') | |
self.query = prompt_cfg.pop('query', 'cls') | |
self.n_seg = prompt_cfg.pop('n_seg', 4) | |
self.k_s = prompt_cfg.pop('K_s', depth) | |
self.st = prompt_cfg.pop('st', 0) | |
self.end = prompt_cfg.pop('end', depth) | |
assert self.st <= self.end | |
if self.style == 'default': | |
print(f'Prompting {self.st}-{self.end} layer of the visual backbone') | |
elif self.style == 'VoP_c' and self.k_s < depth: | |
self.prompt_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim)) | |
elif self.style == 'VoP_c_pool': | |
self.prompt_temp_embed = nn.Parameter(torch.zeros(1, self.n_seg, embed_dim)) | |
trunc_normal_(self.prompt_temp_embed, std=.02) | |
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule | |
blocks = [] | |
for i in range(depth): | |
stblk_cfg = {} | |
if self.num_tokens > 0: | |
stblk_cfg = {'num_tokens': prompt_cfg['num_tokens'], 'split_st': prompt_cfg.get('split_st', False)} | |
blocks.append( | |
SpaceTimeBlock( | |
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, | |
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, time_init=time_init, | |
attention_style=attention_style, act_layer=act_layer, is_tanh_gating=is_tanh_gating, **stblk_cfg) | |
) | |
self.blocks = nn.ModuleList(blocks) | |
self.norm = norm_layer(embed_dim) | |
if self.tune_bias: | |
self.param_list += reduce(operator.add, [[m for n, m in x.named_parameters() if 'bias' not in n] for x in self.blocks]) | |
self.param_list += [m for n, m in self.norm.named_parameters() if 'bias' not in n] | |
else: | |
self.param_list += reduce(operator.add, [list(x.parameters()) for x in self.blocks]) | |
self.param_list += list(self.norm.parameters()) | |
# Representation layer | |
if representation_size: | |
self.num_features = representation_size | |
self.pre_logits = nn.Sequential(OrderedDict([ | |
('fc', nn.Linear(embed_dim, representation_size)), | |
('act', nn.Tanh()) | |
])) | |
if self.tune_bias: | |
self.param_list += [m for n, m in self.pre_logits.named_parameters() if 'bias' not in n] | |
else: | |
self.param_list += list(self.pre_logits.parameters()) | |
else: | |
self.pre_logits = nn.Identity() | |
# Classifier head | |
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() | |
trunc_normal_(self.pos_embed, std=.02) | |
trunc_normal_(self.cls_token, std=.02) | |
# if num_frames > 1, then we perform ViT inflation and initialise time attention to zero so not necessary. | |
if num_frames == 1: | |
self.apply(self._init_weights) | |
# einops transformations | |
self.einops_from_space = 'b (f n) d' | |
self.einops_to_space = '(b f) n d' | |
self.einops_from_time = 'b (f n) d' | |
self.einops_to_time = '(b n) f d' | |
# freeze the backbone and only learn the prompts | |
self.prompt_learner = None | |
if self.num_tokens > 0: | |
if 'VoP_c' in self.style: | |
pool = prompt_cfg.pop('pool', {}) if 'pool' in self.style else {} | |
if self.k_s > 0: | |
self.prompt_generator = CMM(self.num_tokens // self.n_seg, self.n_seg, embed_dim, self.prompt_dim, num_layer=self.k_s, \ | |
shared=prompt_cfg.get('deep_shared', False), pool=pool) | |
n_prompt_layer = depth - self.k_s | |
else: | |
n_prompt_layer = self.end - self.st | |
if n_prompt_layer > 0: | |
prompt_cfg['num_layers'] = n_prompt_layer | |
prompt_cfg['prompt_dim'] = embed_dim | |
self.prompt_learner = VisualPromptLearner(patch_size, embed_dim, **prompt_cfg) | |
for p in self.param_list: | |
p.requies_grad = False | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def no_weight_decay(self): | |
return {'pos_embed', 'cls_token'} | |
def get_classifier(self): | |
return self.head | |
def reset_classifier(self, num_classes, global_pool=''): | |
self.num_classes = num_classes | |
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |
def forward_features(self, x, use_checkpoint=False, cls_at_last=True, istrain=False, gamma=1.0): | |
# print(x.shape) | |
b, curr_frames, channels, _, _ = x.shape | |
x = self.patch_embed(x) | |
x = x.flatten(2).transpose(2, 1) | |
x = x.reshape(b, -1, self.patch_embed.embed_dim) | |
BF = x.shape[0] | |
cls_tokens = self.cls_token.expand(BF, -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |
x = torch.cat((cls_tokens, x), dim=1) | |
# positional embed needs to be tiled for each frame (this does [1,2,3] --> [1,2,3,1,2,3]...) | |
cls_embed = self.pos_embed[:, 0, :].unsqueeze(1) | |
tile_pos_embed = self.pos_embed[:, 1:, :].repeat(1, self.num_frames, 1) | |
# temporal embed needs to be repeated within each frame (this does [1,2,3] --> [1,1,1,2,2,2,3,3,3]...) | |
tile_temporal_embed = self.temporal_embed.repeat_interleave(self.patches_per_frame, 1) | |
total_pos_embed = tile_pos_embed + tile_temporal_embed | |
total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) # 1 x (NT + 1) x D | |
curr_patches = x.shape[1] | |
x = x + total_pos_embed[:, :curr_patches] # B x (NT + 1) x D | |
ps_loss = x.new_zeros([1]) | |
# incorporate prompts | |
if self.num_tokens > 0: | |
if 'VoP_c' in self.style and self.k_s > 0: | |
ctx, ps = self.prompt_generator(x[:, 1:, :], 0, istrain=istrain, gamma=gamma) | |
ps_loss += ps | |
if self.prompt_generator.use_bank: | |
prompt_temp_embed = self.prompt_temp_embed.repeat_interleave(self.num_tokens // self.n_seg, 1) | |
ctx = ctx + prompt_temp_embed | |
elif self.prompt_learner is not None: | |
ctx, ps = self.prompt_learner(x[:, :1, :], 0, istrain=istrain, gamma=gamma) | |
ps_loss += ps | |
if ctx.size(0) != BF: | |
ctx = ctx.expand(BF, -1, -1) | |
x = torch.cat(( | |
x[:, :1, :], # cls_token | |
ctx, | |
x[:, 1:, :] | |
), dim=1) | |
if self.ln_pre is not None: | |
x = self.ln_pre(x) | |
x = self.pos_drop(x) | |
n = self.patches_per_frame | |
f = curr_frames | |
for i, blk in enumerate(self.blocks): | |
if self.num_tokens > 0 and i > 0 and i >= self.st and i < self.end: | |
if 'VoP_c' in self.style: | |
if i < self.k_s: | |
ctx, ps = self.prompt_generator(x[:, self.num_tokens+1:, :], i, istrain=istrain, gamma=gamma) | |
ps_loss += ps | |
if self.prompt_generator.use_bank: | |
prompt_temp_embed = self.prompt_temp_embed.repeat_interleave(self.num_tokens // self.n_seg, 1) | |
ctx = ctx + prompt_temp_embed | |
else: | |
ctx, ps = self.prompt_learner(x[:, :1, :], i-self.k_s, istrain=istrain, gamma=gamma) | |
ps_loss += ps | |
if 'pool' in self.style: | |
prompt_embed = self.prompt_temp_embed.repeat_interleave(self.num_tokens // self.n_seg, 1) | |
else: | |
prompt_embed = self.prompt_embed.repeat_interleave(self.num_tokens // self.num_frames, 1) | |
ctx = ctx + prompt_embed | |
if ctx.size(0) != BF: | |
ctx = ctx.expand(BF, -1, -1) | |
elif (i - self.st) < self.prompt_learner.num_layers: | |
ctx, ps = self.prompt_learner(x[:, :1, :], i-self.st, istrain=istrain, gamma=gamma) | |
ps_loss += ps | |
if ctx.size(0) != BF: | |
ctx = ctx.expand(BF, -1, -1) | |
x = torch.cat(( | |
x[:, :1, :], # cls_token | |
ctx, | |
x[:, self.num_tokens+1:, :] | |
), dim=1) | |
style = 'default' if i >= self.k_s else self.style | |
pt_tmp = self.pt_tmp if i >= self.st and i < self.end else False | |
pt_spt = self.pt_spt if i >= self.st and i < self.end else False | |
x = blk(x, self.einops_from_space, self.einops_to_space, self.einops_from_time, | |
self.einops_to_time, | |
time_n=n, space_f=f, use_checkpoint=use_checkpoint, pt_spt=pt_spt, | |
pt_tmp=pt_tmp, style=style, n_seg=self.n_seg) | |
if cls_at_last: | |
x = self.norm(x) | |
x = x[:, 0] | |
x = self.pre_logits(x) | |
return x, ps_loss | |
else: | |
return self.norm(x), ps_loss | |
def forward(self, x, use_checkpoint=False, istrain=False, gamma=1.0): | |
# Note: B C T H W => B T C H W | |
# The default input order is different from the one in Frozen-in-Time | |
x = x.permute(0, 2, 1, 3, 4).contiguous() | |
x, ps_loss = self.forward_features(x, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma) | |
x = self.head(x) | |
return x, ps_loss | |
def train(self, mode=True): | |
if not isinstance(mode, bool): | |
raise ValueError("training mode is expected to be boolean") | |
self.training = mode | |
for m in self.modules(): | |
m.training = mode | |
if mode and self.num_tokens > 0: | |
for n, m in self.named_modules(): | |
if 'prompt' not in n: | |
m.training = False | |