gina9726's picture
Upload demo files
c6f92cc verified
raw
history blame contribute delete
No virus
11 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from lavila.models.openai_clip import load as load_openai_clip
from lavila.models.openai_model import QuickGELU, Transformer
from lavila.models.timesformer import SpaceTimeTransformer
from lavila.models.utils import remap_keys, rsetattr
from lavila.models.prompt_tuning import PromptLearner
class CLIP(nn.Module):
def __init__(self,
cfg,
embed_dim: int,
# vision
vision_width: int,
vision_model: nn.Module,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
tempearture_init=0.07,
**kwargs,
):
super().__init__()
self.context_length = context_length
self.vision_width = vision_width
self.tune_bias = cfg.get('tune_bias', False)
self.freeze_vis_backbone = cfg.get('freeze_vis_backbone', False)
self.freeze_txt_backbone = cfg.get('freeze_txt_backbone', False)
self.visual = vision_model
self.t_step = cfg.get('t_step', self.visual.num_frames)
txt_prompt_cfg = cfg.get('text_prompt', {})
self.n_ctx = txt_prompt_cfg.get('n_ctx', 0)
self.txt_use_bank = txt_prompt_cfg.get('use_bank', False)
if self.txt_use_bank:
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
prompt_cfg=txt_prompt_cfg,
prompt_learner=PromptLearner(transformer_width, self.n_ctx),
prompt_generator=self.visual.prompt_generator
)
else:
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
prompt_cfg=txt_prompt_cfg,
prompt_learner=PromptLearner(transformer_width, self.n_ctx)
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = nn.LayerNorm(transformer_width) # used to be `models.transformer.LayerNorm``
self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
print("=> initialize initial temperature with {}".format(tempearture_init))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / tempearture_init))
self.initialize_parameters()
freeze_list = []
if self.freeze_vis_backbone:
print("=> Freeze visual backbone")
freeze_list += self.visual.param_list + [self.image_projection]
if self.freeze_txt_backbone:
print("=> Freeze text backbone")
if self.tune_bias:
freeze_list += [m for n, m in self.transformer.named_parameters() if 'prompt' not in n and 'bias' not in n]
freeze_list += [m for n, m in self.ln_final.named_parameters() if 'bias' not in n]
else:
freeze_list += [m for n, m in self.transformer.named_parameters() if 'prompt' not in n]
freeze_list += list(self.ln_final.parameters())
freeze_list += list(self.token_embedding.parameters())
freeze_list += [self.positional_embedding] + [self.text_projection]
for p in freeze_list:
p.requires_grad = False
# text prompts
if self.n_ctx > 0:
if self.txt_use_bank:
prompt_dim = self.visual.prompt_dim
if prompt_dim != transformer_width:
self.transformer.prompt_inproj = nn.Linear(transformer_width, prompt_dim, bias=False)
else:
self.transformer.prompt_inproj = nn.Identity()
self.transformer.prompt_outproj = nn.Linear(prompt_dim, transformer_width, bias=False)
nn.init.kaiming_normal_(
self.transformer.prompt_outproj.weight, a=0, mode='fan_out')
params_to_update = [n for n, m in self.named_parameters() if m.requires_grad]
num_opt_params = sum([m.numel() for m in self.parameters() if m.requires_grad])
num_fz_params = sum([m.numel() for m in self.parameters() if not m.requires_grad])
print("=> Params to update: {}".format(params_to_update))
print("=> Update/Frozen: {}/{}".format(num_opt_params, num_fz_params))
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5)
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def encode_image(self, image, use_checkpoint=False, apply_project=True, istrain=False, gamma=1.0):
x, ps_loss = self.visual(image, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma)
if isinstance(x, list):
assert len(x) == 1
x = x[0]
if apply_project:
x = x @ self.image_projection
return x, ps_loss
def encode_text(self, text, use_checkpoint=False, istrain=False, gamma=1.0):
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
B = x.shape[0]
eot = text.argmax(dim=-1)
x = x.permute(1, 0, 2) # NLD -> LND
x, ps_loss = self.transformer(x, self.positional_embedding, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma, eot=eot)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), self.n_ctx + eot] @ self.text_projection
return x, ps_loss
def forward(self, image, text, use_checkpoint=False, norm_embed=False, istrain=False, gamma=1.0):
image_embed, ps_loss_img = self.encode_image(image, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma)
text_embed, ps_loss_txt = self.encode_text(text, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma)
if norm_embed:
image_embed = F.normalize(image_embed, dim=-1)
text_embed = F.normalize(text_embed, dim=-1)
return {'image_embed': image_embed,
'text_embed': text_embed,
'logit_scale': self.logit_scale.exp(),
'ps_loss': ps_loss_img + ps_loss_txt}
def train(self, mode=True):
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
self.training = mode
for m in self.modules():
m.training = mode
if mode:
if self.freeze_vis_backbone and not self.tune_bias:
for n, m in self.visual.named_modules():
if 'prompt' not in n:
m.training = False
if self.freeze_txt_backbone and not self.tune_bias:
for n, m in self.transformer.named_modules():
if 'prompt' not in n:
m.training = False
self.token_embedding.training = False
self.ln_final.training = False
def CLIP_OPENAI_TIMESFORMER_BASE(
num_frames=4, timesformer_gated_xattn=False, temperature_init=0.07,
project_embed_dim=256, **kwargs
):
cfg = kwargs.pop('model_cfg', {})
vision_model = SpaceTimeTransformer(
num_frames=num_frames,
time_init='zeros',
attention_style='frozen-in-time',
ln_pre=True,
act_layer=QuickGELU,
is_tanh_gating=timesformer_gated_xattn,
drop_path_rate=cfg.get('drop_path_rate', 0),
tune_bias=cfg.get('tune_bias', False),
prompt_cfg=cfg.get('visual_prompt', {})
)
clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
print("=> Loading CLIP (ViT-B/16) weights")
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12)
res = vision_model.load_state_dict(remapped_state_dict, strict=False)
print(res)
vision_model.head = nn.Identity()
vision_model.pre_logits = nn.Identity()
vision_model.fc = nn.Identity()
model = CLIP(
cfg,
embed_dim=project_embed_dim,
vision_width=768,
vision_model=vision_model,
context_length=77,
vocab_size=49408,
transformer_width=512,
transformer_heads=8,
transformer_layers=12,
tempearture_init=temperature_init,
**kwargs
)
model.transformer.load_state_dict(clip_model.transformer.state_dict(), strict=False)
model.token_embedding.load_state_dict(clip_model.token_embedding.state_dict())
model.positional_embedding.data.copy_(clip_model.positional_embedding.data)
model.ln_final.load_state_dict(clip_model.ln_final.state_dict())
if project_embed_dim == clip_model.text_projection.shape[1]:
print("=> Loading CLIP's text_projection, image_projection and logit_scale directly")
model.image_projection.data.copy_(clip_model.visual.proj.data)
model.text_projection.data.copy_(clip_model.text_projection.data)
model.logit_scale.data.copy_(clip_model.logit_scale.data)
return model