Spaces:
Paused
Paused
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import numpy as np | |
import timm | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from lavila.models.openai_clip import load as load_openai_clip | |
from lavila.models.openai_model import QuickGELU, Transformer | |
from lavila.models.timesformer import SpaceTimeTransformer | |
from lavila.models.utils import remap_keys, rsetattr | |
from lavila.models.prompt_tuning import PromptLearner | |
class CLIP(nn.Module): | |
def __init__(self, | |
cfg, | |
embed_dim: int, | |
# vision | |
vision_width: int, | |
vision_model: nn.Module, | |
# text | |
context_length: int, | |
vocab_size: int, | |
transformer_width: int, | |
transformer_heads: int, | |
transformer_layers: int, | |
tempearture_init=0.07, | |
**kwargs, | |
): | |
super().__init__() | |
self.context_length = context_length | |
self.vision_width = vision_width | |
self.tune_bias = cfg.get('tune_bias', False) | |
self.freeze_vis_backbone = cfg.get('freeze_vis_backbone', False) | |
self.freeze_txt_backbone = cfg.get('freeze_txt_backbone', False) | |
self.visual = vision_model | |
self.t_step = cfg.get('t_step', self.visual.num_frames) | |
txt_prompt_cfg = cfg.get('text_prompt', {}) | |
self.n_ctx = txt_prompt_cfg.get('n_ctx', 0) | |
self.txt_use_bank = txt_prompt_cfg.get('use_bank', False) | |
if self.txt_use_bank: | |
self.transformer = Transformer( | |
width=transformer_width, | |
layers=transformer_layers, | |
heads=transformer_heads, | |
attn_mask=self.build_attention_mask(), | |
prompt_cfg=txt_prompt_cfg, | |
prompt_learner=PromptLearner(transformer_width, self.n_ctx), | |
prompt_generator=self.visual.prompt_generator | |
) | |
else: | |
self.transformer = Transformer( | |
width=transformer_width, | |
layers=transformer_layers, | |
heads=transformer_heads, | |
attn_mask=self.build_attention_mask(), | |
prompt_cfg=txt_prompt_cfg, | |
prompt_learner=PromptLearner(transformer_width, self.n_ctx) | |
) | |
self.vocab_size = vocab_size | |
self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) | |
self.ln_final = nn.LayerNorm(transformer_width) # used to be `models.transformer.LayerNorm`` | |
self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim)) | |
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) | |
print("=> initialize initial temperature with {}".format(tempearture_init)) | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / tempearture_init)) | |
self.initialize_parameters() | |
freeze_list = [] | |
if self.freeze_vis_backbone: | |
print("=> Freeze visual backbone") | |
freeze_list += self.visual.param_list + [self.image_projection] | |
if self.freeze_txt_backbone: | |
print("=> Freeze text backbone") | |
if self.tune_bias: | |
freeze_list += [m for n, m in self.transformer.named_parameters() if 'prompt' not in n and 'bias' not in n] | |
freeze_list += [m for n, m in self.ln_final.named_parameters() if 'bias' not in n] | |
else: | |
freeze_list += [m for n, m in self.transformer.named_parameters() if 'prompt' not in n] | |
freeze_list += list(self.ln_final.parameters()) | |
freeze_list += list(self.token_embedding.parameters()) | |
freeze_list += [self.positional_embedding] + [self.text_projection] | |
for p in freeze_list: | |
p.requires_grad = False | |
# text prompts | |
if self.n_ctx > 0: | |
if self.txt_use_bank: | |
prompt_dim = self.visual.prompt_dim | |
if prompt_dim != transformer_width: | |
self.transformer.prompt_inproj = nn.Linear(transformer_width, prompt_dim, bias=False) | |
else: | |
self.transformer.prompt_inproj = nn.Identity() | |
self.transformer.prompt_outproj = nn.Linear(prompt_dim, transformer_width, bias=False) | |
nn.init.kaiming_normal_( | |
self.transformer.prompt_outproj.weight, a=0, mode='fan_out') | |
params_to_update = [n for n, m in self.named_parameters() if m.requires_grad] | |
num_opt_params = sum([m.numel() for m in self.parameters() if m.requires_grad]) | |
num_fz_params = sum([m.numel() for m in self.parameters() if not m.requires_grad]) | |
print("=> Params to update: {}".format(params_to_update)) | |
print("=> Update/Frozen: {}/{}".format(num_opt_params, num_fz_params)) | |
def initialize_parameters(self): | |
nn.init.normal_(self.token_embedding.weight, std=0.02) | |
nn.init.normal_(self.positional_embedding, std=0.01) | |
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) | |
attn_std = self.transformer.width ** -0.5 | |
fc_std = (2 * self.transformer.width) ** -0.5 | |
for block in self.transformer.resblocks: | |
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) | |
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) | |
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) | |
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) | |
nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5) | |
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) | |
def build_attention_mask(self): | |
# lazily create causal attention mask, with full attention between the vision tokens | |
# pytorch uses additive attention mask; fill with -inf | |
mask = torch.empty(self.context_length, self.context_length) | |
mask.fill_(float("-inf")) | |
mask.triu_(1) # zero out the lower diagonal | |
return mask | |
def encode_image(self, image, use_checkpoint=False, apply_project=True, istrain=False, gamma=1.0): | |
x, ps_loss = self.visual(image, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma) | |
if isinstance(x, list): | |
assert len(x) == 1 | |
x = x[0] | |
if apply_project: | |
x = x @ self.image_projection | |
return x, ps_loss | |
def encode_text(self, text, use_checkpoint=False, istrain=False, gamma=1.0): | |
x = self.token_embedding(text) # [batch_size, n_ctx, d_model] | |
B = x.shape[0] | |
eot = text.argmax(dim=-1) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x, ps_loss = self.transformer(x, self.positional_embedding, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma, eot=eot) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.ln_final(x) | |
# x.shape = [batch_size, n_ctx, transformer.width] | |
# take features from the eot embedding (eot_token is the highest number in each sequence) | |
x = x[torch.arange(x.shape[0]), self.n_ctx + eot] @ self.text_projection | |
return x, ps_loss | |
def forward(self, image, text, use_checkpoint=False, norm_embed=False, istrain=False, gamma=1.0): | |
image_embed, ps_loss_img = self.encode_image(image, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma) | |
text_embed, ps_loss_txt = self.encode_text(text, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma) | |
if norm_embed: | |
image_embed = F.normalize(image_embed, dim=-1) | |
text_embed = F.normalize(text_embed, dim=-1) | |
return {'image_embed': image_embed, | |
'text_embed': text_embed, | |
'logit_scale': self.logit_scale.exp(), | |
'ps_loss': ps_loss_img + ps_loss_txt} | |
def train(self, mode=True): | |
if not isinstance(mode, bool): | |
raise ValueError("training mode is expected to be boolean") | |
self.training = mode | |
for m in self.modules(): | |
m.training = mode | |
if mode: | |
if self.freeze_vis_backbone and not self.tune_bias: | |
for n, m in self.visual.named_modules(): | |
if 'prompt' not in n: | |
m.training = False | |
if self.freeze_txt_backbone and not self.tune_bias: | |
for n, m in self.transformer.named_modules(): | |
if 'prompt' not in n: | |
m.training = False | |
self.token_embedding.training = False | |
self.ln_final.training = False | |
def CLIP_OPENAI_TIMESFORMER_BASE( | |
num_frames=4, timesformer_gated_xattn=False, temperature_init=0.07, | |
project_embed_dim=256, **kwargs | |
): | |
cfg = kwargs.pop('model_cfg', {}) | |
vision_model = SpaceTimeTransformer( | |
num_frames=num_frames, | |
time_init='zeros', | |
attention_style='frozen-in-time', | |
ln_pre=True, | |
act_layer=QuickGELU, | |
is_tanh_gating=timesformer_gated_xattn, | |
drop_path_rate=cfg.get('drop_path_rate', 0), | |
tune_bias=cfg.get('tune_bias', False), | |
prompt_cfg=cfg.get('visual_prompt', {}) | |
) | |
clip_model, _ = load_openai_clip('ViT-B/16', 'cpu') | |
print("=> Loading CLIP (ViT-B/16) weights") | |
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12) | |
res = vision_model.load_state_dict(remapped_state_dict, strict=False) | |
print(res) | |
vision_model.head = nn.Identity() | |
vision_model.pre_logits = nn.Identity() | |
vision_model.fc = nn.Identity() | |
model = CLIP( | |
cfg, | |
embed_dim=project_embed_dim, | |
vision_width=768, | |
vision_model=vision_model, | |
context_length=77, | |
vocab_size=49408, | |
transformer_width=512, | |
transformer_heads=8, | |
transformer_layers=12, | |
tempearture_init=temperature_init, | |
**kwargs | |
) | |
model.transformer.load_state_dict(clip_model.transformer.state_dict(), strict=False) | |
model.token_embedding.load_state_dict(clip_model.token_embedding.state_dict()) | |
model.positional_embedding.data.copy_(clip_model.positional_embedding.data) | |
model.ln_final.load_state_dict(clip_model.ln_final.state_dict()) | |
if project_embed_dim == clip_model.text_projection.shape[1]: | |
print("=> Loading CLIP's text_projection, image_projection and logit_scale directly") | |
model.image_projection.data.copy_(clip_model.visual.proj.data) | |
model.text_projection.data.copy_(clip_model.text_projection.data) | |
model.logit_scale.data.copy_(clip_model.logit_scale.data) | |
return model | |