gina9726 commited on
Commit
c6f92cc
1 Parent(s): afc90f1

Upload demo files

Browse files
app.py CHANGED
@@ -1,49 +1,75 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.backends.cudnn as cudnn
4
- import torch.cuda.amp as amp
5
- from torch.distributed.optim import ZeroRedundancyOptimizer
6
- import torch.nn.parallel
7
- import torchvision.transforms as transforms
8
- import torchvision.transforms._transforms_video as transforms_video
9
- from sklearn.metrics import confusion_matrix
10
 
 
 
11
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- sample_videos = [
14
- "https://ak.picdn.net/shutterstock/videos/21179416/preview/stock-footage-aerial-shot-winter-forest.mp4",
15
- "https://ak.picdn.net/shutterstock/videos/5629184/preview/stock-footage-senior-couple-looking-through-binoculars-on-sailboat-together-shot-on-red-epic-for-high-quality-k.mp4",
16
- "https://ak.picdn.net/shutterstock/videos/1063125190/preview/stock-footage-a-beautiful-cookie-with-oranges-lies-on-a-green-tablecloth.mp4"
17
- ]
18
- sample_videos_gt = [
19
- "forest",
20
- "people",
21
- "orange"
22
- ]
23
-
24
- def predict(idx, video):
25
- label = sample_videos_gt[idx]
26
- return label, label, label
27
-
28
- with gr.Blocks() as demo:
29
- gr.Markdown(
30
- """
31
- # Ego-VPA Demo
32
- Choose a sample video and click predict to view the results.
33
- """
34
- )
35
-
36
- with gr.Row():
37
- with gr.Column():
38
- video = gr.PlayableVideo(label="video", interactive=False)
39
- with gr.Column():
40
- idx = gr.Number(label="Idx", visible=False)
41
- label = gr.Text(label="Ground Truth")
42
- zeroshot = gr.Text(label="LaViLa (zero-shot) prediction")
43
- ours = gr.Text(label="Ego-VPA prediction")
44
- btn = gr.Button("Predict", variant="primary")
45
- btn.click(predict, inputs=[idx, video], outputs=[label, zeroshot, ours])
46
- gr.Examples(examples=[[i, x] for i, x in enumerate(sample_videos)], inputs=[idx, video])
47
 
48
  if __name__ == "__main__":
49
- demo.launch()
 
 
1
+ ### app.py
2
+ # User interface for the demo.
3
+ ###
 
 
 
 
 
 
4
 
5
+ import os
6
+ import pandas as pd
7
  import gradio as gr
8
+ from gradio_rich_textbox import RichTextbox
9
+
10
+ from demo import VideoCLSModel
11
+
12
+
13
+ def load_samples(data_root):
14
+ sample_videos = []
15
+ n_sample = len(os.listdir(f'{data_root}/csv'))
16
+ for i in range(n_sample):
17
+ df = pd.read_csv(f'{data_root}/csv/{i}.csv')
18
+ vid = df['id'].values[0]
19
+ sample_videos.append(f'{data_root}/video/{vid}.mp4')
20
+
21
+ return sample_videos
22
+
23
+ def format_pred(pred, gt):
24
+ tp = '[color=green]{}[/color]'
25
+ fp = '[color=red]{}[/color]'
26
+ fmt_pred = []
27
+ for x in pred:
28
+ if x in gt:
29
+ fmt_pred.append(tp.format(x))
30
+ else:
31
+ fmt_pred.append(fp.format(x))
32
+
33
+ return ', '.join(fmt_pred)
34
+
35
+ def main():
36
+ lavila = VideoCLSModel("configs/charades_ego/zeroshot.yml")
37
+ egovpa = VideoCLSModel("configs/charades_ego/egovpa.yml")
38
+ sample_videos = load_samples('data/charades_ego')
39
+ print(sample_videos)
40
+
41
+ def predict(idx):
42
+ zeroshot_action, gt_action = lavila.predict(idx)
43
+ egovpa_action, gt_action = egovpa.predict(idx)
44
+ zeroshot_action = format_pred(zeroshot_action, gt_action)
45
+ egovpa_action = format_pred(egovpa_action, gt_action)
46
+
47
+ return gt_action, zeroshot_action, egovpa_action
48
+
49
+ with gr.Blocks() as demo:
50
+ gr.Markdown(
51
+ """
52
+ # Ego-VPA Demo
53
+ Choose a sample video and click predict to view the results
54
+ (<span style="color:green">correct</span>/<span style="color:red">incorrect</span>).
55
+ """
56
+ )
57
+
58
+ with gr.Row():
59
+ with gr.Column():
60
+ video = gr.PlayableVideo(label="video", height='300px', interactive=False, autoplay=True)
61
+ with gr.Column():
62
+ idx = gr.Number(label="Idx", visible=False)
63
+ label = RichTextbox(label="Ground Truth", visible=False)
64
+ zeroshot = RichTextbox(label="LaViLa (zero-shot) prediction")
65
+ ours = RichTextbox(label="Ego-VPA prediction")
66
+ btn = gr.Button("Predict", variant="primary")
67
+ btn.click(predict, inputs=[idx], outputs=[label, zeroshot, ours])
68
+ gr.Examples(examples=[[i, x] for i, x in enumerate(sample_videos)], inputs=[idx, video])
69
+
70
+ demo.launch(share=True)
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  if __name__ == "__main__":
74
+ main()
75
+
ckpt/charades_ego.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:612fd98f8f9281c4fefdeac0dce88a9fdd6ec3bf0fba9afbb8510e14a542a423
3
+ size 728841529
ckpt/lavila_epo1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34048c21e30cfbbfd4f9554f000f893371e78596423100414cf227b07a0539c2
3
+ size 710793107
configs/base.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model:
3
+ pretrain: ""
4
+ resume: ""
5
+ timesformer_freeze_space: false
6
+ drop_path_rate: 0.1
7
+ dropout_ratio: 0.5
8
+ freeze_vis_backbone: false
9
+ freeze_txt_backbone: false
10
+ use_vn_classifier: false
11
+
12
+ data:
13
+ dataset: ek100_mir
14
+ root: datasets/EK100/video_ht256px
15
+ metadata: datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_train.csv
16
+ metadata_val: datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test.csv
17
+ relevancy_path: datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl
18
+ clip_length: 16
19
+ clip_stride: 4
20
+ sparse_sample: false
21
+ num_crops: 1
22
+ num_clips: 1
23
+
configs/charades_ego/egovpa.yml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model:
3
+ pretrain: ckpt/charades_ego.pt
4
+ freeze_vis_backbone: true
5
+ freeze_txt_backbone: true
6
+ num_frames: 16
7
+ text_prompt:
8
+ n_ctx: 8
9
+ use_bank: true
10
+ visual_prompt:
11
+ num_layers: 12
12
+ prompt_dim: 512
13
+ num_tokens: 128
14
+ deep: true
15
+ deep_shared: false
16
+ split_st: false
17
+ pt_spt: true
18
+ pt_tmp: false
19
+ style: VoP_c_pool # VoP_c: prompts are generated by context fusion; frame-specific attention
20
+ n_seg: 16 # number of segments per video (n_seg=clip_length -> 1 frame/seg)
21
+ K_s: 8 # boundary of intra-frame/inter-frame attention (VoP_f+c)
22
+ pool:
23
+ size: 10
24
+
25
+ data:
26
+ dataset: charades_ego
27
+ #root: /data/CharadesEgo/CharadesEgo_v1_480
28
+ #metadata_val: /data/CharadesEgo/CharadesEgo/CharadesEgo_v1_test_only1st.csv # all testing data
29
+ root: data/charades_ego/video
30
+ metadata_val: data/charades_ego/csv/{}.csv
31
+ label_map: meta/charades_ego/charades_ego.json
32
+ clip_length: 16
33
+ sparse_sample: true
34
+
35
+
configs/charades_ego/zeroshot.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model:
3
+ pretrain: ckpt/lavila_epo1.pth
4
+ freeze_vis_backbone: true
5
+ freeze_txt_backbone: true
6
+ num_frames: 16
7
+
8
+ data:
9
+ dataset: charades_ego
10
+ #root: /data/CharadesEgo/CharadesEgo_v1_480
11
+ #metadata_val: /data/CharadesEgo/CharadesEgo/CharadesEgo_v1_test_only1st.csv # all testing data
12
+ root: data/charades_ego/video
13
+ metadata_val: data/charades_ego/csv/{}.csv
14
+ label_map: meta/charades_ego/charades_ego.json
15
+ clip_length: 16
16
+ sparse_sample: true
17
+
18
+
demo.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### demo.py
2
+ # Define model classes for inference.
3
+ ###
4
+
5
+ from collections import OrderedDict
6
+ import json
7
+ import numpy as np
8
+ import os
9
+ import pandas as pd
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.backends.cudnn as cudnn
14
+ import torchvision.transforms as transforms
15
+ import torchvision.transforms._transforms_video as transforms_video
16
+ from sklearn.metrics import confusion_matrix
17
+
18
+ from lavila.data import datasets
19
+ from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop
20
+ from lavila.models import models
21
+ from lavila.models.tokenizer import (MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer)
22
+ from lavila.models.utils import inflate_positional_embeds
23
+ from lavila.utils.config import load_cfg
24
+ from lavila.utils.evaluation_charades import charades_map
25
+ from lavila.utils.evaluation import get_mean_accuracy
26
+
27
+
28
+ class VideoModel(nn.Module):
29
+ """ Base model for video understanding based on LaViLa architecture. """
30
+ def __init__(self, config):
31
+ """ Initializes the model.
32
+ Parameters:
33
+ config: config file
34
+ """
35
+ super(VideoModel, self).__init__()
36
+ self.cfg = load_cfg(config)
37
+ self.model = self.build_model()
38
+ self.tokenizer = self.get_tokenizer()
39
+ self.templates = ['{}']
40
+ self.dataset = self.cfg['data']['dataset']
41
+ self.eval()
42
+
43
+ def build_model(self):
44
+ cfg = self.cfg
45
+ if cfg['model'].get('pretrain', False):
46
+ ckpt_path = cfg['model']['pretrain']
47
+ else:
48
+ raise Exception('no checkpoint found')
49
+ ckpt = torch.load(ckpt_path, map_location='cpu')
50
+
51
+ state_dict = OrderedDict()
52
+ for k, v in ckpt['state_dict'].items():
53
+ state_dict[k.replace('module.', '')] = v
54
+
55
+ old_args = vars(ckpt['args'])
56
+ arch = old_args.get('model', 'CLIP_OPENAI_TIMESFORMER_BASE')
57
+ self.arch = arch
58
+ cfg['model']['arch'] = arch
59
+ cfg['model']['norm_embed'] = old_args.get('norm_embed', True)
60
+ print("=> creating model: {}".format(arch))
61
+ model = getattr(models, arch)(
62
+ pretrained=old_args.get('load_visual_pretrained', None),
63
+ pretrained2d=old_args.get('load_visual_pretrained', None) is not None,
64
+ text_use_cls_token=old_args.get('use_cls_token', False),
65
+ project_embed_dim=old_args.get('project_embed_dim', 256),
66
+ timesformer_gated_xattn=False,
67
+ num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']),
68
+ model_cfg=cfg['model']
69
+ )
70
+ model.logit_scale.requires_grad = False
71
+
72
+ if torch.cuda.is_available():
73
+ model.cuda()
74
+
75
+ if ('TIMESFORMER' in arch or 'EGOVLP' in arch) and cfg['model'].get('inflat_posemb', True):
76
+ # inflate weight
77
+ print('=> inflating PE in models due to different frame numbers')
78
+ state_dict = inflate_positional_embeds(
79
+ model.state_dict(), state_dict,
80
+ num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']),
81
+ load_temporal_fix='bilinear',
82
+ )
83
+ model.load_state_dict(state_dict, strict=True)
84
+ print("=> loaded resume checkpoint '{}' (epoch {})".format(ckpt_path, ckpt['epoch']))
85
+
86
+ return model
87
+
88
+ def eval(self):
89
+ cudnn.benchmark = True
90
+ for p in self.model.parameters():
91
+ p.requires_grad = False
92
+ self.model.eval()
93
+
94
+ def get_tokenizer(self):
95
+ arch = self.arch
96
+ if arch.endswith('DISTILBERT_BASE'):
97
+ tokenizer = MyDistilBertTokenizer('distilbert-base-uncased')
98
+ elif arch.endswith('BERT_BASE'):
99
+ tokenizer = MyBertTokenizer('bert-base-uncased')
100
+ elif arch.endswith('BERT_LARGE'):
101
+ tokenizer = MyBertTokenizer('bert-large-uncased')
102
+ elif arch.endswith('GPT2'):
103
+ tokenizer = MyGPT2Tokenizer('gpt2')
104
+ elif arch.endswith('GPT2_MEDIUM'):
105
+ tokenizer = MyGPT2Tokenizer('gpt2-medium')
106
+ elif arch.endswith('GPT2_LARGE'):
107
+ tokenizer = MyGPT2Tokenizer('gpt2-large')
108
+ elif arch.endswith('GPT2_XL'):
109
+ tokenizer = MyGPT2Tokenizer('gpt2-xl')
110
+ else:
111
+ print("Using SimpleTokenizer because of model '{}'. "
112
+ "Please check if this is what you want".format(arch))
113
+ tokenizer = SimpleTokenizer()
114
+
115
+ return tokenizer
116
+
117
+
118
+ class VideoCLSModel(VideoModel):
119
+ """ Video model for video classification tasks (Charades-Ego, EGTEA). """
120
+ def __init__(self, config):
121
+ super(VideoCLSModel, self).__init__(config)
122
+ self.labels, self.mapping_vn2act = self.gen_label_map()
123
+ self.text_features = self.get_text_features()
124
+
125
+ def gen_label_map(self):
126
+ labelmap = self.cfg.get('label_map', 'meta/charades_ego/label_map.json')
127
+ if os.path.isfile(labelmap):
128
+ print(f"=> Loading label maps from {labelmap}")
129
+ meta = json.load(open(labelmap, 'r'))
130
+ labels, mapping_vn2act = meta['labels'], meta['mapping_vn2act']
131
+ else:
132
+ from lavila.utils.preprocess import generate_label_map
133
+ labels, mapping_vn2act = generate_label_map(self.dataset)
134
+ meta = {'labels': labels, 'mapping_vn2act': mapping_vn2act}
135
+ meta_dir = f'meta/{self.dataset}'
136
+ if not os.path.exists(meta_dir):
137
+ os.makedirs(meta_dir)
138
+ json.dump(meta, open(f'{meta_dir}/label_map.json', 'w'))
139
+ print(f"=> Label map is generated and saved to {meta_dir}/label_map.json")
140
+
141
+ return labels, mapping_vn2act
142
+
143
+ def load_data(self, idx=None):
144
+ print(f"=> Creating dataset")
145
+ cfg, dataset = self.cfg, self.dataset
146
+ data_cfg = cfg['data']
147
+ crop_size = 224 if '336PX' not in self.arch else 336
148
+ val_transform = transforms.Compose([
149
+ Permute([3, 0, 1, 2]), # T H W C -> C T H W
150
+ transforms.Resize(crop_size),
151
+ transforms.CenterCrop(crop_size),
152
+ transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]),
153
+ ])
154
+
155
+ if idx is None:
156
+ metadata_val = data_cfg['metadata_val']
157
+ else:
158
+ metadata_val = data_cfg['metadata_val'].format(idx)
159
+ if dataset in ['charades_ego', 'egtea']:
160
+ val_dataset = datasets.VideoClassyDataset(
161
+ dataset, data_cfg['root'], metadata_val,
162
+ transform=val_transform, is_training=False,
163
+ label_mapping=self.mapping_vn2act, is_trimmed=False,
164
+ num_clips=1, clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
165
+ sparse_sample=data_cfg['sparse_sample']
166
+ )
167
+ else:
168
+ raise NotImplementedError
169
+
170
+ val_loader = torch.utils.data.DataLoader(
171
+ val_dataset, batch_size=8, shuffle=False,
172
+ num_workers=4, pin_memory=True, sampler=None, drop_last=False
173
+ )
174
+
175
+ return val_loader
176
+
177
+ @torch.no_grad()
178
+ def get_text_features(self):
179
+ print('=> Extracting text features')
180
+ text_features = []
181
+ for label in self.labels:
182
+ if isinstance(label, list):
183
+ texts = [tmpl.format(lbl) for tmpl in self.templates for lbl in label]
184
+ else:
185
+ texts = [tmpl.format(label) for tmpl in self.templates]
186
+ texts = self.tokenizer(texts)
187
+ if isinstance(texts, tuple):
188
+ # Bert-style tokenizer will output both ids and mask
189
+ texts, masks = texts
190
+ texts = texts.cuda(non_blocking=True)
191
+ masks = masks.cuda(non_blocking=True)
192
+ else:
193
+ texts = texts.cuda(non_blocking=True)
194
+ masks = None
195
+ texts = texts.view(-1, 77).contiguous()
196
+ masks = masks.view(-1, 77).contiguous() if masks is not None else None
197
+ if masks is not None:
198
+ class_embeddings, _ = self.model.encode_text(texts, attention_mask=masks)
199
+ else:
200
+ class_embeddings, _ = self.model.encode_text(texts)
201
+ class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
202
+ class_embeddings = class_embeddings.mean(dim=0)
203
+ class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
204
+
205
+ text_features.append(class_embeddings)
206
+ text_features = torch.stack(text_features, dim=0)
207
+
208
+ return text_features
209
+
210
+ @torch.no_grad()
211
+ def forward(self, idx=None):
212
+ print('=> Start forwarding')
213
+ val_loader = self.load_data(idx)
214
+ all_outputs = []
215
+ all_targets = []
216
+ for i, values in enumerate(val_loader):
217
+ images = values[0]
218
+ target = values[1]
219
+
220
+ images = images.cuda(non_blocking=True)
221
+ target = target.cuda(non_blocking=True)
222
+
223
+ # encode images
224
+ image_features, _ = self.model.encode_image(images)
225
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
226
+ # cosine similarity as logits
227
+ logits_per_image = image_features @ self.text_features.t()
228
+ logits_per_image = torch.softmax(logits_per_image, dim=1)
229
+
230
+ all_outputs.append(logits_per_image.cpu())
231
+ all_targets.append(target.cpu())
232
+
233
+ all_outputs = torch.cat(all_outputs)
234
+ all_targets = torch.cat(all_targets)
235
+
236
+ return all_outputs, all_targets
237
+
238
+ @torch.no_grad()
239
+ def predict(self, idx=0):
240
+ all_outputs, all_targets = self.forward(idx)
241
+ preds, targets = all_outputs.numpy(), all_targets.numpy()
242
+ #sel = np.where(np.cumsum(sorted(preds[0].tolist(), reverse=True)) > 0.06)[0][0]
243
+ sel = 5
244
+ df = pd.DataFrame(self.labels)
245
+ pred_action = df.iloc[preds[0].argsort()[-sel:]].values.tolist()
246
+ gt_action = df.iloc[np.where(targets[0])[0]].values.tolist()
247
+ pred_action = sorted([x[0] for x in pred_action])
248
+ gt_action = sorted([x[0] for x in gt_action])
249
+ return pred_action, gt_action
250
+
251
+ @torch.no_grad()
252
+ def evaluate(self):
253
+ all_outputs, all_targets = self.forward()
254
+ preds, targets = all_outputs.numpy(), all_targets.numpy()
255
+ if self.dataset == 'charades_ego':
256
+ m_ap, _, m_aps = charades_map(preds, targets)
257
+ print('mAP = {:.3f}'.format(m_ap))
258
+ elif self.dataset == 'egtea':
259
+ cm = confusion_matrix(targets, preds.argmax(axis=1))
260
+ mean_class_acc, acc = get_mean_accuracy(cm)
261
+ print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc))
262
+ else:
263
+ raise NotImplementedError
264
+
265
+
266
+ def main():
267
+ lavila = VideoCLSModel("configs/charades_ego/zeroshot.yml")
268
+ egovpa = VideoCLSModel("configs/charades_ego/egovpa.yml")
269
+ lavila.evaluate()
270
+ egovpa.evaluate()
271
+
272
+
273
+ if __name__ == '__main__':
274
+ main()
275
+
lavila/data/datasets.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import csv
8
+ import glob
9
+ import json
10
+ import numpy as np
11
+ import os.path as osp
12
+ import pickle
13
+ import random
14
+
15
+ import decord
16
+ import pandas as pd
17
+ import torch
18
+
19
+
20
+ def datetime2sec(str):
21
+ hh, mm, ss = str.split(':')
22
+ return int(hh) * 3600 + int(mm) * 60 + float(ss)
23
+
24
+
25
+ def video_loader(root, vid, second, end_second=None, chunk_len=300, fps=30, clip_length=32, jitter=False):
26
+ if chunk_len == -1:
27
+ vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid)))
28
+ second_offset = second
29
+ if end_second is not None:
30
+ end_second = min(end_second, len(vr) / vr.get_avg_fps())
31
+ else:
32
+ end_second = len(vr) / vr.get_avg_fps()
33
+ else:
34
+ chunk_start = int(second) // chunk_len * chunk_len
35
+ second_offset = second - chunk_start
36
+ vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start)))
37
+ if fps == -1:
38
+ fps = vr.get_avg_fps()
39
+
40
+ # calculate frame_ids
41
+ frame_offset = int(np.round(second_offset * fps))
42
+ total_duration = max(int((end_second - second) * fps), clip_length)
43
+ if chunk_len == -1:
44
+ if end_second <= second:
45
+ raise ValueError("end_second should be greater than second")
46
+ else:
47
+ frame_ids = get_frame_ids(frame_offset, min(frame_offset + total_duration, len(vr)), num_segments=clip_length, jitter=jitter)
48
+ else:
49
+ frame_ids = get_frame_ids(frame_offset, frame_offset + total_duration, num_segments=clip_length, jitter=jitter)
50
+
51
+ # load frames
52
+ if max(frame_ids) < len(vr):
53
+ try:
54
+ frames = vr.get_batch(frame_ids).asnumpy()
55
+ except decord.DECORDError as error:
56
+ print(error)
57
+ frames = vr.get_batch([0] * len(frame_ids)).asnumpy()
58
+ else:
59
+ # find the remaining frames in the next chunk
60
+ try:
61
+ frame_ids_part1 = list(filter(lambda frame_id: frame_id < len(vr), frame_ids))
62
+ frames_part1 = vr.get_batch(frame_ids_part1).asnumpy()
63
+ vr2 = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start + chunk_len)))
64
+ frame_ids_part2 = list(filter(lambda frame_id: frame_id >= len(vr), frame_ids))
65
+ frame_ids_part2 = [min(frame_id % len(vr), len(vr2) - 1) for frame_id in frame_ids_part2]
66
+ frames_part2 = vr2.get_batch(frame_ids_part2).asnumpy()
67
+ frames = np.concatenate([frames_part1, frames_part2], axis=0)
68
+ # the next chunk does not exist; the current chunk is the last one
69
+ except (RuntimeError, decord.DECORDError) as error:
70
+ print(error)
71
+ frame_ids = get_frame_ids(min(frame_offset, len(vr) - 1), len(vr), num_segments=clip_length, jitter=jitter)
72
+ frames = vr.get_batch(frame_ids).asnumpy()
73
+
74
+ frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
75
+ return torch.stack(frames, dim=0)
76
+
77
+
78
+ def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True):
79
+ seg_size = float(end_frame - start_frame - 1) / num_segments
80
+ seq = []
81
+ for i in range(num_segments):
82
+ start = int(np.round(seg_size * i) + start_frame)
83
+ end = int(np.round(seg_size * (i + 1)) + start_frame)
84
+ end = min(end, end_frame)
85
+ if jitter:
86
+ frame_id = np.random.randint(low=start, high=(end + 1))
87
+ else:
88
+ frame_id = (start + end) // 2
89
+ seq.append(frame_id)
90
+ return seq
91
+
92
+
93
+ def video_loader_by_frames(root, vid, frame_ids):
94
+ vr = decord.VideoReader(osp.join(root, vid))
95
+ try:
96
+ frames = vr.get_batch(frame_ids).asnumpy()
97
+ frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
98
+ except (IndexError, decord.DECORDError) as error:
99
+ print(error)
100
+ print("Erroneous video: ", vid)
101
+ frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))]
102
+ return torch.stack(frames, dim=0)
103
+
104
+
105
+ class VideoCaptionDatasetBase(torch.utils.data.Dataset):
106
+ def __init__(self, dataset, root, metadata, is_trimmed=True):
107
+ self.dataset = dataset
108
+ self.root = root
109
+ self.is_trimmed = is_trimmed
110
+
111
+ if self.dataset == 'ego4d':
112
+ with open(metadata, 'rb') as f:
113
+ self.samples = pickle.load(f)
114
+ elif self.dataset == 'ego4d_mcq':
115
+ with open(metadata, 'r') as f:
116
+ self.samples = json.load(f)
117
+ elif self.dataset in ['ek100_cls', 'ek100_mir']:
118
+ video_list = glob.glob(osp.join(self.root, '*/*.MP4'))
119
+ fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list}
120
+ self.samples = []
121
+ with open(metadata) as f:
122
+ csv_reader = csv.reader(f)
123
+ _ = next(csv_reader) # skip the header
124
+ for row in csv_reader:
125
+ pid, vid = row[1:3]
126
+ # start_frame, end_frame = int(row[6]), int(row[7])
127
+ # Deprecated: some videos might have fps mismatch issue
128
+ start_timestamp, end_timestamp = datetime2sec(row[4]), datetime2sec(row[5])
129
+ narration = row[8]
130
+ verb, noun = int(row[10]), int(row[12])
131
+ vid_path = '{}/{}.MP4'.format(pid, vid)
132
+ fps = fps_dict[osp.join(self.root, vid_path)]
133
+ start_frame = int(np.round(fps * start_timestamp))
134
+ end_frame = int(np.ceil(fps * end_timestamp))
135
+ self.samples.append((vid_path, start_frame, end_frame, narration, verb, noun))
136
+ if self.dataset == 'ek100_mir':
137
+ self.metadata_sentence = pd.read_csv(metadata[:metadata.index('.csv')] + '_sentence.csv')
138
+ if 'train' in metadata:
139
+ self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_train.pkl'), 'rb'))
140
+ elif 'test' in metadata:
141
+ self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_test.pkl'), 'rb'))
142
+ else:
143
+ raise ValueError('{} should contain either "train" or "test"!'.format(metadata))
144
+ self.relevancy = .1
145
+ elif self.dataset == 'egtea':
146
+ video_list = glob.glob(osp.join(self.root, '*/*'))
147
+ len_dict = {video: len(decord.VideoReader(video)) for video in video_list}
148
+
149
+ vn_list, labels = [], []
150
+ for row in open(osp.join(osp.dirname(metadata), 'action_idx.txt')):
151
+ row = row.strip()
152
+ vn = int(row.split(' ')[-1])
153
+ vn_list.append(vn)
154
+ narration = ' '.join(row.split(' ')[:-1])
155
+ labels.append(narration.replace('_', ' ').lower())
156
+ # labels.append(narration)
157
+ mapping_act2narration = {vn: narration for vn, narration in zip(vn_list, labels)}
158
+
159
+ self.samples = []
160
+ with open(metadata) as f:
161
+ for row in f:
162
+ clip_id, action_idx = row.strip().split(' ')[:2]
163
+ video_id = '-'.join(clip_id.split('-')[:3])
164
+ vid_relpath = osp.join(video_id, '{}.mp4'.format(clip_id))
165
+ vid_fullpath = osp.join(self.root, video_id, '{}.mp4'.format(clip_id))
166
+ self.samples.append((vid_relpath, 0, len_dict[vid_fullpath], mapping_act2narration[int(action_idx)]))
167
+ elif self.dataset == 'charades_ego':
168
+ video_list = glob.glob(osp.join(self.root, '*.mp4'))
169
+ fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list}
170
+ self.samples = []
171
+ with open(metadata) as f:
172
+ csv_reader = csv.reader(f)
173
+ _ = next(csv_reader) # skip the header
174
+ for row in csv_reader:
175
+ video_id = row[0]
176
+ if self.is_trimmed:
177
+ for action_tuple in row[9].split(';'):
178
+ if not action_tuple:
179
+ continue
180
+ action, start_timestamp, end_timestamp = action_tuple.split(' ')
181
+ start_timestamp, end_timestamp = float(start_timestamp), float(end_timestamp)
182
+ vid_path = '{}.mp4'.format(video_id)
183
+ fps = fps_dict[osp.join(self.root, vid_path)]
184
+ start_frame = int(np.round(fps * start_timestamp))
185
+ end_frame = int(np.ceil(fps * end_timestamp))
186
+ self.samples.append((vid_path, start_frame, end_frame, action))
187
+ else:
188
+ if not row[9]:
189
+ action_list = []
190
+ else:
191
+ action_list = [action_tuple.split(' ')[0] for action_tuple in row[9].split(';')]
192
+ vid_path = '{}.mp4'.format(video_id)
193
+ fps = fps_dict[osp.join(self.root, vid_path)]
194
+ duration = fps * float(row[10])
195
+ self.samples.append((vid_path, 0, duration, action_list))
196
+ elif self.dataset == 'charades_ego_trimmed':
197
+ with open(metadata, 'rb') as f:
198
+ self.samples = pickle.load(f)
199
+ else:
200
+ raise NotImplementedError
201
+
202
+ def get_raw_item(self, i, is_training=True, num_clips=1, clip_length=32, clip_stride=2, sparse_sample=False,
203
+ narration_selection='random'):
204
+ if self.dataset == 'ego4d':
205
+ if len(self.samples[i]) == 4:
206
+ vid, start_second, end_second, narration = self.samples[i]
207
+ frames = video_loader(self.root, vid, start_second,
208
+ end_second=end_second,
209
+ clip_length=clip_length,
210
+ jitter=is_training)
211
+ if isinstance(narration, list):
212
+ if narration_selection == 'random':
213
+ narration = random.choice(narration)
214
+ elif narration_selection == 'concat':
215
+ narration = '. '.join(narration)
216
+ elif narration_selection == 'list':
217
+ narration = narration
218
+ else:
219
+ raise ValueError
220
+ return frames, narration
221
+ elif len(self.samples[i]) == 5:
222
+ # TODO: need better filtering strategy based on nll
223
+ vid, start_second, end_second, narration, _ = self.samples[i]
224
+ frames = video_loader(self.root, vid, start_second,
225
+ end_second=end_second,
226
+ clip_length=clip_length,
227
+ jitter=is_training)
228
+ if isinstance(narration, list):
229
+ if narration_selection == 'random':
230
+ narration = random.choice(narration)
231
+ elif narration_selection == 'concat':
232
+ narration = '. '.join(narration)
233
+ elif narration_selection == 'list':
234
+ narration = narration
235
+ else:
236
+ raise ValueError
237
+ return frames, narration
238
+ elif self.dataset == 'ego4d_mcq':
239
+ itemMCQ = self.samples[str(i)]
240
+ answerIndex = itemMCQ['answer']
241
+ textQuery = itemMCQ['query']['clip_text']
242
+ sampleOptions = itemMCQ['choices']
243
+ frames_options = []
244
+ narration_options = []
245
+ for option_id in range(len(sampleOptions)):
246
+ option = sampleOptions[str(option_id)]
247
+ frames = video_loader(self.root, option['video_uid'],
248
+ float(option['clip_start']), end_second=float(option['clip_end']),
249
+ clip_length=clip_length,
250
+ jitter=is_training)
251
+ frames_options.append(frames)
252
+ narration_options.append(option['clip_text'])
253
+ return textQuery, frames_options, narration_options, answerIndex, itemMCQ['types']
254
+ elif self.dataset == 'ek100_mir':
255
+ vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i]
256
+ # from third_party.EgoVLP.base.base_dataset import sample_frames_start_end
257
+ # frame_ids = sample_frames_start_end(clip_length, start_frame, end_frame, sample='uniform', fix_start=None)
258
+ frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training)
259
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
260
+ if is_training:
261
+ positive_list = np.where(self.relevancy_mat[i] > self.relevancy)[0].tolist()
262
+ if positive_list != []:
263
+ pos = random.sample(positive_list, min(len(positive_list), 1))[0]
264
+ if pos < len(self.metadata_sentence) and pos < self.relevancy_mat.shape[1]:
265
+ return frames, (self.metadata_sentence.iloc[pos][1], self.relevancy_mat[i][pos])
266
+ else:
267
+ return frames, (narration, 1)
268
+ elif self.dataset == 'ek100_cls':
269
+ vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i]
270
+ frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training)
271
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
272
+ return frames, '{}:{}'.format(verb, noun)
273
+ elif self.dataset == 'egtea':
274
+ vid_path, start_frame, end_frame, sentence = self.samples[i]
275
+ if is_training:
276
+ assert num_clips == 1
277
+ if end_frame < clip_length * clip_stride:
278
+ frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
279
+ zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
280
+ frames = torch.cat((frames, zeros), dim=0)
281
+ frames = frames[::clip_stride]
282
+ else:
283
+ start_id = np.random.randint(0, end_frame - clip_length * clip_stride + 1)
284
+ frame_ids = np.arange(start_id, start_id + clip_length * clip_stride, clip_stride)
285
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
286
+ else:
287
+ if end_frame < clip_length * clip_stride:
288
+ frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
289
+ zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
290
+ frames = torch.cat((frames, zeros), dim=0)
291
+ frames = frames[::clip_stride]
292
+ frames = frames.repeat(num_clips, 1, 1, 1)
293
+ else:
294
+ frame_ids = []
295
+ for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int):
296
+ frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride))
297
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
298
+ return frames, sentence
299
+ elif self.dataset == 'charades_ego':
300
+ vid_path, start_frame, end_frame, action_list = self.samples[i]
301
+ if sparse_sample:
302
+ frame_ids = get_frame_ids(start_frame, end_frame, num_segments=num_clips * clip_length, jitter=is_training)
303
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
304
+ else:
305
+ if end_frame < clip_length * clip_stride:
306
+ frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
307
+ zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
308
+ frames = torch.cat((frames, zeros), dim=0)
309
+ frames = frames[::clip_stride]
310
+ frames = frames.repeat(num_clips, 1, 1, 1)
311
+ else:
312
+ frame_ids = []
313
+ for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int):
314
+ frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride))
315
+ #print('frame_ids:', frame_ids)
316
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
317
+ return frames, action_list, vid_path
318
+ elif self.dataset == 'charades_ego_trimmed':
319
+ vid, start_second, end_second, narration = self.samples[i]
320
+ frames = video_loader(self.root, vid, start_second,
321
+ end_second=end_second,
322
+ chunk_len=-1, # no chunk for CharadesEgo
323
+ fps=-1, # could be variable fps
324
+ clip_length=clip_length,
325
+ jitter=is_training)
326
+ return frames, narration
327
+ else:
328
+ raise NotImplementedError
329
+
330
+ def __getitem__(self, i):
331
+ raise NotImplementedError
332
+
333
+ def __len__(self):
334
+ return len(self.samples)
335
+
336
+
337
+ class VideoCaptionDatasetCLIP(VideoCaptionDatasetBase):
338
+ def __init__(self, dataset, root, metadata, transform=None,
339
+ is_training=True, tokenizer=None,
340
+ clip_length=32, clip_stride=2, sparse_sample=False,
341
+ narration_selection='random',
342
+ num_hard_negatives=0,
343
+ subsample_stride=None):
344
+ super().__init__(dataset, root, metadata)
345
+
346
+ self.full_samples = self.samples.copy()
347
+ if isinstance(subsample_stride, int):
348
+ self.samples = self.samples[::subsample_stride]
349
+ self.transform = transform
350
+ self.is_training = is_training
351
+ self.tokenizer = tokenizer
352
+ self.clip_length = clip_length
353
+ self.clip_stride = clip_stride
354
+ self.sparse_sample = sparse_sample
355
+ self.narration_selection = narration_selection
356
+ self.num_hard_negatives = num_hard_negatives
357
+ if num_hard_negatives > 0:
358
+ assert self.dataset == 'htm_aa'
359
+
360
+ def __getitem__(self, i):
361
+ frames, caption = self.get_raw_item(
362
+ i, is_training=self.is_training,
363
+ clip_length=self.clip_length,
364
+ clip_stride=self.clip_stride,
365
+ sparse_sample=self.sparse_sample,
366
+ narration_selection=self.narration_selection,
367
+ )
368
+
369
+ # ek100_mir will also output relevancy value
370
+ if isinstance(caption, tuple):
371
+ caption, relevancy = caption
372
+ else:
373
+ relevancy = 0.
374
+
375
+ # apply transformation
376
+ if self.transform is not None:
377
+ frames = self.transform(frames)
378
+
379
+ # tokenize caption
380
+ if self.tokenizer is not None:
381
+ caption = self.tokenizer(caption)
382
+
383
+ if isinstance(caption, tuple):
384
+ caption, mask = caption
385
+ return frames, caption, mask, relevancy
386
+ else:
387
+ return frames, caption, relevancy
388
+
389
+
390
+ class VideoCaptionDatasetMCQ(VideoCaptionDatasetBase):
391
+ def __init__(self, dataset, root, metadata, transform=None,
392
+ is_training=True, tokenizer=None,
393
+ clip_length=32, clip_stride=2, sparse_sample=False,
394
+ narration_selection='random'):
395
+ super().__init__(dataset, root, metadata)
396
+
397
+ self.full_samples = self.samples.copy()
398
+ self.transform = transform
399
+ self.is_training = is_training
400
+ self.tokenizer = tokenizer
401
+ self.clip_length = clip_length
402
+ self.clip_stride = clip_stride
403
+ self.sparse_sample = sparse_sample
404
+ self.narration_selection = narration_selection
405
+
406
+ def __getitem__(self, i):
407
+
408
+ textQuery, frames_options, narration_options, answerIndex, q_type = self.get_raw_item(
409
+ i, is_training=self.is_training,
410
+ clip_length=self.clip_length,
411
+ clip_stride=self.clip_stride,
412
+ sparse_sample=self.sparse_sample,
413
+ narration_selection=self.narration_selection,
414
+ )
415
+
416
+ # apply transformation
417
+ if self.transform is not None:
418
+ frames_options = [self.transform(frames) for frames in frames_options]
419
+
420
+ # tokenize caption
421
+ if self.tokenizer is not None:
422
+ textQuery = self.tokenizer(textQuery)
423
+ narration_options = self.tokenizer(narration_options)
424
+ if isinstance(textQuery, tuple):
425
+ textQuery, mask_query = textQuery
426
+ narration_options, mask_options = narration_options
427
+ return (
428
+ textQuery, torch.stack(frames_options, dim=0),
429
+ narration_options, answerIndex, q_type,
430
+ mask_query, mask_options
431
+ )
432
+ else:
433
+ return textQuery, torch.stack(frames_options, dim=0), narration_options, answerIndex, q_type
434
+
435
+
436
+ class VideoClassyDataset(VideoCaptionDatasetBase):
437
+ def __init__(
438
+ self, dataset, root, metadata, transform=None,
439
+ is_training=True, label_mapping=None,
440
+ num_clips=1,
441
+ clip_length=32, clip_stride=2,
442
+ sparse_sample=False,
443
+ is_trimmed=True,
444
+ ):
445
+ super().__init__(dataset, root, metadata, is_trimmed=is_trimmed)
446
+
447
+ self.transform = transform
448
+ self.is_training = is_training
449
+ self.label_mapping = label_mapping
450
+ self.num_clips = num_clips
451
+ self.clip_length = clip_length
452
+ self.clip_stride = clip_stride
453
+ self.sparse_sample = sparse_sample
454
+
455
+ def __getitem__(self, i):
456
+ frames, label, vid_path = self.get_raw_item(
457
+ i, is_training=self.is_training,
458
+ num_clips=self.num_clips,
459
+ clip_length=self.clip_length,
460
+ clip_stride=self.clip_stride,
461
+ sparse_sample=self.sparse_sample,
462
+ )
463
+
464
+ # apply transformation
465
+ if self.transform is not None:
466
+ frames = self.transform(frames)
467
+
468
+ if self.label_mapping is not None:
469
+ if isinstance(label, list):
470
+ # multi-label case
471
+ res_array = np.zeros(len(self.label_mapping))
472
+ for lbl in label:
473
+ res_array[self.label_mapping[lbl]] = 1.
474
+ label = res_array
475
+ else:
476
+ label = self.label_mapping[label]
477
+
478
+ return frames, label, vid_path
479
+
480
+
481
+ def get_dataset(train_transform, tokenizer, cfg, is_training=True):
482
+ narration_selection = cfg.get('narration_selection', 'random')
483
+ num_hard_neg = cfg.get('num_hard_neg', 0)
484
+ data_cfg = cfg['data']
485
+ if cfg['model']['arch'].startswith('CLIP') or cfg['model']['arch'].startswith('VCLM'):
486
+ if is_training:
487
+ metadata = data_cfg['metadata']
488
+ else:
489
+ metadata = data_cfg['metadata_val']
490
+
491
+ return VideoCaptionDatasetCLIP(
492
+ data_cfg['dataset'], data_cfg['root'], metadata, train_transform,
493
+ is_training=is_training,
494
+ tokenizer=tokenizer,
495
+ clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
496
+ sparse_sample=data_cfg['sparse_sample'],
497
+ narration_selection=narration_selection,
498
+ num_hard_negatives=num_hard_neg
499
+ )
500
+ else:
501
+ raise NotImplementedError
502
+
503
+
504
+ def get_downstream_dataset(transform, tokenizer, cfg, is_training=True, num_clips=0, label_mapping=None):
505
+ data_cfg = cfg['data']
506
+ n_clips = num_clips if num_clips > 0 else data_cfg['num_clips']
507
+ if is_training:
508
+ metadata = data_cfg['metadata']
509
+ return VideoClassyDataset(
510
+ data_cfg['dataset'], data_cfg['root'], metadata, transform,
511
+ is_training=True, label_mapping=label_mapping,
512
+ num_clips=n_clips,
513
+ clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
514
+ sparse_sample=data_cfg['sparse_sample'],
515
+ )
516
+ else:
517
+ metadata = data_cfg['metadata_val']
518
+ return VideoClassyDataset(
519
+ data_cfg['dataset'], data_cfg['root'], metadata, transform,
520
+ is_training=False, label_mapping=label_mapping,
521
+ num_clips=n_clips,
522
+ clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
523
+ sparse_sample=data_cfg['sparse_sample'],
524
+ is_trimmed=not data_cfg['dataset'] == 'charades_ego'
525
+ )
526
+
lavila/data/video_transforms.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Sequence
9
+ import torch
10
+ import torch.nn as nn
11
+ from torchvision import transforms
12
+
13
+
14
+ class Permute(nn.Module):
15
+ """
16
+ Permutation as an op
17
+ """
18
+
19
+ def __init__(self, ordering):
20
+ super().__init__()
21
+ self.ordering = ordering
22
+
23
+ def forward(self, frames):
24
+ """
25
+ Args:
26
+ frames in some ordering, by default (C, T, H, W)
27
+ Returns:
28
+ frames in the ordering that was specified
29
+ """
30
+ return frames.permute(self.ordering)
31
+
32
+
33
+ class TemporalCrop(nn.Module):
34
+ """
35
+ Convert the video into smaller clips temporally.
36
+ """
37
+
38
+ def __init__(
39
+ self, frames_per_clip: int = 8, stride: int = 8, frame_stride: int = 1
40
+ ):
41
+ super().__init__()
42
+ self.frames = frames_per_clip
43
+ self.stride = stride
44
+ self.frame_stride = frame_stride
45
+
46
+ def forward(self, video):
47
+ assert video.ndim == 4, "Must be (C, T, H, W)"
48
+ res = []
49
+ for start in range(
50
+ 0, video.size(1) - (self.frames * self.frame_stride) + 1, self.stride
51
+ ):
52
+ end = start + (self.frames) * self.frame_stride
53
+ res.append(video[:, start: end: self.frame_stride, ...])
54
+ return res
55
+
56
+
57
+ def crop_boxes(boxes, x_offset, y_offset):
58
+ """
59
+ Peform crop on the bounding boxes given the offsets.
60
+ Args:
61
+ boxes (ndarray or None): bounding boxes to peform crop. The dimension
62
+ is `num boxes` x 4.
63
+ x_offset (int): cropping offset in the x axis.
64
+ y_offset (int): cropping offset in the y axis.
65
+ Returns:
66
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
67
+ `num boxes` x 4.
68
+ """
69
+ cropped_boxes = boxes.copy()
70
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
71
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
72
+
73
+ return cropped_boxes
74
+
75
+
76
+ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
77
+ """
78
+ Perform uniform spatial sampling on the images and corresponding boxes.
79
+ Args:
80
+ images (tensor): images to perform uniform crop. The dimension is
81
+ `num frames` x `channel` x `height` x `width`.
82
+ size (int): size of height and weight to crop the images.
83
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
84
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
85
+ crop if height is larger than width.
86
+ boxes (ndarray or None): optional. Corresponding boxes to images.
87
+ Dimension is `num boxes` x 4.
88
+ scale_size (int): optinal. If not None, resize the images to scale_size before
89
+ performing any crop.
90
+ Returns:
91
+ cropped (tensor): images with dimension of
92
+ `num frames` x `channel` x `size` x `size`.
93
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
94
+ `num boxes` x 4.
95
+ """
96
+ assert spatial_idx in [0, 1, 2]
97
+ ndim = len(images.shape)
98
+ if ndim == 3:
99
+ images = images.unsqueeze(0)
100
+ height = images.shape[2]
101
+ width = images.shape[3]
102
+
103
+ if scale_size is not None:
104
+ if width <= height:
105
+ width, height = scale_size, int(height / width * scale_size)
106
+ else:
107
+ width, height = int(width / height * scale_size), scale_size
108
+ images = torch.nn.functional.interpolate(
109
+ images,
110
+ size=(height, width),
111
+ mode="bilinear",
112
+ align_corners=False,
113
+ )
114
+
115
+ y_offset = int(math.ceil((height - size) / 2))
116
+ x_offset = int(math.ceil((width - size) / 2))
117
+
118
+ if height > width:
119
+ if spatial_idx == 0:
120
+ y_offset = 0
121
+ elif spatial_idx == 2:
122
+ y_offset = height - size
123
+ else:
124
+ if spatial_idx == 0:
125
+ x_offset = 0
126
+ elif spatial_idx == 2:
127
+ x_offset = width - size
128
+ cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size]
129
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
130
+ if ndim == 3:
131
+ cropped = cropped.squeeze(0)
132
+ return cropped, cropped_boxes
133
+
134
+
135
+ class SpatialCrop(nn.Module):
136
+ """
137
+ Convert the video into 3 smaller clips spatially. Must be used after the
138
+ temporal crops to get spatial crops, and should be used with
139
+ -2 in the spatial crop at the slowfast augmentation stage (so full
140
+ frames are passed in here). Will return a larger list with the
141
+ 3x spatial crops as well. It's useful for 3x4 testing (eg in SwinT)
142
+ or 3x10 testing in SlowFast etc.
143
+ """
144
+
145
+ def __init__(self, crop_size: int = 224, num_crops: int = 3):
146
+ super().__init__()
147
+ self.crop_size = crop_size
148
+ if num_crops == 6:
149
+ self.crops_to_ext = [0, 1, 2]
150
+ # I guess Swin uses 5 crops without flipping, but that doesn't
151
+ # make sense given they first resize to 224 and take 224 crops.
152
+ # (pg 6 of https://arxiv.org/pdf/2106.13230.pdf)
153
+ # So I'm assuming we can use flipped crops and that will add sth..
154
+ self.flipped_crops_to_ext = [0, 1, 2]
155
+ elif num_crops == 3:
156
+ self.crops_to_ext = [0, 1, 2]
157
+ self.flipped_crops_to_ext = []
158
+ elif num_crops == 1:
159
+ self.crops_to_ext = [1]
160
+ self.flipped_crops_to_ext = []
161
+ else:
162
+ raise NotImplementedError(
163
+ "Nothing else supported yet, "
164
+ "slowfast only takes 0, 1, 2 as arguments"
165
+ )
166
+
167
+ def forward(self, videos: Sequence[torch.Tensor]):
168
+ """
169
+ Args:
170
+ videos: A list of C, T, H, W videos.
171
+ Returns:
172
+ videos: A list with 3x the number of elements. Each video converted
173
+ to C, T, H', W' by spatial cropping.
174
+ """
175
+ assert isinstance(videos, list), "Must be a list of videos after temporal crops"
176
+ assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
177
+ res = []
178
+ for video in videos:
179
+ for spatial_idx in self.crops_to_ext:
180
+ res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
181
+ if not self.flipped_crops_to_ext:
182
+ continue
183
+ flipped_video = transforms.functional.hflip(video)
184
+ for spatial_idx in self.flipped_crops_to_ext:
185
+ res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
186
+ return res
lavila/models/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
lavila/models/distributed_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # Part of the code is from
7
+ # `https://github.com/facebookresearch/vissl/blob/main/vissl/utils/distributed_utils.py` and
8
+ # `https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/generic/distributed_util.py`
9
+ # Modified by Yue Zhao
10
+ # The original code is under MIT License
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ from typing import Tuple
15
+
16
+
17
+ def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]:
18
+ """
19
+ For some backends, such as NCCL, communication only works if the
20
+ tensor is on the GPU. This helper function converts to the correct
21
+ device and returns the tensor + original device.
22
+ """
23
+ orig_device = "cpu" if not tensor.is_cuda else "gpu"
24
+ if (
25
+ torch.distributed.is_available()
26
+ and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
27
+ and not tensor.is_cuda
28
+ ):
29
+ tensor = tensor.cuda()
30
+ return (tensor, orig_device)
31
+
32
+
33
+ def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor:
34
+ """
35
+ For some backends, such as NCCL, communication only works if the
36
+ tensor is on the GPU. This converts the tensor back to original device.
37
+ """
38
+ if tensor.is_cuda and orig_device == "cpu":
39
+ tensor = tensor.cpu()
40
+ return tensor
41
+
42
+
43
+ def is_distributed_training_run() -> bool:
44
+ return (
45
+ torch.distributed.is_available()
46
+ and torch.distributed.is_initialized()
47
+ and (torch.distributed.get_world_size() > 1)
48
+ )
49
+
50
+
51
+ class GatherLayer(torch.autograd.Function):
52
+ """
53
+ Gather tensors from all workers with support for backward propagation:
54
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
55
+ """
56
+
57
+ @staticmethod
58
+ def forward(ctx, x):
59
+ output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
60
+ dist.all_gather(output, x)
61
+ return tuple(output)
62
+
63
+ @staticmethod
64
+ def backward(ctx, *grads):
65
+ all_gradients = torch.stack(grads)
66
+ dist.all_reduce(all_gradients)
67
+ return all_gradients[dist.get_rank()]
68
+
69
+
70
+ def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
71
+ """
72
+ Similar to classy_vision.generic.distributed_util.gather_from_all
73
+ except that it does not cut the gradients
74
+ """
75
+ if tensor.ndim == 0:
76
+ # 0 dim tensors cannot be gathered. so unsqueeze
77
+ tensor = tensor.unsqueeze(0)
78
+
79
+ if is_distributed_training_run():
80
+ tensor, orig_device = convert_to_distributed_tensor(tensor)
81
+ gathered_tensors = GatherLayer.apply(tensor)
82
+ gathered_tensors = [
83
+ convert_to_normal_tensor(_tensor, orig_device)
84
+ for _tensor in gathered_tensors
85
+ ]
86
+ else:
87
+ gathered_tensors = [tensor]
88
+ gathered_tensor = torch.cat(gathered_tensors, 0)
89
+ return gathered_tensor
lavila/models/models.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import timm
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from lavila.models.openai_clip import load as load_openai_clip
14
+ from lavila.models.openai_model import QuickGELU, Transformer
15
+ from lavila.models.timesformer import SpaceTimeTransformer
16
+ from lavila.models.utils import remap_keys, rsetattr
17
+ from lavila.models.prompt_tuning import PromptLearner
18
+
19
+
20
+ class CLIP(nn.Module):
21
+ def __init__(self,
22
+ cfg,
23
+ embed_dim: int,
24
+ # vision
25
+ vision_width: int,
26
+ vision_model: nn.Module,
27
+ # text
28
+ context_length: int,
29
+ vocab_size: int,
30
+ transformer_width: int,
31
+ transformer_heads: int,
32
+ transformer_layers: int,
33
+ tempearture_init=0.07,
34
+ **kwargs,
35
+ ):
36
+ super().__init__()
37
+
38
+ self.context_length = context_length
39
+ self.vision_width = vision_width
40
+ self.tune_bias = cfg.get('tune_bias', False)
41
+ self.freeze_vis_backbone = cfg.get('freeze_vis_backbone', False)
42
+ self.freeze_txt_backbone = cfg.get('freeze_txt_backbone', False)
43
+
44
+ self.visual = vision_model
45
+ self.t_step = cfg.get('t_step', self.visual.num_frames)
46
+ txt_prompt_cfg = cfg.get('text_prompt', {})
47
+ self.n_ctx = txt_prompt_cfg.get('n_ctx', 0)
48
+ self.txt_use_bank = txt_prompt_cfg.get('use_bank', False)
49
+ if self.txt_use_bank:
50
+ self.transformer = Transformer(
51
+ width=transformer_width,
52
+ layers=transformer_layers,
53
+ heads=transformer_heads,
54
+ attn_mask=self.build_attention_mask(),
55
+ prompt_cfg=txt_prompt_cfg,
56
+ prompt_learner=PromptLearner(transformer_width, self.n_ctx),
57
+ prompt_generator=self.visual.prompt_generator
58
+ )
59
+ else:
60
+ self.transformer = Transformer(
61
+ width=transformer_width,
62
+ layers=transformer_layers,
63
+ heads=transformer_heads,
64
+ attn_mask=self.build_attention_mask(),
65
+ prompt_cfg=txt_prompt_cfg,
66
+ prompt_learner=PromptLearner(transformer_width, self.n_ctx)
67
+ )
68
+
69
+ self.vocab_size = vocab_size
70
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
71
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
72
+ self.ln_final = nn.LayerNorm(transformer_width) # used to be `models.transformer.LayerNorm``
73
+
74
+ self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
75
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
76
+ print("=> initialize initial temperature with {}".format(tempearture_init))
77
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / tempearture_init))
78
+
79
+ self.initialize_parameters()
80
+
81
+ freeze_list = []
82
+ if self.freeze_vis_backbone:
83
+ print("=> Freeze visual backbone")
84
+ freeze_list += self.visual.param_list + [self.image_projection]
85
+
86
+ if self.freeze_txt_backbone:
87
+ print("=> Freeze text backbone")
88
+ if self.tune_bias:
89
+ freeze_list += [m for n, m in self.transformer.named_parameters() if 'prompt' not in n and 'bias' not in n]
90
+ freeze_list += [m for n, m in self.ln_final.named_parameters() if 'bias' not in n]
91
+ else:
92
+ freeze_list += [m for n, m in self.transformer.named_parameters() if 'prompt' not in n]
93
+ freeze_list += list(self.ln_final.parameters())
94
+ freeze_list += list(self.token_embedding.parameters())
95
+ freeze_list += [self.positional_embedding] + [self.text_projection]
96
+
97
+ for p in freeze_list:
98
+ p.requires_grad = False
99
+
100
+ # text prompts
101
+ if self.n_ctx > 0:
102
+ if self.txt_use_bank:
103
+ prompt_dim = self.visual.prompt_dim
104
+ if prompt_dim != transformer_width:
105
+ self.transformer.prompt_inproj = nn.Linear(transformer_width, prompt_dim, bias=False)
106
+ else:
107
+ self.transformer.prompt_inproj = nn.Identity()
108
+ self.transformer.prompt_outproj = nn.Linear(prompt_dim, transformer_width, bias=False)
109
+ nn.init.kaiming_normal_(
110
+ self.transformer.prompt_outproj.weight, a=0, mode='fan_out')
111
+
112
+ params_to_update = [n for n, m in self.named_parameters() if m.requires_grad]
113
+ num_opt_params = sum([m.numel() for m in self.parameters() if m.requires_grad])
114
+ num_fz_params = sum([m.numel() for m in self.parameters() if not m.requires_grad])
115
+ print("=> Params to update: {}".format(params_to_update))
116
+ print("=> Update/Frozen: {}/{}".format(num_opt_params, num_fz_params))
117
+
118
+ def initialize_parameters(self):
119
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
120
+ nn.init.normal_(self.positional_embedding, std=0.01)
121
+
122
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
123
+ attn_std = self.transformer.width ** -0.5
124
+ fc_std = (2 * self.transformer.width) ** -0.5
125
+ for block in self.transformer.resblocks:
126
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
127
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
128
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
129
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
130
+
131
+ nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5)
132
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
133
+
134
+ def build_attention_mask(self):
135
+ # lazily create causal attention mask, with full attention between the vision tokens
136
+ # pytorch uses additive attention mask; fill with -inf
137
+ mask = torch.empty(self.context_length, self.context_length)
138
+ mask.fill_(float("-inf"))
139
+ mask.triu_(1) # zero out the lower diagonal
140
+ return mask
141
+
142
+ def encode_image(self, image, use_checkpoint=False, apply_project=True, istrain=False, gamma=1.0):
143
+ x, ps_loss = self.visual(image, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma)
144
+
145
+ if isinstance(x, list):
146
+ assert len(x) == 1
147
+ x = x[0]
148
+ if apply_project:
149
+ x = x @ self.image_projection
150
+
151
+ return x, ps_loss
152
+
153
+ def encode_text(self, text, use_checkpoint=False, istrain=False, gamma=1.0):
154
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
155
+ B = x.shape[0]
156
+ eot = text.argmax(dim=-1)
157
+
158
+ x = x.permute(1, 0, 2) # NLD -> LND
159
+ x, ps_loss = self.transformer(x, self.positional_embedding, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma, eot=eot)
160
+ x = x.permute(1, 0, 2) # LND -> NLD
161
+ x = self.ln_final(x)
162
+
163
+ # x.shape = [batch_size, n_ctx, transformer.width]
164
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
165
+ x = x[torch.arange(x.shape[0]), self.n_ctx + eot] @ self.text_projection
166
+
167
+ return x, ps_loss
168
+
169
+ def forward(self, image, text, use_checkpoint=False, norm_embed=False, istrain=False, gamma=1.0):
170
+ image_embed, ps_loss_img = self.encode_image(image, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma)
171
+ text_embed, ps_loss_txt = self.encode_text(text, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma)
172
+
173
+ if norm_embed:
174
+ image_embed = F.normalize(image_embed, dim=-1)
175
+ text_embed = F.normalize(text_embed, dim=-1)
176
+ return {'image_embed': image_embed,
177
+ 'text_embed': text_embed,
178
+ 'logit_scale': self.logit_scale.exp(),
179
+ 'ps_loss': ps_loss_img + ps_loss_txt}
180
+
181
+ def train(self, mode=True):
182
+ if not isinstance(mode, bool):
183
+ raise ValueError("training mode is expected to be boolean")
184
+ self.training = mode
185
+ for m in self.modules():
186
+ m.training = mode
187
+
188
+ if mode:
189
+ if self.freeze_vis_backbone and not self.tune_bias:
190
+ for n, m in self.visual.named_modules():
191
+ if 'prompt' not in n:
192
+ m.training = False
193
+
194
+ if self.freeze_txt_backbone and not self.tune_bias:
195
+ for n, m in self.transformer.named_modules():
196
+ if 'prompt' not in n:
197
+ m.training = False
198
+
199
+ self.token_embedding.training = False
200
+ self.ln_final.training = False
201
+
202
+
203
+ def CLIP_OPENAI_TIMESFORMER_BASE(
204
+ num_frames=4, timesformer_gated_xattn=False, temperature_init=0.07,
205
+ project_embed_dim=256, **kwargs
206
+ ):
207
+ cfg = kwargs.pop('model_cfg', {})
208
+ vision_model = SpaceTimeTransformer(
209
+ num_frames=num_frames,
210
+ time_init='zeros',
211
+ attention_style='frozen-in-time',
212
+ ln_pre=True,
213
+ act_layer=QuickGELU,
214
+ is_tanh_gating=timesformer_gated_xattn,
215
+ drop_path_rate=cfg.get('drop_path_rate', 0),
216
+ tune_bias=cfg.get('tune_bias', False),
217
+ prompt_cfg=cfg.get('visual_prompt', {})
218
+ )
219
+ clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
220
+ print("=> Loading CLIP (ViT-B/16) weights")
221
+ remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12)
222
+ res = vision_model.load_state_dict(remapped_state_dict, strict=False)
223
+ print(res)
224
+
225
+ vision_model.head = nn.Identity()
226
+ vision_model.pre_logits = nn.Identity()
227
+ vision_model.fc = nn.Identity()
228
+ model = CLIP(
229
+ cfg,
230
+ embed_dim=project_embed_dim,
231
+ vision_width=768,
232
+ vision_model=vision_model,
233
+ context_length=77,
234
+ vocab_size=49408,
235
+ transformer_width=512,
236
+ transformer_heads=8,
237
+ transformer_layers=12,
238
+ tempearture_init=temperature_init,
239
+ **kwargs
240
+ )
241
+ model.transformer.load_state_dict(clip_model.transformer.state_dict(), strict=False)
242
+ model.token_embedding.load_state_dict(clip_model.token_embedding.state_dict())
243
+ model.positional_embedding.data.copy_(clip_model.positional_embedding.data)
244
+ model.ln_final.load_state_dict(clip_model.ln_final.state_dict())
245
+ if project_embed_dim == clip_model.text_projection.shape[1]:
246
+ print("=> Loading CLIP's text_projection, image_projection and logit_scale directly")
247
+ model.image_projection.data.copy_(clip_model.visual.proj.data)
248
+ model.text_projection.data.copy_(clip_model.text_projection.data)
249
+ model.logit_scale.data.copy_(clip_model.logit_scale.data)
250
+ return model
251
+
252
+
lavila/models/openai_clip.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Part of the code is from https://github.com/openai/CLIP/blob/main/clip/clip.py
8
+ # Modified by Yue Zhao
9
+ # The original code is under MIT License
10
+
11
+ import hashlib
12
+ import os
13
+ import urllib
14
+ import warnings
15
+ from typing import Union, List
16
+ from pkg_resources import packaging
17
+
18
+ import torch
19
+ from PIL import Image
20
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
21
+ from tqdm import tqdm
22
+
23
+ from .openai_model import build_model
24
+ from .tokenizer import SimpleTokenizer as _Tokenizer
25
+
26
+ try:
27
+ from torchvision.transforms import InterpolationMode
28
+ BICUBIC = InterpolationMode.BICUBIC
29
+ except ImportError:
30
+ BICUBIC = Image.BICUBIC
31
+
32
+
33
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
34
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
35
+
36
+
37
+ __all__ = ["available_models", "load", "tokenize"]
38
+ _tokenizer = _Tokenizer()
39
+
40
+ _MODELS = {
41
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
42
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
43
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
44
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
45
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
46
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
47
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
48
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
49
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
50
+ }
51
+
52
+
53
+ def _download(url: str, root: str):
54
+ os.makedirs(root, exist_ok=True)
55
+ filename = os.path.basename(url)
56
+
57
+ expected_sha256 = url.split("/")[-2]
58
+ download_target = os.path.join(root, filename)
59
+
60
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
61
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
62
+
63
+ if os.path.isfile(download_target):
64
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
65
+ return download_target
66
+ else:
67
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
68
+
69
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
70
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
71
+ while True:
72
+ buffer = source.read(8192)
73
+ if not buffer:
74
+ break
75
+
76
+ output.write(buffer)
77
+ loop.update(len(buffer))
78
+
79
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
80
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
81
+
82
+ return download_target
83
+
84
+
85
+ def _convert_image_to_rgb(image):
86
+ return image.convert("RGB")
87
+
88
+
89
+ def _transform(n_px):
90
+ return Compose([
91
+ Resize(n_px, interpolation=BICUBIC),
92
+ CenterCrop(n_px),
93
+ _convert_image_to_rgb,
94
+ ToTensor(),
95
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
96
+ ])
97
+
98
+
99
+ def available_models() -> List[str]:
100
+ """Returns the names of available CLIP models"""
101
+ return list(_MODELS.keys())
102
+
103
+
104
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
105
+ """Load a CLIP model
106
+ Parameters
107
+ ----------
108
+ name : str
109
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
110
+ device : Union[str, torch.device]
111
+ The device to put the loaded model
112
+ jit : bool
113
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
114
+ download_root: str
115
+ path to download the model files; by default, it uses "~/.cache/clip"
116
+ Returns
117
+ -------
118
+ model : torch.nn.Module
119
+ The CLIP model
120
+ preprocess : Callable[[PIL.Image], torch.Tensor]
121
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
122
+ """
123
+ if name in _MODELS:
124
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("/store/nosnap/.cache/clip"))
125
+ elif os.path.isfile(name):
126
+ model_path = name
127
+ else:
128
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
129
+
130
+ with open(model_path, 'rb') as opened_file:
131
+ try:
132
+ # loading JIT archive
133
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
134
+ state_dict = None
135
+ except RuntimeError:
136
+ # loading saved state dict
137
+ if jit:
138
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
139
+ jit = False
140
+ state_dict = torch.load(opened_file, map_location="cpu")
141
+
142
+ if not jit:
143
+ model = build_model(state_dict or model.state_dict()).to(device)
144
+ if str(device) == "cpu":
145
+ model.float()
146
+ return model, _transform(model.visual.input_resolution)
147
+
148
+ # patch the device names
149
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
150
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
151
+
152
+ def patch_device(module):
153
+ try:
154
+ graphs = [module.graph] if hasattr(module, "graph") else []
155
+ except RuntimeError:
156
+ graphs = []
157
+
158
+ if hasattr(module, "forward1"):
159
+ graphs.append(module.forward1.graph)
160
+
161
+ for graph in graphs:
162
+ for node in graph.findAllNodes("prim::Constant"):
163
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
164
+ node.copyAttributes(device_node)
165
+
166
+ model.apply(patch_device)
167
+ patch_device(model.encode_image)
168
+ patch_device(model.encode_text)
169
+
170
+ # patch dtype to float32 on CPU
171
+ if str(device) == "cpu":
172
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
173
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
174
+ float_node = float_input.node()
175
+
176
+ def patch_float(module):
177
+ try:
178
+ graphs = [module.graph] if hasattr(module, "graph") else []
179
+ except RuntimeError:
180
+ graphs = []
181
+
182
+ if hasattr(module, "forward1"):
183
+ graphs.append(module.forward1.graph)
184
+
185
+ for graph in graphs:
186
+ for node in graph.findAllNodes("aten::to"):
187
+ inputs = list(node.inputs())
188
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
189
+ if inputs[i].node()["value"] == 5:
190
+ inputs[i].node().copyAttributes(float_node)
191
+
192
+ model.apply(patch_float)
193
+ patch_float(model.encode_image)
194
+ patch_float(model.encode_text)
195
+
196
+ model.float()
197
+
198
+ return model, _transform(model.input_resolution.item())
199
+
200
+
201
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
202
+ """
203
+ Returns the tokenized representation of given input string(s)
204
+ Parameters
205
+ ----------
206
+ texts : Union[str, List[str]]
207
+ An input string or a list of input strings to tokenize
208
+ context_length : int
209
+ The context length to use; all CLIP models use 77 as the context length
210
+ truncate: bool
211
+ Whether to truncate the text in case its encoding is longer than the context length
212
+ Returns
213
+ -------
214
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
215
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
216
+ """
217
+ if isinstance(texts, str):
218
+ texts = [texts]
219
+
220
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
221
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
222
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
223
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
224
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
225
+ else:
226
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
227
+
228
+ for i, tokens in enumerate(all_tokens):
229
+ if len(tokens) > context_length:
230
+ if truncate:
231
+ tokens = tokens[:context_length]
232
+ tokens[-1] = eot_token
233
+ else:
234
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
235
+ result[i, :len(tokens)] = torch.tensor(tokens)
236
+
237
+ return result
lavila/models/openai_model.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Part of the code is from https://github.com/openai/CLIP/blob/main/clip/model.py
8
+ # Modified by Yue Zhao
9
+ # The original code is under MIT License
10
+
11
+ from collections import OrderedDict
12
+ from typing import Tuple, Union
13
+ from einops import rearrange
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import torch.utils.checkpoint as checkpoint
19
+ from torch import nn
20
+ import pdb
21
+
22
+
23
+ class Bottleneck(nn.Module):
24
+ expansion = 4
25
+
26
+ def __init__(self, inplanes, planes, stride=1):
27
+ super().__init__()
28
+
29
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
30
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
31
+ self.bn1 = nn.BatchNorm2d(planes)
32
+ self.relu1 = nn.ReLU(inplace=True)
33
+
34
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
35
+ self.bn2 = nn.BatchNorm2d(planes)
36
+ self.relu2 = nn.ReLU(inplace=True)
37
+
38
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
39
+
40
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
41
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
42
+ self.relu3 = nn.ReLU(inplace=True)
43
+
44
+ self.downsample = None
45
+ self.stride = stride
46
+
47
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
48
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
49
+ self.downsample = nn.Sequential(OrderedDict([
50
+ ("-1", nn.AvgPool2d(stride)),
51
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
52
+ ("1", nn.BatchNorm2d(planes * self.expansion))
53
+ ]))
54
+
55
+ def forward(self, x: torch.Tensor):
56
+ identity = x
57
+
58
+ out = self.relu1(self.bn1(self.conv1(x)))
59
+ out = self.relu2(self.bn2(self.conv2(out)))
60
+ out = self.avgpool(out)
61
+ out = self.bn3(self.conv3(out))
62
+
63
+ if self.downsample is not None:
64
+ identity = self.downsample(x)
65
+
66
+ out += identity
67
+ out = self.relu3(out)
68
+ return out
69
+
70
+
71
+ class AttentionPool2d(nn.Module):
72
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
73
+ super().__init__()
74
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
75
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
76
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
77
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
78
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
79
+ self.num_heads = num_heads
80
+
81
+ def forward(self, x):
82
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
83
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
84
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
85
+ x, _ = F.multi_head_attention_forward(
86
+ query=x[:1], key=x, value=x,
87
+ embed_dim_to_check=x.shape[-1],
88
+ num_heads=self.num_heads,
89
+ q_proj_weight=self.q_proj.weight,
90
+ k_proj_weight=self.k_proj.weight,
91
+ v_proj_weight=self.v_proj.weight,
92
+ in_proj_weight=None,
93
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
94
+ bias_k=None,
95
+ bias_v=None,
96
+ add_zero_attn=False,
97
+ dropout_p=0,
98
+ out_proj_weight=self.c_proj.weight,
99
+ out_proj_bias=self.c_proj.bias,
100
+ use_separate_proj_weight=True,
101
+ training=self.training,
102
+ need_weights=False
103
+ )
104
+ return x.squeeze(0)
105
+
106
+
107
+ class ModifiedResNet(nn.Module):
108
+ """
109
+ A ResNet class that is similar to torchvision's but contains the following changes:
110
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
111
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
112
+ - The final pooling layer is a QKV attention instead of an average pool
113
+ """
114
+
115
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
116
+ super().__init__()
117
+ self.output_dim = output_dim
118
+ self.input_resolution = input_resolution
119
+
120
+ # the 3-layer stem
121
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
122
+ self.bn1 = nn.BatchNorm2d(width // 2)
123
+ self.relu1 = nn.ReLU(inplace=True)
124
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
125
+ self.bn2 = nn.BatchNorm2d(width // 2)
126
+ self.relu2 = nn.ReLU(inplace=True)
127
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
128
+ self.bn3 = nn.BatchNorm2d(width)
129
+ self.relu3 = nn.ReLU(inplace=True)
130
+ self.avgpool = nn.AvgPool2d(2)
131
+
132
+ # residual layers
133
+ self._inplanes = width # this is a *mutable* variable used during construction
134
+ self.layer1 = self._make_layer(width, layers[0])
135
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
136
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
137
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
138
+
139
+ embed_dim = width * 32 # the ResNet feature dimension
140
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
141
+
142
+ def _make_layer(self, planes, blocks, stride=1):
143
+ layers = [Bottleneck(self._inplanes, planes, stride)]
144
+
145
+ self._inplanes = planes * Bottleneck.expansion
146
+ for _ in range(1, blocks):
147
+ layers.append(Bottleneck(self._inplanes, planes))
148
+
149
+ return nn.Sequential(*layers)
150
+
151
+ def forward(self, x):
152
+ def stem(x):
153
+ x = self.relu1(self.bn1(self.conv1(x)))
154
+ x = self.relu2(self.bn2(self.conv2(x)))
155
+ x = self.relu3(self.bn3(self.conv3(x)))
156
+ x = self.avgpool(x)
157
+ return x
158
+
159
+ x = x.type(self.conv1.weight.dtype)
160
+ x = stem(x)
161
+ x = self.layer1(x)
162
+ x = self.layer2(x)
163
+ x = self.layer3(x)
164
+ x = self.layer4(x)
165
+ x = self.attnpool(x)
166
+
167
+ return x
168
+
169
+
170
+ class LayerNorm(nn.LayerNorm):
171
+ """Subclass torch's LayerNorm to handle fp16."""
172
+
173
+ def forward(self, x: torch.Tensor):
174
+ orig_type = x.dtype
175
+ ret = super().forward(x.type(torch.float32))
176
+ return ret.type(orig_type)
177
+
178
+
179
+ class QuickGELU(nn.Module):
180
+ def forward(self, x: torch.Tensor):
181
+ return x * torch.sigmoid(1.702 * x)
182
+
183
+
184
+ class ResidualAttentionBlock(nn.Module):
185
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
186
+ super().__init__()
187
+
188
+ self.attn = nn.MultiheadAttention(d_model, n_head)
189
+ self.ln_1 = nn.LayerNorm(d_model) # used to be `models.transformer.LayerNorm`
190
+ self.mlp = nn.Sequential(OrderedDict([
191
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
192
+ ("gelu", QuickGELU()),
193
+ ("c_proj", nn.Linear(d_model * 4, d_model))
194
+ ]))
195
+ self.ln_2 = nn.LayerNorm(d_model) # used to be `models.transformer.LayerNorm`
196
+ self.attn_mask = attn_mask
197
+
198
+ def attention(self, x: torch.Tensor):
199
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
200
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
201
+
202
+ def forward_part1(self, x):
203
+ return self.attention(self.ln_1(x))
204
+
205
+ def forward_part2(self, x):
206
+ return self.mlp(self.ln_2(x))
207
+
208
+ def forward(self, x: torch.Tensor, use_checkpoint=False):
209
+ if use_checkpoint:
210
+ x = x + checkpoint.checkpoint(self.forward_part1, x)
211
+ else:
212
+ x = x + self.forward_part1(x)
213
+
214
+ if use_checkpoint:
215
+ x = x + checkpoint.checkpoint(self.forward_part2, x)
216
+ else:
217
+ x = x + self.forward_part2(x)
218
+ return x
219
+
220
+
221
+ class Transformer(nn.Module):
222
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, prompt_cfg={}, prompt_learner=None, prompt_generator=None):
223
+ super().__init__()
224
+ self.width = width
225
+ self.layers = layers
226
+ self.resblocks = nn.ModuleList([ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
227
+ self.num_tokens = prompt_cfg.pop('n_ctx', 0)
228
+ self.use_bank = prompt_cfg.pop('use_bank', False)
229
+ if self.num_tokens > 0:
230
+ self.prompt_learner = prompt_learner
231
+ self.prompt_generator = prompt_generator
232
+ self.k_s = 0
233
+ if self.prompt_generator is not None:
234
+ if self.prompt_generator.use_bank:
235
+ self.k_s = len(self.prompt_generator.prompt_pool)
236
+ self.prompt_inproj = None
237
+ self.prompt_outproj = None
238
+
239
+ def forward(self, x: torch.Tensor, pos_emb, use_checkpoint=False, istrain=False, gamma=1.0, eot=None):
240
+ ps_loss = x.new_zeros([1])
241
+ BZ = x.size(1)
242
+ if not self.use_bank:
243
+ if self.num_tokens > 0:
244
+ ctx = self.prompt_learner()
245
+ ctx = ctx.unsqueeze(1).expand(-1, BZ, -1)
246
+ x = torch.cat((
247
+ x[:1, :, :], # SOT
248
+ ctx,
249
+ x[1:, :, :]
250
+ ), dim=0)
251
+ x = x[:pos_emb.size(0)] + pos_emb.unsqueeze(1)
252
+
253
+ for i, blk in enumerate(self.resblocks):
254
+ if self.num_tokens > 0 and self.use_bank:
255
+ k = self.num_tokens
256
+ num_tokens = 0 if i == 0 else self.num_tokens
257
+ x = torch.cat((x[:1, :, :], x[num_tokens+1:, :, :]), dim=0)
258
+ query = self.prompt_inproj(x[eot, torch.arange(BZ), :].detach())
259
+ if i < self.k_s:
260
+ out = self.prompt_generator.prompt_pool[i](query, k, istrain=istrain, gamma=gamma)
261
+ ctx = self.prompt_outproj(out['prompts'])
262
+ ctx = ctx.transpose(1, 0) + pos_emb.unsqueeze(1)[1:self.num_tokens+1, :]
263
+ ps_loss += out.get('ps_loss', 0)
264
+ else:
265
+ ctx = self.prompt_learner()
266
+ ctx = ctx.unsqueeze(1).expand(-1, BZ, -1)
267
+ ctx = ctx + pos_emb.unsqueeze(1)[1:self.num_tokens+1, :]
268
+
269
+ x = torch.cat((
270
+ x[:1, :, :], # SOT
271
+ ctx,
272
+ x[1:, :, :]
273
+ ), dim=0)
274
+ x = x[:pos_emb.size(0)]
275
+
276
+ if use_checkpoint:
277
+ x = checkpoint.checkpoint(blk, x)
278
+ else:
279
+ x = blk(x)
280
+
281
+ return x, ps_loss
282
+
283
+
284
+
285
+ class VisionTransformer(nn.Module):
286
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
287
+ super().__init__()
288
+ self.input_resolution = input_resolution
289
+ self.output_dim = output_dim
290
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
291
+
292
+ scale = width ** -0.5
293
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
294
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
295
+ self.ln_pre = LayerNorm(width)
296
+
297
+ self.transformer = Transformer(width, layers, heads)
298
+
299
+ self.ln_post = LayerNorm(width)
300
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
301
+
302
+ def forward(self, x: torch.Tensor, apply_project=True, use_checkpoint=False, cls_at_last=True):
303
+ x = self.conv1(x) # shape = [*, width, grid, grid]
304
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
305
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
306
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
307
+ x = x + self.positional_embedding.to(x.dtype)
308
+ x = self.ln_pre(x)
309
+
310
+ x = x.permute(1, 0, 2) # NLD -> LND
311
+ x = self.transformer(x, use_checkpoint=use_checkpoint)
312
+ x = x.permute(1, 0, 2) # LND -> NLD
313
+
314
+ if cls_at_last:
315
+ x = self.ln_post(x[:, 0, :])
316
+
317
+ if self.proj is not None and apply_project:
318
+ x = x @ self.proj
319
+
320
+ return x
321
+ else:
322
+ return x[:, 1:, :]
323
+
324
+
325
+ class CLIP(nn.Module):
326
+ def __init__(self,
327
+ embed_dim: int,
328
+ # vision
329
+ image_resolution: int,
330
+ vision_layers: Union[Tuple[int, int, int, int], int],
331
+ vision_width: int,
332
+ vision_patch_size: int,
333
+ # text
334
+ context_length: int,
335
+ vocab_size: int,
336
+ transformer_width: int,
337
+ transformer_heads: int,
338
+ transformer_layers: int
339
+ ):
340
+ super().__init__()
341
+
342
+ self.context_length = context_length
343
+
344
+ if isinstance(vision_layers, (tuple, list)):
345
+ vision_heads = vision_width * 32 // 64
346
+ self.visual = ModifiedResNet(
347
+ layers=vision_layers,
348
+ output_dim=embed_dim,
349
+ heads=vision_heads,
350
+ input_resolution=image_resolution,
351
+ width=vision_width
352
+ )
353
+ else:
354
+ vision_heads = vision_width // 64
355
+ self.visual = VisionTransformer(
356
+ input_resolution=image_resolution,
357
+ patch_size=vision_patch_size,
358
+ width=vision_width,
359
+ layers=vision_layers,
360
+ heads=vision_heads,
361
+ output_dim=embed_dim
362
+ )
363
+
364
+ self.transformer = Transformer(
365
+ width=transformer_width,
366
+ layers=transformer_layers,
367
+ heads=transformer_heads,
368
+ attn_mask=self.build_attention_mask()
369
+ )
370
+
371
+ self.vocab_size = vocab_size
372
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
373
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
374
+ self.ln_final = LayerNorm(transformer_width)
375
+
376
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
377
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
378
+
379
+ self.initialize_parameters()
380
+
381
+ def initialize_parameters(self):
382
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
383
+ nn.init.normal_(self.positional_embedding, std=0.01)
384
+
385
+ if isinstance(self.visual, ModifiedResNet):
386
+ if self.visual.attnpool is not None:
387
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
388
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
389
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
390
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
391
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
392
+
393
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
394
+ for name, param in resnet_block.named_parameters():
395
+ if name.endswith("bn3.weight"):
396
+ nn.init.zeros_(param)
397
+
398
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
399
+ attn_std = self.transformer.width ** -0.5
400
+ fc_std = (2 * self.transformer.width) ** -0.5
401
+ for block in self.transformer.resblocks:
402
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
403
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
404
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
405
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
406
+
407
+ if self.text_projection is not None:
408
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
409
+
410
+ def build_attention_mask(self):
411
+ # lazily create causal attention mask, with full attention between the vision tokens
412
+ # pytorch uses additive attention mask; fill with -inf
413
+ mask = torch.empty(self.context_length, self.context_length)
414
+ mask.fill_(float("-inf"))
415
+ mask.triu_(1) # zero out the lower diagonal
416
+ return mask
417
+
418
+ @property
419
+ def dtype(self):
420
+ return self.visual.conv1.weight.dtype
421
+
422
+ def encode_image(self, image, apply_project=True, use_checkpoint=False):
423
+ if image.ndim == 4:
424
+ return self.visual(image.type(self.dtype))
425
+ else:
426
+ image = image.permute(0, 2, 1, 3, 4) # BCTHW -> BTCHW
427
+ bb, tt, _, _, _ = image.shape
428
+ x = self.visual(image.reshape(-1, *image.shape[2:]), apply_project=apply_project, use_checkpoint=use_checkpoint) # ND
429
+ x = x.view(bb, tt, -1)
430
+ image_features = x.mean(1)
431
+ # image_features = x.max(1).values
432
+ return image_features
433
+
434
+ def encode_text(self, text, use_checkpoint=False):
435
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
436
+
437
+ x = x + self.positional_embedding.type(self.dtype)
438
+ x = x.permute(1, 0, 2) # NLD -> LND
439
+ x = self.transformer(x, use_checkpoint=use_checkpoint)
440
+ x = x.permute(1, 0, 2) # LND -> NLD
441
+ x = self.ln_final(x).type(self.dtype)
442
+
443
+ # x.shape = [batch_size, n_ctx, transformer.width]
444
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
445
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
446
+
447
+ return x
448
+
449
+ def forward(self, image, text, use_checkpoint=False, norm_embed=True):
450
+ image_features = self.encode_image(image, use_checkpoint=use_checkpoint)
451
+ text_features = self.encode_text(text, use_checkpoint=use_checkpoint)
452
+
453
+ # normalized features
454
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
455
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
456
+
457
+ # # cosine similarity as logits
458
+ # logit_scale = self.logit_scale.exp()
459
+ # logits_per_image = logit_scale * image_features @ text_features.t()
460
+ # logits_per_text = logits_per_image.t()
461
+
462
+ # # shape = [global_batch_size, global_batch_size]
463
+ # return logits_per_image, logits_per_text
464
+
465
+ return {'image_embed': image_features,
466
+ 'text_embed': text_features,
467
+ 'logit_scale': self.logit_scale.exp()}
468
+
469
+
470
+ def convert_weights(model: nn.Module):
471
+ """Convert applicable model parameters to fp16"""
472
+
473
+ def _convert_weights_to_fp16(l):
474
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
475
+ l.weight.data = l.weight.data.half()
476
+ if l.bias is not None:
477
+ l.bias.data = l.bias.data.half()
478
+
479
+ if isinstance(l, nn.MultiheadAttention):
480
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
481
+ tensor = getattr(l, attr)
482
+ if tensor is not None:
483
+ tensor.data = tensor.data.half()
484
+
485
+ for name in ["text_projection", "proj"]:
486
+ if hasattr(l, name):
487
+ attr = getattr(l, name)
488
+ if attr is not None:
489
+ attr.data = attr.data.half()
490
+
491
+ model.apply(_convert_weights_to_fp16)
492
+
493
+
494
+ def build_model(state_dict: dict):
495
+ vit = "visual.proj" in state_dict
496
+
497
+ if vit:
498
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
499
+ vision_layers = len(
500
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]
501
+ )
502
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
503
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
504
+ image_resolution = vision_patch_size * grid_size
505
+ else:
506
+ counts: list = [
507
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]
508
+ ]
509
+ vision_layers = tuple(counts)
510
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
511
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
512
+ vision_patch_size = None
513
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
514
+ image_resolution = output_width * 32
515
+
516
+ embed_dim = state_dict["text_projection"].shape[1]
517
+ context_length = state_dict["positional_embedding"].shape[0]
518
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
519
+ transformer_width = state_dict["ln_final.weight"].shape[0]
520
+ transformer_heads = transformer_width // 64
521
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
522
+
523
+ model = CLIP(
524
+ embed_dim,
525
+ image_resolution, vision_layers, vision_width, vision_patch_size,
526
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
527
+ )
528
+
529
+ for key in ["input_resolution", "context_length", "vocab_size"]:
530
+ if key in state_dict:
531
+ del state_dict[key]
532
+
533
+ convert_weights(model)
534
+ model.load_state_dict(state_dict)
535
+ return model.eval()
lavila/models/prompt_tuning.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ from functools import reduce
4
+ from operator import mul
5
+ from einops import rearrange, repeat
6
+ import pdb
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ class PromptLearner(nn.Module):
13
+ def __init__(self, ctx_dim=512, n_ctx=16):
14
+ super(PromptLearner, self).__init__()
15
+ self.n_ctx = n_ctx
16
+ self.ctx_dim = ctx_dim
17
+
18
+ # initialize prompts
19
+ ctx_vectors = torch.empty(n_ctx, ctx_dim)
20
+ nn.init.normal_(ctx_vectors, std=0.02)
21
+ prompt_prefix = " ".join(["X"] * n_ctx)
22
+ self.ctx = nn.Parameter(ctx_vectors) # to be optimized
23
+ print(f'Initial context: "{prompt_prefix}"')
24
+ print(f"Number of context words (tokens): {n_ctx}")
25
+
26
+ def forward(self):
27
+ return self.ctx
28
+
29
+ class PromptPoolLearner(nn.Module):
30
+ def __init__(self, prompt_dim=256, size=128, length=1):
31
+ super(PromptPoolLearner, self).__init__()
32
+ self.prompt_dim = prompt_dim
33
+ self.length = length
34
+ self.size = size
35
+
36
+ # initiate prompt
37
+ self.prompt_values = nn.Parameter(torch.zeros(size, length, prompt_dim))
38
+ self.id_table = torch.ones([size]).cuda()
39
+
40
+ # xavier_uniform initialization
41
+ nn.init.uniform_(self.prompt_values.data, -1, 1)
42
+
43
+ def l2_normalize(self, x, dim=None, epsilon=1e-12):
44
+ """Normalizes a given vector or matrix."""
45
+ square_sum = torch.sum(x ** 2, dim=dim, keepdim=True)
46
+ x_inv_norm = torch.rsqrt(torch.maximum(square_sum, torch.tensor(epsilon, device=x.device)))
47
+ return x * x_inv_norm
48
+
49
+ def forward(self, query, k=0, istrain=False, gamma=1.0):
50
+ BZ = query.shape[0]
51
+ out = dict()
52
+ query = self.l2_normalize(query.squeeze(1), dim=1)
53
+ keys = self.prompt_values.mean(dim=1)
54
+ keys = self.l2_normalize(keys, dim=1)
55
+ similarity = torch.matmul(query, keys.t())
56
+
57
+ if k > 0 and k < self.size:
58
+
59
+ if istrain:
60
+ inv_freq = self.id_table.sum() / self.id_table.float()
61
+ weights = torch.softmax((1 + similarity) / 2 + 0.5 * (1 - gamma) * inv_freq / inv_freq.sum(), dim=1)
62
+ idx = torch.multinomial(weights, k, replacement=False)
63
+ else:
64
+ idx = torch.argsort(similarity, dim=-1, descending=True)[:, :k]
65
+
66
+ prompt_id, id_counts = torch.unique(idx, return_counts=True, sorted=True)
67
+ self.id_table[prompt_id] += id_counts
68
+ prompts = self.prompt_values[idx.flatten(), ...].view(BZ, k * self.length, self.prompt_dim)
69
+ else:
70
+ idx = torch.arange(self.size).unsqueeze(0).expand(BZ, -1)
71
+ prompts = self.prompt_values.flatten(0, 1).unsqueeze(0).expand(BZ, -1, -1)
72
+
73
+ prompts = self.l2_normalize(prompts, dim=-1)
74
+ out['prompts'] = prompts
75
+ sel_sim = similarity[torch.arange(BZ).view(-1, 1), idx]
76
+ sel_key = keys[idx.flatten(), ...].view(BZ, k, self.prompt_dim)
77
+ diff = F.mse_loss((sel_sim.unsqueeze(1) @ sel_key).squeeze(), query.detach(), reduction='sum') / BZ
78
+ ksim = torch.mean(torch.matmul(keys, keys.t()) - torch.eye(self.size).to(keys.device))
79
+ out['ps_loss'] = diff + ksim
80
+
81
+ return out
82
+
83
+
84
+ class VisualPromptLearner(nn.Module):
85
+ def __init__(self, patch_size=16, embed_dim=768, num_layers=12, prompt_dim=256, num_tokens=5, deep=False,
86
+ deep_shared=False, split_st=False, dropout=0.1, pool={}):
87
+ super(VisualPromptLearner, self).__init__()
88
+ self.num_layers = num_layers
89
+ self.embed_dim = embed_dim
90
+ self.prompt_dim = prompt_dim
91
+ self.num_tokens = num_tokens # number of prompted tokens
92
+ self.prompt_dropout = nn.Dropout(dropout)
93
+ pool_size = pool.get('size', 0)
94
+ self.pool_length = pool.get('length', 1)
95
+ self.use_bank = True if pool_size > 0 and num_tokens <= (pool_size * self.pool_length) else False
96
+ if self.use_bank:
97
+ print(f'Using feature bank with size {pool_size} (dimension: {prompt_dim})')
98
+
99
+ if prompt_dim != embed_dim:
100
+ self.prompt_inproj = nn.Linear(embed_dim, prompt_dim, bias=False)
101
+ else:
102
+ self.prompt_inproj = nn.Identity()
103
+
104
+ if self.use_bank:
105
+ self.prompt_outproj = nn.Linear(prompt_dim, embed_dim, bias=False)
106
+ nn.init.kaiming_normal_(
107
+ self.prompt_outproj.weight, a=0, mode='fan_out')
108
+ else:
109
+ self.prompt_outproj = nn.Identity()
110
+
111
+ self.split_st = split_st # split spatial and temporal prompts
112
+
113
+ # initiate prompt:
114
+ val = math.sqrt(6. / float(3 * reduce(mul, (patch_size, patch_size), 1) + prompt_dim))
115
+ if split_st:
116
+ if self.use_bank:
117
+ pool['size'] //= 2
118
+ self.spatial_prompt_pool = PromptPoolLearner(prompt_dim, **pool)
119
+ self.temporal_prompt_pool = PromptPoolLearner(prompt_dim, **pool)
120
+ else:
121
+ self.spatial_prompt_embeddings = nn.Parameter(torch.zeros(
122
+ 1, num_tokens // 2, prompt_dim))
123
+ self.temporal_prompt_embeddings = nn.Parameter(torch.zeros(
124
+ 1, num_tokens // 2, prompt_dim))
125
+ # xavier_uniform initialization
126
+ nn.init.uniform_(self.spatial_prompt_embeddings.data, -val, val)
127
+ nn.init.uniform_(self.temporal_prompt_embeddings.data, -val, val)
128
+ else:
129
+ if self.use_bank:
130
+ self.prompt_pool = PromptPoolLearner(prompt_dim, **pool)
131
+ else:
132
+ self.prompt_embeddings = nn.Parameter(torch.zeros(
133
+ 1, num_tokens, prompt_dim))
134
+ # xavier_uniform initialization
135
+ nn.init.uniform_(self.prompt_embeddings.data, -val, val)
136
+
137
+ self.deep = deep or deep_shared
138
+ self.deep_shared = deep_shared
139
+ if deep and (not deep_shared):
140
+ total_d_layer = num_layers - 1
141
+ if split_st:
142
+ if self.use_bank:
143
+ self.spatial_deep_prompt_pool = nn.ModuleList([
144
+ PromptPoolLearner(prompt_dim, **pool)
145
+ for i in range(total_d_layer)])
146
+ self.temporal_deep_prompt_pool = nn.ModuleList([
147
+ PromptPoolLearner(prompt_dim, **pool)
148
+ for i in range(total_d_layer)])
149
+ else:
150
+ self.spatial_deep_prompt_embeddings = nn.Parameter(torch.zeros(
151
+ total_d_layer, num_tokens // 2, prompt_dim))
152
+ self.temporal_deep_prompt_embeddings = nn.Parameter(torch.zeros(
153
+ total_d_layer, num_tokens // 2, prompt_dim))
154
+ # xavier_uniform initialization
155
+ nn.init.uniform_(self.spatial_deep_prompt_embeddings.data, -val, val)
156
+ nn.init.uniform_(self.temporal_deep_prompt_embeddings.data, -val, val)
157
+ else:
158
+ if self.use_bank:
159
+ self.deep_prompt_pool = nn.ModuleList([
160
+ PromptPoolLearner(prompt_dim, **pool)
161
+ for i in range(total_d_layer)])
162
+ else:
163
+ self.deep_prompt_embeddings = nn.Parameter(torch.zeros(
164
+ total_d_layer, num_tokens, prompt_dim))
165
+ # xavier_uniform initialization
166
+ nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)
167
+
168
+ def forward(self, query=None, layer=0, istrain=False, gamma=1.0):
169
+ query = query.detach()
170
+ query = self.prompt_inproj(query)
171
+ ps_loss = query.new_zeros([1])
172
+ if self.split_st:
173
+ if self.deep and (not self.deep_shared) and layer > 0:
174
+ if self.use_bank:
175
+ k = (self.num_tokens // 2) // self.pool_length
176
+ spatial_out = self.spatial_deep_prompt_pool[layer-1](query, k, istrain, gamma)
177
+ spatial_prompts = spatial_out['prompts']
178
+ temporal_out = self.temporal_deep_prompt_pool[layer-1](query, k, istrain, gamma)
179
+ temporal_prompts = temporal_out['prompts']
180
+ ps_loss += spatial_out.get('ps_loss', 0) + temporal_out.get('ps_loss', 0)
181
+ else:
182
+ spatial_prompts = self.spatial_deep_prompt_embeddings[layer-1]
183
+ temporal_prompts = self.temporal_deep_prompt_embeddings[layer-1]
184
+ else:
185
+ if self.use_bank:
186
+ k = (self.num_tokens // 2) // self.pool_length
187
+ spatial_out = self.spatial_prompt_pool(query, k, istrain, gamma)
188
+ spatial_prompts = spatial_out['prompts']
189
+ temporal_out = self.temporal_prompt_pool(query, k, istrain, gamma)
190
+ temporal_prompts = temporal_out['prompts']
191
+ ps_loss += spatial_out.get('ps_loss', 0) + temporal_out.get('ps_loss', 0)
192
+ else:
193
+ spatial_prompts = self.spatial_prompt_embeddings
194
+ temporal_prompts = self.temporal_prompt_embeddings
195
+
196
+ prompts = torch.cat((spatial_prompts, temporal_prompts), dim=1)
197
+
198
+ else:
199
+ if self.deep and (not self.deep_shared) and layer > 0:
200
+ if self.use_bank:
201
+ k = self.num_tokens // self.pool_length
202
+ out = self.deep_prompt_pool[layer-1](query, k, istrain, gamma)
203
+ prompts = out['prompts']
204
+ ps_loss += out.get('ps_loss', 0)
205
+ else:
206
+ prompts = self.deep_prompt_embeddings[layer-1]
207
+ else:
208
+ if self.use_bank:
209
+ k = self.num_tokens // self.pool_length
210
+ out = self.prompt_pool(query, k, istrain, gamma)
211
+ prompts = out['prompts']
212
+ ps_loss += out.get('ps_loss', 0)
213
+ else:
214
+ prompts = self.prompt_embeddings
215
+
216
+ prompts = self.prompt_dropout(self.prompt_outproj(prompts))
217
+ return prompts, ps_loss
218
+
219
+
220
+ class CMM(nn.Module):
221
+ '''Context modeling module'''
222
+ def __init__(self, num_tokens=8, num_frames=16, embed_dim=768, prompt_dim=256, dropout=0., num_layer=1, shared=False, pool={}):
223
+ super(CMM, self).__init__()
224
+ self.num_tokens = num_tokens
225
+ self.num_frames = num_frames
226
+ self.embed_dim = embed_dim
227
+ self.prompt_dim = prompt_dim
228
+ self.pool_size = pool.get('size', 0)
229
+ self.pool_length = pool.get('length', 1)
230
+ self.use_bank = True if self.pool_size > 0 else False
231
+ self.use_rnn = not self.use_bank
232
+ if self.use_rnn:
233
+ self.rnn = nn.LSTM(input_size=embed_dim, hidden_size=embed_dim,
234
+ num_layers=1, batch_first=True, dropout=dropout, bidirectional=True)
235
+ self.shared = shared
236
+ self.prompt_dropout = nn.Dropout(dropout)
237
+
238
+ if self.use_bank:
239
+ print(f'Using feature bank with size {self.pool_size} (dimension: {prompt_dim})')
240
+ if self.use_rnn:
241
+ self.prompt_inproj = nn.Linear(embed_dim * 2, prompt_dim)
242
+ nn.init.kaiming_normal_(
243
+ self.prompt_inproj.weight, a=0, mode='fan_out')
244
+ else:
245
+ if embed_dim != prompt_dim:
246
+ self.prompt_inproj = nn.Linear(embed_dim, prompt_dim, bias=False)
247
+ else:
248
+ self.prompt_inproj = nn.Identity()
249
+
250
+ self.prompt_outproj = nn.Linear(prompt_dim, embed_dim, bias=False)
251
+ nn.init.kaiming_normal_(
252
+ self.prompt_outproj.weight, a=0, mode='fan_out')
253
+
254
+ if shared:
255
+ self.prompt_pool = PromptPoolLearner(prompt_dim, **pool)
256
+ else:
257
+ self.prompt_pool = nn.ModuleList([
258
+ PromptPoolLearner(prompt_dim, **pool)
259
+ for i in range(num_layer)])
260
+ else:
261
+ self.fc = nn.Linear(embed_dim * 2, embed_dim * num_tokens)
262
+
263
+ def forward(self, x, layer=0, istrain=False, gamma=1.0):
264
+ BZ = x.size(0)
265
+ x = x.detach()
266
+ x = rearrange(x, 'b (f n) d -> b f n d', f=self.num_frames)
267
+ x = torch.mean(x, dim=2)
268
+
269
+ if self.use_rnn:
270
+ x, _ = self.rnn(x)
271
+
272
+ ps_loss = x.new_zeros([1])
273
+ if self.use_bank:
274
+ query = self.prompt_inproj(x).flatten(0, 1)
275
+ k = self.num_tokens // self.pool_length
276
+ if self.shared:
277
+ out = self.prompt_pool(query, k, istrain, gamma)
278
+ else:
279
+ out = self.prompt_pool[layer](query, k, istrain, gamma)
280
+
281
+ prompts = rearrange(out['prompts'], '(b f) p d -> b (f p) d', f=self.num_frames)
282
+ prompts = self.prompt_outproj(prompts)
283
+ ps_loss += out.get('ps_loss', 0) * self.num_frames
284
+
285
+ else:
286
+ prompts = self.fc(x)
287
+ prompts = rearrange(prompts, 'b f (p d) -> b (f p) d', p=self.num_tokens)
288
+
289
+ return prompts, ps_loss
290
+
291
+
lavila/models/timesformer.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Part of the code is from https://github.com/m-bain/frozen-in-time/blob/main/model/video_transformer.py
8
+ # Modified by Yue Zhao
9
+ # The original code is under MIT License
10
+
11
+ """
12
+ Implementations of Video Transformers in PyTorch
13
+ A PyTorch implementation of space-time transformer as described in
14
+ 'Frozen in Time: A Joint Image and Video Encoder for End-to-End Retrieval' - https://arxiv.org/abs/2104.00650
15
+ A PyTorch implementation of timesformer as described in
16
+ 'Is Space-Time Attention All You Need for Video Understanding?' - https://arxiv.org/abs/2102.05095
17
+ Acknowledgments:
18
+ - This code builds on Ross Wightman's vision_transformer code in pytorch-image-models:
19
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
20
+ - It is also inspired by lucidrains timesformer implementation:
21
+ https://github.com/lucidrains/TimeSformer-pytorch
22
+ Hacked together by Max Bain
23
+ """
24
+
25
+ from collections import OrderedDict, defaultdict
26
+ from functools import partial, reduce
27
+ import operator
28
+ import copy
29
+
30
+ import torch
31
+ import torch.utils.checkpoint as checkpoint
32
+ from einops import rearrange, repeat
33
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
34
+ from torch import einsum, nn
35
+ import torch.nn.functional as F
36
+ import pdb
37
+
38
+ from lavila.models.prompt_tuning import VisualPromptLearner, CMM
39
+
40
+
41
+ def attn(q, k, v):
42
+ sim = einsum('b i d, b j d -> b i j', q, k)
43
+ attn = sim.softmax(dim=-1)
44
+ out = einsum('b i j, b j d -> b i d', attn, v)
45
+ return out
46
+
47
+
48
+ class Mlp(nn.Module):
49
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
50
+ super().__init__()
51
+ out_features = out_features or in_features
52
+ hidden_features = hidden_features or in_features
53
+ self.fc1 = nn.Linear(in_features, hidden_features)
54
+ self.act = act_layer()
55
+ self.fc2 = nn.Linear(hidden_features, out_features)
56
+ self.drop = nn.Dropout(drop)
57
+
58
+ def forward(self, x):
59
+ x = self.fc1(x)
60
+ x = self.act(x)
61
+ x = self.drop(x)
62
+ x = self.fc2(x)
63
+ x = self.drop(x)
64
+ return x
65
+
66
+
67
+ class VideoPatchEmbed(nn.Module):
68
+ """ Video to Patch Embedding
69
+ """
70
+
71
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
72
+ num_frames=8, ln_pre=False):
73
+ super().__init__()
74
+ img_size = to_2tuple(img_size)
75
+ patch_size = to_2tuple(patch_size)
76
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * num_frames
77
+ self.img_size = img_size
78
+ self.patch_size = patch_size
79
+ self.num_patches = num_patches
80
+ self.num_frames = num_frames
81
+ self.embed_dim = embed_dim
82
+ # ln_pre is inserted to be compatible with CLIP-style model
83
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=not ln_pre)
84
+
85
+ def forward(self, x):
86
+ B, F, C, H, W = x.shape
87
+ assert F <= self.num_frames
88
+ x = x.view(-1, C, H, W)
89
+ x = self.proj(x)
90
+ return x
91
+
92
+
93
+ class VarAttention(nn.Module):
94
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
95
+ initialize='random', num_tokens=0):
96
+ super().__init__()
97
+ self.num_heads = num_heads
98
+ head_dim = dim // num_heads
99
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
100
+ self.scale = qk_scale or head_dim ** -0.5
101
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
102
+ self.proj = nn.Linear(dim, dim)
103
+ if initialize == 'zeros':
104
+ self.qkv.weight.data.fill_(0)
105
+ self.qkv.bias.data.fill_(0)
106
+ # fill proj weight with 1 here to improve training dynamics. Otherwise temporal attention inputs
107
+ # are multiplied by 0*0, which is hard for the model to move out of.
108
+ self.proj.weight.data.fill_(1)
109
+ self.proj.bias.data.fill_(0)
110
+ self.attn_drop = nn.Dropout(attn_drop)
111
+ self.proj_drop = nn.Dropout(proj_drop)
112
+ self.num_tokens = num_tokens
113
+
114
+ def forward(self, x, einops_from, einops_to, einops_dims, cfg):
115
+ style = cfg.get('style', 'default')
116
+ pt_att = cfg.get('pt_att', True)
117
+ n_seg = cfg.get('n_seg', 4)
118
+ if 'VoP' in style:
119
+ return self.forward_VoP(x, einops_from, einops_to, einops_dims, n_seg)
120
+ elif style == 'attall':
121
+ return self.forward_attall(x, pt_att)
122
+ else:
123
+ return self.forward_features(x, einops_from, einops_to, einops_dims, pt_att)
124
+
125
+ def forward_features(self, x, einops_from, einops_to, einops_dims, pt_att=True):
126
+ h = self.num_heads
127
+ num_tokens = self.num_tokens
128
+ if self.num_tokens > 0 and not pt_att:
129
+ prompts = x[:, 1:self.num_tokens+1, :]
130
+ x = torch.cat((
131
+ x[:, :1, :], # cls_token
132
+ x[:, self.num_tokens+1:, :] # patch embeddings
133
+ ), dim=1)
134
+ num_tokens = 0
135
+
136
+ # project x to q, k, v values
137
+ q, k, v = self.qkv(x).chunk(3, dim=-1)
138
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
139
+
140
+ q *= self.scale
141
+
142
+ # splice out CLS token at index 1 (and prompts)
143
+ (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:num_tokens+1], t[:, num_tokens+1:]), (q, k, v)) # Bh x () x d
144
+
145
+ # let CLS token attend to key / values of all patches across time and space
146
+ cls_out = attn(cls_q, k, v) # Bh x (1 + p) x d
147
+ # rearrange across time or space
148
+ q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_)) # Bh x NT x d -> Bhr x s x d
149
+
150
+ # expand cls token keys and values across time or space and concat
151
+ r = q_.shape[0] // cls_k.shape[0]
152
+ cls_k, cls_v = map(lambda t: repeat(t, 'b p d -> (b r) p d', r=r), (cls_k, cls_v)) # Bhr x (1 + p) x d
153
+ k_ = torch.cat((cls_k, k_), dim=1)
154
+ v_ = torch.cat((cls_v, v_), dim=1)
155
+
156
+ # attention
157
+ out = attn(q_, k_, v_)
158
+
159
+ # merge back time or space
160
+ out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) # Bh x NT x d
161
+
162
+ # concat back the cls token
163
+ out = torch.cat((cls_out, out), dim=1) # Bh x (1 + p + NT) x d
164
+
165
+ # merge back the heads
166
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h) # B x (1 + p + NT) x hd
167
+ if self.num_tokens > 0 and not pt_att:
168
+ out = torch.cat((
169
+ out[:, :1, :], # cls_tokens
170
+ prompts,
171
+ out[:, 1:, :] # patch embeddings
172
+ ), dim=1)
173
+
174
+ # to out
175
+ x = self.proj(out)
176
+ x = self.proj_drop(x)
177
+ return x
178
+
179
+ def forward_VoP(self, x, einops_from, einops_to, einops_dims, n_seg=4):
180
+ # position-specific prompts for spatial attention
181
+ h = self.num_heads
182
+ num_tokens = self.num_tokens
183
+
184
+ # project x to q, k, v values
185
+ q, k, v = self.qkv(x).chunk(3, dim=-1) # B x (1+p+NT) x hd
186
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) # Bh x (1+p+NT) x d
187
+
188
+ q *= self.scale
189
+
190
+ # splice out CLS token at index 1 and prompts
191
+ (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:num_tokens+1], t[:, num_tokens+1:]), (q, k, v)) # Bh x () x d
192
+ # let CLS token attend to key / values of all patches across time and space
193
+ cls_out = attn(cls_q[:, :1, :], k, v) # cls token: Bh x 1 x d
194
+
195
+ # segment prompts into s segments in time
196
+ pstep = num_tokens // n_seg
197
+ pseg = [range(st, en) for st, en in zip(range(1, num_tokens+1, pstep), range(pstep+1, num_tokens+2, pstep))]
198
+ p_q, p_k, p_v = map(lambda t: rearrange(t[:, pseg, :], 'b s p d -> (b s) p d'), (cls_q, cls_k, cls_v)) # prompt query: (Bh x n_seg) x p_per_seg x d
199
+
200
+ # segment patch embeddings into s segments in time
201
+ q_, k_, v_ = map(lambda t: rearrange(t, 'b (f n) d -> b f n d', **einops_dims), (q_, k_, v_)) # Bh x T x N x d
202
+ num_frames = k_.size(1)
203
+ tstep = num_frames // n_seg
204
+ tseg = [range(st, en) for st, en in zip(range(0, num_frames, tstep), range(tstep, num_frames+1, tstep))]
205
+ q_, k_, v_ = map(lambda t: t[:, tseg, ...], (q_, k_, v_)) # Bh x n_seg x f_per_seg x n x d
206
+ q_, k_, v_ = map(lambda t: rearrange(t, 'b s f n d -> (b s) (f n) d'), (q_, k_, v_)) # (Bh x n_seg) x (f_per_seg x n) x d
207
+
208
+ # concatenate prompts and patch embeddings
209
+ k_, v_ = map(lambda t: torch.cat((t[0], t[1]), dim=1), ((p_k, k_), (p_v, v_)))
210
+ p_out = attn(p_q, k_, v_) # (Bh x n_seg) x p_per_seg x d
211
+ out = attn(q_, k_, v_) # (Bh x n_seg) x (f_per_seg x n) x d
212
+ p_out = rearrange(p_out, '(b s) p d -> b (s p) d', s=n_seg) # Bh x p x d
213
+ out = rearrange(out, '(b s) (f n) d -> b (s f n) d', s=n_seg, f=tstep) # Bh x NT x d
214
+
215
+ # merge tokens
216
+ out = torch.cat((cls_out, p_out, out), dim=1) # Bh x (1+p+NT) x d
217
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h) # B x (NT+1) x hd
218
+
219
+ # to out
220
+ x = self.proj(out)
221
+ x = self.proj_drop(x)
222
+ return x
223
+
224
+ def forward_attall(self, x, pt_att=True):
225
+ h = self.num_heads
226
+ if self.num_tokens > 0 and not pt_att:
227
+ prompts = x[:, 1:self.num_tokens+1, :]
228
+ x = torch.cat((
229
+ x[:, :1, :], # cls_token
230
+ x[:, self.num_tokens+1:, :] # patch embeddings
231
+ ), dim=1)
232
+
233
+ # project x to q, k, v values
234
+ q, k, v = self.qkv(x).chunk(3, dim=-1)
235
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
236
+
237
+ q *= self.scale
238
+
239
+ # all tokens attend to all tokens
240
+ out = attn(q, k, v)
241
+
242
+ # merge back the heads
243
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h) # B x (1 + p + NT) x hd
244
+ if self.num_tokens > 0 and not pt_att:
245
+ out = torch.cat((
246
+ out[:, :1, :], # cls_tokens
247
+ prompts,
248
+ out[:, 1:, :] # patch embeddings
249
+ ), dim=1)
250
+
251
+ # to out
252
+ x = self.proj(out)
253
+ x = self.proj_drop(x)
254
+ return x
255
+
256
+
257
+ class SpaceTimeBlock(nn.Module):
258
+
259
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
260
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, time_init='zeros',
261
+ attention_style='frozen-in-time', is_tanh_gating=False, num_tokens=0, split_st=False):
262
+ super().__init__()
263
+
264
+ self.split_st = split_st # split spatial and temporal prompts
265
+ if split_st:
266
+ num_tokens = num_tokens // 2
267
+ self.num_tokens = num_tokens # learnable prompts
268
+
269
+ self.norm1 = norm_layer(dim)
270
+ self.attn = VarAttention(
271
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, num_tokens=num_tokens)
272
+
273
+ self.timeattn = VarAttention(
274
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, num_tokens=num_tokens,
275
+ initialize=time_init)
276
+
277
+ if is_tanh_gating:
278
+ self.alpha_timeattn = nn.Parameter(torch.zeros([]))
279
+
280
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
281
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
282
+ self.norm2 = norm_layer(dim)
283
+ mlp_hidden_dim = int(dim * mlp_ratio)
284
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
285
+ self.norm3 = norm_layer(dim)
286
+
287
+ self.attention_style = attention_style
288
+
289
+ def forward(self, x, einops_from_space, einops_to_space, einops_from_time, einops_to_time,
290
+ time_n, space_f, use_checkpoint=False, pt_spt=True, pt_tmp=True, style='default', n_seg=4):
291
+ if self.split_st:
292
+ spatial_prompts = x[:, 1:self.num_tokens+1, :]
293
+ x = torch.cat((
294
+ x[:, :1, :], # cls_token
295
+ x[:, self.num_tokens+1:, :] # temporal prompts and patch embeddings
296
+ ), dim=1)
297
+
298
+ if use_checkpoint:
299
+ time_output = checkpoint.checkpoint(
300
+ self.timeattn, self.norm3(x), einops_from_time, einops_to_time, {"n": time_n}, {'pt_att': pt_tmp}
301
+ )
302
+ else:
303
+ time_output = self.timeattn(self.norm3(x), einops_from_time, einops_to_time, {"n": time_n}, {'pt_att': pt_tmp})
304
+ if hasattr(self, "alpha_timeattn"):
305
+ time_output = torch.tanh(self.alpha_timeattn) * time_output
306
+ time_residual = x + time_output
307
+
308
+ if self.split_st:
309
+ temporal_prompts = time_residual[:, 1:self.num_tokens+1, :]
310
+ time_residual = torch.cat((
311
+ time_residual[:, :1, :], # cls_token
312
+ spatial_prompts,
313
+ time_residual[:, self.num_tokens+1:, :] # patch embeddings
314
+ ), dim=1)
315
+
316
+ cfg = {'style': style, 'pt_att': pt_spt, 'n_seg': n_seg}
317
+ if use_checkpoint:
318
+ space_output = checkpoint.checkpoint(
319
+ self.attn, self.norm1(time_residual), einops_from_space, einops_to_space, {"f": space_f}, cfg
320
+ )
321
+ else:
322
+ space_output = self.attn(self.norm1(time_residual), einops_from_space,
323
+ einops_to_space, {"f": space_f}, cfg)
324
+ if self.attention_style == 'frozen-in-time':
325
+ space_residual = x + self.drop_path(space_output)
326
+ else:
327
+ raise NotImplementedError
328
+
329
+ if self.split_st:
330
+ space_residual = torch.cat((
331
+ space_residual[:, :self.num_tokens+1, :], # cls_token and spacial prompts
332
+ temporal_prompts,
333
+ space_residual[:, self.num_tokens+1:, :] # patch embeddings
334
+ ), dim=1)
335
+
336
+ x = space_residual + self.drop_path(self.mlp(self.norm2(space_residual)))
337
+
338
+ return x
339
+
340
+
341
+ class SpaceTimeTransformer(nn.Module):
342
+ """ Vision Transformer
343
+ A PyTorch impl of : `Space-Time Transformer` from Frozen-in-time - by Max Bain.
344
+ https://arxiv.org/abs/2104.00650
345
+ Based off:
346
+ - ViT implementation from the timm library [https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py]
347
+ lucidrains timesformer implementation [https://github.com/lucidrains/TimeSformer-pytorch].
348
+ Notable differences:
349
+ - allows for variable length input frames (<= num_frames)
350
+ - allows for variable length input resolution (<= (img_size, img_size)) [UNTESTED]
351
+ - different attention block mechanism
352
+ """
353
+
354
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
355
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
356
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,
357
+ num_frames=8, time_init='rand', attention_style='frozen-in-time', ln_pre=False,
358
+ act_layer=nn.GELU, is_tanh_gating=False, tune_bias=False, prompt_cfg={}):
359
+ """
360
+ Args:
361
+ img_size (int, tuple): input image size
362
+ patch_size (int, tuple): patch size
363
+ in_chans (int): number of input channels
364
+ num_classes (int): number of classes for classification head
365
+ embed_dim (int): embedding dimension
366
+ depth (int): depth of transformer
367
+ num_heads (int): number of attention heads
368
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
369
+ qkv_bias (bool): enable bias for qkv if True
370
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
371
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
372
+ drop_rate (float): dropout rate
373
+ attn_drop_rate (float): attention dropout rate
374
+ drop_path_rate (float): stochastic depth rate
375
+ hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
376
+ norm_layer: (nn.Module): normalization layer
377
+ num_frames: (int) maximum number of frames expected as input
378
+ time_init: (str) how to initialise the time attention layer, 'zeros' allows for the timesformer to start off
379
+ as ViT.
380
+ attention_style: (str) how to attend to space and time.
381
+ """
382
+ super().__init__()
383
+ self.num_classes = num_classes
384
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
385
+ self.num_frames = num_frames
386
+ self.embed_dim = embed_dim
387
+ self.tune_bias = tune_bias
388
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
389
+ print("######USING ATTENTION STYLE: ", attention_style)
390
+ self.param_list = []
391
+ if hybrid_backbone is not None:
392
+ raise NotImplementedError('hybrid backbone not implemented')
393
+ else:
394
+ self.patch_embed = VideoPatchEmbed(
395
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=num_frames, ln_pre=ln_pre)
396
+ self.param_list += list(self.patch_embed.parameters())
397
+ num_patches = self.patch_embed.num_patches
398
+ self.patches_per_frame = num_patches // num_frames
399
+
400
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
401
+ self.pos_embed = nn.Parameter(
402
+ torch.zeros(1, self.patches_per_frame + 1,
403
+ embed_dim)) # remember to take pos_embed[1:] for tiling over time
404
+ self.temporal_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim))
405
+ self.param_list += [self.cls_token, self.pos_embed, self.temporal_embed]
406
+
407
+ if ln_pre:
408
+ self.ln_pre = nn.LayerNorm(embed_dim)
409
+ if self.tune_bias:
410
+ self.param_list += [m for n, m in self.ln_pre.named_parameters() if 'bias' not in n]
411
+ else:
412
+ self.param_list += list(self.ln_pre.parameters())
413
+ else:
414
+ self.ln_pre = None
415
+
416
+ self.pos_drop = nn.Dropout(p=drop_rate)
417
+
418
+ # config for prompts
419
+ self.num_tokens = prompt_cfg.get('num_tokens', 0)
420
+ self.prompt_dim = prompt_cfg.get('prompt_dim', 768)
421
+ self.pt_spt = prompt_cfg.pop('pt_spt', True)
422
+ self.pt_tmp = prompt_cfg.pop('pt_tmp', True)
423
+ self.style = prompt_cfg.pop('style', 'default')
424
+ self.query = prompt_cfg.pop('query', 'cls')
425
+ self.n_seg = prompt_cfg.pop('n_seg', 4)
426
+ self.k_s = prompt_cfg.pop('K_s', depth)
427
+ self.st = prompt_cfg.pop('st', 0)
428
+ self.end = prompt_cfg.pop('end', depth)
429
+ assert self.st <= self.end
430
+ if self.style == 'default':
431
+ print(f'Prompting {self.st}-{self.end} layer of the visual backbone')
432
+ elif self.style == 'VoP_c' and self.k_s < depth:
433
+ self.prompt_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim))
434
+ elif self.style == 'VoP_c_pool':
435
+ self.prompt_temp_embed = nn.Parameter(torch.zeros(1, self.n_seg, embed_dim))
436
+ trunc_normal_(self.prompt_temp_embed, std=.02)
437
+
438
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
439
+
440
+ blocks = []
441
+ for i in range(depth):
442
+ stblk_cfg = {}
443
+ if self.num_tokens > 0:
444
+ stblk_cfg = {'num_tokens': prompt_cfg['num_tokens'], 'split_st': prompt_cfg.get('split_st', False)}
445
+ blocks.append(
446
+ SpaceTimeBlock(
447
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
448
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, time_init=time_init,
449
+ attention_style=attention_style, act_layer=act_layer, is_tanh_gating=is_tanh_gating, **stblk_cfg)
450
+ )
451
+
452
+ self.blocks = nn.ModuleList(blocks)
453
+ self.norm = norm_layer(embed_dim)
454
+ if self.tune_bias:
455
+ self.param_list += reduce(operator.add, [[m for n, m in x.named_parameters() if 'bias' not in n] for x in self.blocks])
456
+ self.param_list += [m for n, m in self.norm.named_parameters() if 'bias' not in n]
457
+ else:
458
+ self.param_list += reduce(operator.add, [list(x.parameters()) for x in self.blocks])
459
+ self.param_list += list(self.norm.parameters())
460
+
461
+ # Representation layer
462
+ if representation_size:
463
+ self.num_features = representation_size
464
+ self.pre_logits = nn.Sequential(OrderedDict([
465
+ ('fc', nn.Linear(embed_dim, representation_size)),
466
+ ('act', nn.Tanh())
467
+ ]))
468
+ if self.tune_bias:
469
+ self.param_list += [m for n, m in self.pre_logits.named_parameters() if 'bias' not in n]
470
+ else:
471
+ self.param_list += list(self.pre_logits.parameters())
472
+ else:
473
+ self.pre_logits = nn.Identity()
474
+
475
+ # Classifier head
476
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
477
+
478
+ trunc_normal_(self.pos_embed, std=.02)
479
+ trunc_normal_(self.cls_token, std=.02)
480
+
481
+ # if num_frames > 1, then we perform ViT inflation and initialise time attention to zero so not necessary.
482
+ if num_frames == 1:
483
+ self.apply(self._init_weights)
484
+
485
+ # einops transformations
486
+ self.einops_from_space = 'b (f n) d'
487
+ self.einops_to_space = '(b f) n d'
488
+ self.einops_from_time = 'b (f n) d'
489
+ self.einops_to_time = '(b n) f d'
490
+
491
+ # freeze the backbone and only learn the prompts
492
+ self.prompt_learner = None
493
+ if self.num_tokens > 0:
494
+ if 'VoP_c' in self.style:
495
+ pool = prompt_cfg.pop('pool', {}) if 'pool' in self.style else {}
496
+ if self.k_s > 0:
497
+ self.prompt_generator = CMM(self.num_tokens // self.n_seg, self.n_seg, embed_dim, self.prompt_dim, num_layer=self.k_s, \
498
+ shared=prompt_cfg.get('deep_shared', False), pool=pool)
499
+ n_prompt_layer = depth - self.k_s
500
+
501
+ else:
502
+ n_prompt_layer = self.end - self.st
503
+
504
+ if n_prompt_layer > 0:
505
+ prompt_cfg['num_layers'] = n_prompt_layer
506
+ prompt_cfg['prompt_dim'] = embed_dim
507
+ self.prompt_learner = VisualPromptLearner(patch_size, embed_dim, **prompt_cfg)
508
+
509
+ for p in self.param_list:
510
+ p.requies_grad = False
511
+
512
+ def _init_weights(self, m):
513
+ if isinstance(m, nn.Linear):
514
+ trunc_normal_(m.weight, std=.02)
515
+ if isinstance(m, nn.Linear) and m.bias is not None:
516
+ nn.init.constant_(m.bias, 0)
517
+ elif isinstance(m, nn.LayerNorm):
518
+ nn.init.constant_(m.bias, 0)
519
+ nn.init.constant_(m.weight, 1.0)
520
+
521
+ @torch.jit.ignore
522
+ def no_weight_decay(self):
523
+ return {'pos_embed', 'cls_token'}
524
+
525
+ def get_classifier(self):
526
+ return self.head
527
+
528
+ def reset_classifier(self, num_classes, global_pool=''):
529
+ self.num_classes = num_classes
530
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
531
+
532
+ def forward_features(self, x, use_checkpoint=False, cls_at_last=True, istrain=False, gamma=1.0):
533
+ # print(x.shape)
534
+ b, curr_frames, channels, _, _ = x.shape
535
+ x = self.patch_embed(x)
536
+ x = x.flatten(2).transpose(2, 1)
537
+ x = x.reshape(b, -1, self.patch_embed.embed_dim)
538
+
539
+ BF = x.shape[0]
540
+ cls_tokens = self.cls_token.expand(BF, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
541
+ x = torch.cat((cls_tokens, x), dim=1)
542
+ # positional embed needs to be tiled for each frame (this does [1,2,3] --> [1,2,3,1,2,3]...)
543
+ cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
544
+ tile_pos_embed = self.pos_embed[:, 1:, :].repeat(1, self.num_frames, 1)
545
+ # temporal embed needs to be repeated within each frame (this does [1,2,3] --> [1,1,1,2,2,2,3,3,3]...)
546
+ tile_temporal_embed = self.temporal_embed.repeat_interleave(self.patches_per_frame, 1)
547
+ total_pos_embed = tile_pos_embed + tile_temporal_embed
548
+ total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) # 1 x (NT + 1) x D
549
+
550
+ curr_patches = x.shape[1]
551
+ x = x + total_pos_embed[:, :curr_patches] # B x (NT + 1) x D
552
+ ps_loss = x.new_zeros([1])
553
+ # incorporate prompts
554
+ if self.num_tokens > 0:
555
+ if 'VoP_c' in self.style and self.k_s > 0:
556
+ ctx, ps = self.prompt_generator(x[:, 1:, :], 0, istrain=istrain, gamma=gamma)
557
+ ps_loss += ps
558
+ if self.prompt_generator.use_bank:
559
+ prompt_temp_embed = self.prompt_temp_embed.repeat_interleave(self.num_tokens // self.n_seg, 1)
560
+ ctx = ctx + prompt_temp_embed
561
+
562
+ elif self.prompt_learner is not None:
563
+ ctx, ps = self.prompt_learner(x[:, :1, :], 0, istrain=istrain, gamma=gamma)
564
+ ps_loss += ps
565
+ if ctx.size(0) != BF:
566
+ ctx = ctx.expand(BF, -1, -1)
567
+
568
+ x = torch.cat((
569
+ x[:, :1, :], # cls_token
570
+ ctx,
571
+ x[:, 1:, :]
572
+ ), dim=1)
573
+
574
+ if self.ln_pre is not None:
575
+ x = self.ln_pre(x)
576
+ x = self.pos_drop(x)
577
+ n = self.patches_per_frame
578
+ f = curr_frames
579
+
580
+ for i, blk in enumerate(self.blocks):
581
+ if self.num_tokens > 0 and i > 0 and i >= self.st and i < self.end:
582
+ if 'VoP_c' in self.style:
583
+ if i < self.k_s:
584
+ ctx, ps = self.prompt_generator(x[:, self.num_tokens+1:, :], i, istrain=istrain, gamma=gamma)
585
+ ps_loss += ps
586
+ if self.prompt_generator.use_bank:
587
+ prompt_temp_embed = self.prompt_temp_embed.repeat_interleave(self.num_tokens // self.n_seg, 1)
588
+ ctx = ctx + prompt_temp_embed
589
+ else:
590
+ ctx, ps = self.prompt_learner(x[:, :1, :], i-self.k_s, istrain=istrain, gamma=gamma)
591
+ ps_loss += ps
592
+
593
+ if 'pool' in self.style:
594
+ prompt_embed = self.prompt_temp_embed.repeat_interleave(self.num_tokens // self.n_seg, 1)
595
+ else:
596
+ prompt_embed = self.prompt_embed.repeat_interleave(self.num_tokens // self.num_frames, 1)
597
+ ctx = ctx + prompt_embed
598
+ if ctx.size(0) != BF:
599
+ ctx = ctx.expand(BF, -1, -1)
600
+
601
+ elif (i - self.st) < self.prompt_learner.num_layers:
602
+ ctx, ps = self.prompt_learner(x[:, :1, :], i-self.st, istrain=istrain, gamma=gamma)
603
+ ps_loss += ps
604
+ if ctx.size(0) != BF:
605
+ ctx = ctx.expand(BF, -1, -1)
606
+
607
+ x = torch.cat((
608
+ x[:, :1, :], # cls_token
609
+ ctx,
610
+ x[:, self.num_tokens+1:, :]
611
+ ), dim=1)
612
+
613
+ style = 'default' if i >= self.k_s else self.style
614
+ pt_tmp = self.pt_tmp if i >= self.st and i < self.end else False
615
+ pt_spt = self.pt_spt if i >= self.st and i < self.end else False
616
+ x = blk(x, self.einops_from_space, self.einops_to_space, self.einops_from_time,
617
+ self.einops_to_time,
618
+ time_n=n, space_f=f, use_checkpoint=use_checkpoint, pt_spt=pt_spt,
619
+ pt_tmp=pt_tmp, style=style, n_seg=self.n_seg)
620
+
621
+ if cls_at_last:
622
+ x = self.norm(x)
623
+ x = x[:, 0]
624
+ x = self.pre_logits(x)
625
+
626
+ return x, ps_loss
627
+ else:
628
+ return self.norm(x), ps_loss
629
+
630
+ def forward(self, x, use_checkpoint=False, istrain=False, gamma=1.0):
631
+ # Note: B C T H W => B T C H W
632
+ # The default input order is different from the one in Frozen-in-Time
633
+ x = x.permute(0, 2, 1, 3, 4).contiguous()
634
+ x, ps_loss = self.forward_features(x, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma)
635
+ x = self.head(x)
636
+
637
+ return x, ps_loss
638
+
639
+ def train(self, mode=True):
640
+ if not isinstance(mode, bool):
641
+ raise ValueError("training mode is expected to be boolean")
642
+ self.training = mode
643
+ for m in self.modules():
644
+ m.training = mode
645
+
646
+ if mode and self.num_tokens > 0:
647
+ for n, m in self.named_modules():
648
+ if 'prompt' not in n:
649
+ m.training = False
650
+
lavila/models/tokenizer.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Part of the code is from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
8
+ # Modified by Yue Zhao
9
+ # The original code is under MIT License
10
+
11
+ import gzip
12
+ import html
13
+ import os
14
+ from functools import lru_cache
15
+
16
+ import ftfy
17
+ import regex as re
18
+ import torch
19
+
20
+ from transformers import (BertTokenizer, DistilBertTokenizer, GPT2Tokenizer)
21
+
22
+
23
+ @lru_cache()
24
+ def default_bpe():
25
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
26
+
27
+
28
+ @lru_cache()
29
+ def bytes_to_unicode():
30
+ """
31
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
32
+ The reversible bpe codes work on unicode strings.
33
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
34
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
35
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
36
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
37
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
38
+ """
39
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
40
+ cs = bs[:]
41
+ n = 0
42
+ for b in range(2**8):
43
+ if b not in bs:
44
+ bs.append(b)
45
+ cs.append(2**8+n)
46
+ n += 1
47
+ cs = [chr(n) for n in cs]
48
+ return dict(zip(bs, cs))
49
+
50
+
51
+ def get_pairs(word):
52
+ """Return set of symbol pairs in a word.
53
+ Word is represented as tuple of symbols (symbols being variable-length strings).
54
+ """
55
+ pairs = set()
56
+ prev_char = word[0]
57
+ for char in word[1:]:
58
+ pairs.add((prev_char, char))
59
+ prev_char = char
60
+ return pairs
61
+
62
+
63
+ def basic_clean(text):
64
+ text = ftfy.fix_text(text)
65
+ text = html.unescape(html.unescape(text))
66
+ return text.strip()
67
+
68
+
69
+ def whitespace_clean(text):
70
+ text = re.sub(r'\s+', ' ', text)
71
+ text = text.strip()
72
+ return text
73
+
74
+
75
+ class SimpleTokenizer(object):
76
+ def __init__(self, bpe_path: str = default_bpe()):
77
+ self.byte_encoder = bytes_to_unicode()
78
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
79
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
80
+ merges = merges[1:49152-256-2+1]
81
+ merges = [tuple(merge.split()) for merge in merges]
82
+ vocab = list(bytes_to_unicode().values())
83
+ vocab = vocab + [v+'</w>' for v in vocab]
84
+ for merge in merges:
85
+ vocab.append(''.join(merge))
86
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
87
+ self.encoder = dict(zip(vocab, range(len(vocab))))
88
+ self.decoder = {v: k for k, v in self.encoder.items()}
89
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
90
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
91
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
92
+
93
+ def bpe(self, token):
94
+ if token in self.cache:
95
+ return self.cache[token]
96
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
97
+ pairs = get_pairs(word)
98
+
99
+ if not pairs:
100
+ return token+'</w>'
101
+
102
+ while True:
103
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
104
+ if bigram not in self.bpe_ranks:
105
+ break
106
+ first, second = bigram
107
+ new_word = []
108
+ i = 0
109
+ while i < len(word):
110
+ try:
111
+ j = word.index(first, i)
112
+ new_word.extend(word[i:j])
113
+ i = j
114
+ except:
115
+ new_word.extend(word[i:])
116
+ break
117
+
118
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
119
+ new_word.append(first+second)
120
+ i += 2
121
+ else:
122
+ new_word.append(word[i])
123
+ i += 1
124
+ new_word = tuple(new_word)
125
+ word = new_word
126
+ if len(word) == 1:
127
+ break
128
+ else:
129
+ pairs = get_pairs(word)
130
+ word = ' '.join(word)
131
+ self.cache[token] = word
132
+ return word
133
+
134
+ def encode(self, text):
135
+ bpe_tokens = []
136
+ text = whitespace_clean(basic_clean(text)).lower()
137
+ for token in re.findall(self.pat, text):
138
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
139
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
140
+ return bpe_tokens
141
+
142
+ def decode(self, tokens):
143
+ text = ''.join([self.decoder[token] for token in tokens])
144
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
145
+ return text
146
+
147
+ def __call__(self, texts, context_length=77):
148
+ if isinstance(texts, str):
149
+ texts = [texts]
150
+
151
+ sot_token = self.encoder["<|startoftext|>"]
152
+ eot_token = self.encoder["<|endoftext|>"]
153
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
154
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
155
+
156
+ for i, tokens in enumerate(all_tokens):
157
+ tokens = tokens[:context_length]
158
+ result[i, :len(tokens)] = torch.tensor(tokens)
159
+
160
+ if len(result) == 1:
161
+ return result[0]
162
+ return result
163
+
164
+
165
+ class MyBertTokenizer(object):
166
+ def __init__(self, name=''):
167
+ print('=> Initialize MyBertTokenizer ({})'.format(name))
168
+ self.tokenizer = BertTokenizer.from_pretrained(name)
169
+ self.bos_token_id, self.eos_token_id = self.tokenizer('').input_ids
170
+ self.pad_token_id = 0
171
+
172
+ def __call__(self, texts, context_length=77):
173
+ if isinstance(texts, str):
174
+ texts = [texts]
175
+ result = torch.zeros(len(texts), context_length, dtype=torch.long)
176
+ mask = torch.zeros(len(texts), context_length, dtype=torch.float32)
177
+ for i, text in enumerate(texts):
178
+ tokens = self.tokenizer(text)
179
+ input_ids = tokens.input_ids[:context_length]
180
+ attention_mask = tokens.attention_mask[:context_length]
181
+ result[i, :len(input_ids)] = torch.tensor(input_ids)
182
+ mask[i, :len(attention_mask)] = torch.tensor(attention_mask)
183
+
184
+ if len(result) == 1:
185
+ return result[0], mask[0]
186
+ return result, mask
187
+
188
+
189
+ class MyDistilBertTokenizer(object):
190
+ def __init__(self, name=''):
191
+ print('=> Initialize MyDistilBertTokenizer ({})'.format(name))
192
+ self.tokenizer = DistilBertTokenizer.from_pretrained(name)
193
+
194
+ def __call__(self, texts, context_length=77):
195
+ if isinstance(texts, str):
196
+ texts = [texts]
197
+ result = torch.zeros(len(texts), context_length, dtype=torch.long)
198
+ mask = torch.zeros(len(texts), context_length, dtype=torch.float32)
199
+ for i, text in enumerate(texts):
200
+ tokens = self.tokenizer(text)
201
+ input_ids = tokens.input_ids[:context_length]
202
+ attention_mask = tokens.attention_mask[:context_length]
203
+ result[i, :len(input_ids)] = torch.tensor(input_ids)
204
+ mask[i, :len(attention_mask)] = torch.tensor(attention_mask)
205
+
206
+ if len(result) == 1:
207
+ return result[0], mask[0]
208
+ return result, mask
209
+
210
+
211
+ class MyGPT2Tokenizer(object):
212
+ def __init__(self, name='', add_bos=False):
213
+ print('=> Initialize MyGPT2Tokenizer ({})'.format(name))
214
+ self.tokenizer = GPT2Tokenizer.from_pretrained(name)
215
+ self.bos_token_id, self.eos_token_id = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
216
+ self.pad_token_id = 0
217
+ self.add_bos = add_bos
218
+ # num_added_tokens = self.tokenizer.add_special_tokens({'pad_token': "[PAD]"})
219
+ # print('num_added_tokens={}'.format(len(num_added_tokens)))
220
+
221
+ def __call__(self, texts, context_length=77):
222
+ if isinstance(texts, str):
223
+ texts = [texts]
224
+ result = torch.zeros(len(texts), context_length, dtype=torch.long)
225
+ for i, text in enumerate(texts):
226
+ tokens = self.tokenizer(text)
227
+ if not self.add_bos:
228
+ input_ids = tokens.input_ids[:context_length - 1]
229
+ input_ids = input_ids + [self.tokenizer.eos_token_id] # add [EOS]
230
+ else:
231
+ input_ids = tokens.input_ids[:context_length - 2]
232
+ input_ids = [self.tokenizer.bos_token_id] + input_ids + [self.tokenizer.eos_token_id] # add [EOS]
233
+ # attention_mask = tokens.attention_mask[:context_length]
234
+ # attention_mask = attention_mask + [0.] * pad_length
235
+ result[i, :len(input_ids)] = torch.tensor(input_ids)
236
+
237
+ if len(result) == 1:
238
+ return result[0]
239
+ return result
lavila/models/utils.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import OrderedDict
8
+ import functools
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def inflate_positional_embeds(
14
+ current_model_state_dict, new_state_dict,
15
+ num_frames=4,
16
+ load_temporal_fix='bilinear',
17
+ ):
18
+ # allow loading of timesformer with fewer num_frames
19
+ curr_keys = list(current_model_state_dict.keys())
20
+ temporal_embed = ['visual.temporal_embed', 'visual.prompt_embed']
21
+ for x in temporal_embed:
22
+ if x in new_state_dict and x in curr_keys:
23
+ load_temporal_embed = new_state_dict[x]
24
+ load_num_frames = load_temporal_embed.shape[1]
25
+ curr_num_frames = num_frames
26
+ embed_dim = load_temporal_embed.shape[2]
27
+
28
+ if load_num_frames != curr_num_frames:
29
+ if load_num_frames > curr_num_frames:
30
+ print(f'### loaded SpaceTimeTransformer model has MORE frames than current...'
31
+ f'### loading {x} weights, filling in the extras via {load_temporal_fix}')
32
+ new_temporal_embed = load_temporal_embed[:, :curr_num_frames, :]
33
+ else:
34
+ print(f'### loaded SpaceTimeTransformer model has FEWER frames than current...'
35
+ f'### loading {x} weights, filling in the extras via {load_temporal_fix}')
36
+ if load_temporal_fix == 'zeros':
37
+ new_temporal_embed = torch.zeros([load_temporal_embed.shape[0], curr_num_frames, embed_dim])
38
+ new_temporal_embed[:, :load_num_frames] = load_temporal_embed
39
+ elif load_temporal_fix in ['interp', 'bilinear']:
40
+ # interpolate
41
+ # unsqueeze so pytorch thinks its an image
42
+ mode = 'nearest'
43
+ if load_temporal_fix == 'bilinear':
44
+ mode = 'bilinear'
45
+ load_temporal_embed = load_temporal_embed.unsqueeze(0)
46
+ new_temporal_embed = F.interpolate(load_temporal_embed,
47
+ (curr_num_frames, embed_dim), mode=mode).squeeze(0)
48
+ else:
49
+ raise NotImplementedError
50
+ new_state_dict[x] = new_temporal_embed
51
+ # allow loading with smaller spatial patches. assumes custom border crop, to append the
52
+ # border patches to the input sequence
53
+ if 'visual.pos_embed' in new_state_dict and 'visual.pos_embed' in curr_keys:
54
+ load_pos_embed = new_state_dict['visual.pos_embed']
55
+ load_num_patches = load_pos_embed.shape[1]
56
+ curr_pos_embed = current_model_state_dict['visual.pos_embed']
57
+ if load_num_patches != curr_pos_embed.shape[1]:
58
+ raise NotImplementedError(
59
+ 'Loading models with different spatial resolution / patch number not yet implemented, sorry.')
60
+
61
+ return new_state_dict
62
+
63
+
64
+ def rsetattr(obj, attr, val):
65
+ pre, _, post = attr.rpartition('.')
66
+ return setattr(rgetattr(obj, pre) if pre else obj, post, val)
67
+
68
+
69
+ def rgetattr(obj, attr, *args):
70
+ def _getattr(obj, attr):
71
+ return getattr(obj, attr, *args)
72
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
73
+
74
+
75
+ # util functions to convert CLIP-style model keys to TimeSformer-style
76
+ def remap_keys(clip_state_dict, transformer_layers=12):
77
+ remapped_state_dict = OrderedDict()
78
+ key_mapping = {
79
+ "class_embedding": "cls_token",
80
+ "positional_embedding": "pos_embed",
81
+ "conv1.weight": "patch_embed.proj.weight",
82
+ "ln_pre.weight": "ln_pre.weight",
83
+ "ln_pre.bias": "ln_pre.bias",
84
+ "ln_post.weight": "norm.weight",
85
+ "ln_post.bias": "norm.bias",
86
+ }
87
+ for layer in range(transformer_layers):
88
+ key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_weight"] = f"blocks.{layer}.attn.qkv.weight"
89
+ key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_bias"] = f"blocks.{layer}.attn.qkv.bias"
90
+ key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.weight"] = f"blocks.{layer}.attn.proj.weight"
91
+ key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.bias"] = f"blocks.{layer}.attn.proj.bias"
92
+ key_mapping[f"transformer.resblocks.{layer}.ln_1.weight"] = f"blocks.{layer}.norm1.weight"
93
+ key_mapping[f"transformer.resblocks.{layer}.ln_1.bias"] = f"blocks.{layer}.norm1.bias"
94
+ key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.weight"] = f"blocks.{layer}.mlp.fc1.weight"
95
+ key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.bias"] = f"blocks.{layer}.mlp.fc1.bias"
96
+ key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.weight"] = f"blocks.{layer}.mlp.fc2.weight"
97
+ key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.bias"] = f"blocks.{layer}.mlp.fc2.bias"
98
+ key_mapping[f"transformer.resblocks.{layer}.ln_2.weight"] = f"blocks.{layer}.norm2.weight"
99
+ key_mapping[f"transformer.resblocks.{layer}.ln_2.bias"] = f"blocks.{layer}.norm2.bias"
100
+
101
+ for key in clip_state_dict:
102
+ if key == 'proj':
103
+ continue # due to possible dim mismatch, we load this later
104
+ if key == "class_embedding":
105
+ clip_state_dict[key] = clip_state_dict[key].unsqueeze(0).unsqueeze(0)
106
+ if key == "positional_embedding":
107
+ clip_state_dict[key] = clip_state_dict[key].unsqueeze(0)
108
+ remapped_state_dict[key_mapping[key]] = clip_state_dict[key]
109
+
110
+ return remapped_state_dict
lavila/utils/config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import yaml
3
+
4
+ def load_base_cfg():
5
+ with open('configs/base.yml', 'r') as fp:
6
+ cfg = yaml.load(fp, Loader=yaml.SafeLoader)
7
+ return cfg
8
+
9
+ def load_cfg(cfg_file):
10
+ cfg = load_base_cfg()
11
+ with open(cfg_file, 'r') as fp:
12
+ exp_cfg = yaml.load(fp, Loader=yaml.SafeLoader)
13
+
14
+ cfg['model'].update(exp_cfg.get('model', {}))
15
+ cfg['data'].update(exp_cfg.get('data', {}))
16
+ dataset = cfg['data'].get('dataset')
17
+ return cfg
18
+
lavila/utils/evaluation.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def accuracy(output, target, topk=(1,)):
12
+ """Computes the accuracy over the k top predictions for the specified values of k"""
13
+ with torch.no_grad():
14
+ maxk = max(topk)
15
+ batch_size = target.size(0)
16
+
17
+ _, pred = output.topk(maxk, 1, True, True)
18
+ pred = pred.t()
19
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
20
+
21
+ res = []
22
+ for k in topk:
23
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
24
+ res.append(correct_k.mul_(100.0 / batch_size))
25
+ return res
26
+
27
+
28
+ def get_mean_accuracy(cm):
29
+ list_acc = []
30
+ for i in range(len(cm)):
31
+ acc = 0
32
+ if cm[i, :].sum() > 0:
33
+ acc = cm[i, i] / cm[i, :].sum()
34
+ list_acc.append(acc)
35
+
36
+ return 100 * np.mean(list_acc), 100 * np.trace(cm) / np.sum(cm)
lavila/utils/evaluation_charades.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+
9
+
10
+ def compute_map(submission_array, gt_array):
11
+ """ Returns mAP, weighted mAP, and AP array """
12
+ m_aps = []
13
+ n_classes = submission_array.shape[1]
14
+ for oc_i in range(n_classes):
15
+ sorted_idxs = np.argsort(-submission_array[:, oc_i])
16
+ tp = gt_array[:, oc_i][sorted_idxs] == 1
17
+ fp = np.invert(tp)
18
+ n_pos = tp.sum()
19
+ if n_pos < 0.1:
20
+ m_aps.append(float('nan'))
21
+ continue
22
+ fp.sum()
23
+ f_pcs = np.cumsum(fp)
24
+ t_pcs = np.cumsum(tp)
25
+ prec = t_pcs / (f_pcs+t_pcs).astype(float)
26
+ avg_prec = 0
27
+ for i in range(submission_array.shape[0]):
28
+ if tp[i]:
29
+ avg_prec += prec[i]
30
+ m_aps.append(avg_prec / n_pos.astype(float))
31
+ m_aps = np.array(m_aps)
32
+ #m_ap = np.mean(m_aps)
33
+ m_ap = m_aps[~np.isnan(m_aps)]
34
+ print(f'num of available classes: {len(m_ap)}')
35
+ m_ap = m_ap.mean() # compute mean w/o nan
36
+ w_ap = (m_aps * gt_array.sum(axis=0) / gt_array.sum().sum().astype(float))
37
+ return m_ap, w_ap, m_aps
38
+
39
+
40
+ def charades_map(submission_array, gt_array):
41
+ """
42
+ Approximate version of the charades evaluation function
43
+ For precise numbers, use the submission file with the official matlab script
44
+ """
45
+ fix = submission_array.copy()
46
+ empty = np.sum(gt_array, axis=1) == 0
47
+ fix[empty, :] = np.NINF
48
+ return compute_map(fix, gt_array)
49
+
50
+
51
+ def create_submission(video_list, predictions, out_file):
52
+ assert len(video_list) == predictions.shape[0]
53
+ with open(out_file, 'w') as f:
54
+ for i, video_id in enumerate(video_list):
55
+ pred_str = ' '.join(map(lambda x: str(x), predictions[i].tolist()))
56
+ f.write('{} {}\n\n'.format(video_id, pred_str))
lavila/utils/evaluation_ek100mir.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Part of the code is from
8
+ # `https://github.com/mwray/Joint-Part-of-Speech-Embeddings/tree/main/src/evaluation/NDCG.py`
9
+ # and
10
+ # `https://github.com/mwray/Joint-Part-of-Speech-Embeddings/tree/main/src/evaluation/mAP.py`
11
+ # Modified by Yue Zhao
12
+
13
+ import numpy as np
14
+
15
+
16
+ def calculate_DCG(similarity_matrix, relevancy_matrix, k_counts):
17
+ """
18
+ Calculates the Discounted Cumulative Gain (DCG) between two modalities for
19
+ the first modality.
20
+ DCG = \sum_{i=1}^k \frac{rel_i}{log_2(i + 1)}
21
+ i.e. the sum of the k relevant retrievals which is calculated as the scaled
22
+ relevancy for the ith item. The scale is designed such that early
23
+ retrievals are more important than later retrievals.
24
+ Params:
25
+ - similarity_matrix: matrix of size n1 x n2 where n1 is the number of
26
+ items in the first modality and n2 is the number of items in the
27
+ second modality. The [ith,jth] element is the predicted similarity
28
+ between the ith item from the first modality and the jth item from
29
+ the second modality.
30
+ - relevancy_matrix: matrix of size n1 x n2 (see similarity_matrix
31
+ above). The [ith, jth] element is the semantic relevancy between the
32
+ ith item from the first modality and the jth item from the second
33
+ modality.
34
+ - k_counts: matrix of size n1 x n2 (see similarity_matrix above) which
35
+ includes information on which items to use to calculate the DCG for
36
+ (see calculate_k_counts for more info on this matrix).
37
+ Returns:
38
+ - The DCG for each item in the first modality, a n1 length vector.
39
+ """
40
+ x_sz, y_sz = similarity_matrix.shape
41
+ ranks = np.argsort(similarity_matrix)[:, ::-1]
42
+ # Create vector of size (n,) where n is the length of the last dimension in
43
+ # similarity matrix
44
+ # This vector is of the form log(i+1)
45
+ logs = np.log2(np.arange(y_sz) + 2)
46
+ # Convert logs into the divisor for the DCG calculation, of size similarity
47
+ # matrix
48
+ divisors = np.repeat(np.expand_dims(logs, axis=0), x_sz, axis=0)
49
+
50
+ # mask out the sorted relevancy matrix to only use the first k relevant
51
+ # retrievals for each item.
52
+ columns = np.repeat(np.expand_dims(np.arange(x_sz), axis=1), y_sz, axis=1)
53
+ numerators = relevancy_matrix[columns, ranks] * k_counts
54
+ # Calculate the final DCG score (note that this isn't expected to sum to 1)
55
+ return np.sum(numerators / divisors, axis=1)
56
+
57
+
58
+ def calculate_k_counts(relevancy_matrix):
59
+ """
60
+ Works out the maximum number of allowed retrievals when working out the
61
+ Discounted Cumulative Gain. For each query the DCG only uses the first k
62
+ items retrieved which constitute the k relevant items for that query
63
+ (otherwise the nDCG scores can be deceptively high for bad rankings).
64
+ Params:
65
+ - relevancy_matrix: matrix of size n1 x n2 where n1 is the number of
66
+ items in the first modality and n2 is the number of items in the
67
+ second modality. The [ith, jth] element is the semantic relevancy
68
+ between the ith item from the first modality and the jth item from
69
+ the second modality.
70
+ Returns:
71
+ - Matrix of size n1 x n2 (see relevancy matrix for more info). This is
72
+ created as a mask such that if the [ith, jth] element is 1 it
73
+ represents a valid item to use for the calculation of DCG for the
74
+ ith item after sorting. For example, if relevancy matrix of:
75
+ [[1, 0.5, 0],
76
+ [0, 0 , 1]]
77
+ is given, then the k_counts matrix will be:
78
+ [[1, 1, 0],
79
+ [1, 0, 0]]
80
+ i.e. the first row has 2 non-zero items, so the first two retrieved
81
+ items should be used in the calculation. In the second row there is
82
+ only 1 relevant item, therefore only the first retrieved item should
83
+ be used for the DCG calculation.
84
+ """
85
+ return (np.sort(relevancy_matrix)[:, ::-1] > 0).astype(int)
86
+
87
+
88
+ def calculate_IDCG(relevancy_matrix, k_counts):
89
+ """
90
+ Calculates the Ideal Discounted Cumulative Gain (IDCG) which is the value
91
+ of the Discounted Cumulative Gain (DCG) for a perfect retrieval, i.e. the
92
+ items in the second modality were retrieved in order of their descending
93
+ relevancy.
94
+ Params:
95
+ - relevancy_matrix: matrix of size n1 x n2 where n1 is the number of
96
+ items in the first modality and n2 is the number of items in the
97
+ second modality. The [ith, jth] element is the semantic relevancy
98
+ between the ith item from the first modality and the jth item from
99
+ the second modality.
100
+ - k_counts: matrix of size n1 x n2 (see similarity_matrix above) which
101
+ includes information on which items to use to calculate the DCG for
102
+ (see calculate_k_counts for more info on this matrix).
103
+ """
104
+ return calculate_DCG(relevancy_matrix, relevancy_matrix, k_counts)
105
+
106
+
107
+ def calculate_nDCG(similarity_matrix, relevancy_matrix, k_counts=None, IDCG=None, reduction='mean'):
108
+ """
109
+ Calculates the normalised Discounted Cumulative Gain (nDCG) between two
110
+ modalities for the first modality using the Discounted Cumulative Gain
111
+ (DCG) and the Ideal Discounted Cumulative Gain (IDCG).
112
+ nDCG = \frac{DCG}{IDCG}
113
+ Params:
114
+ - similarity_matrix: matrix of size n1 x n2 where n1 is the number of
115
+ items in the first modality and n2 is the number of items in the second
116
+ modality. The [ith,jth] element is the predicted similarity between
117
+ the ith item from the first modality and the jth item from the second
118
+ modality.
119
+ - relevancy_matrix: matrix of size n1 x n2 (see similarity_matrix
120
+ above). The [ith, jth] element is the semantic relevancy between the
121
+ ith item from the first modality and the jth item from the second
122
+ modality.
123
+ - k_counts: optional parameter: matrix of size n1 x n2 (see
124
+ similarity_matrix above) which includes information on which items to
125
+ use to calculate the DCG for (see calculate_k_counts for more info on
126
+ this matrix). This will be calculated using calculate_IDCG if not
127
+ present, but should be pre-processed for efficiency.
128
+ - IDCG: Optional parameter which includes the pre-processed Ideal
129
+ Discounted Cumulative Gain (IDCG). This is a vector of size n1 (see
130
+ similarity_matrix above) which contains the IDCG value for each item
131
+ from the first modality. This will be calculated using calculate_IDCG
132
+ if not present, but should be pre-processed for efficiency.
133
+ - reduction: what to use to reduce the different nDCG scores. By
134
+ default this applies np.mean across all different queries.
135
+ Returns:
136
+ - The nDCG values for the first modality.
137
+ """
138
+ if k_counts is None:
139
+ k_counts = calculate_k_counts(relevancy_matrix)
140
+ DCG = calculate_DCG(similarity_matrix, relevancy_matrix, k_counts)
141
+ if IDCG is None:
142
+ IDCG = calculate_IDCG(relevancy_matrix, k_counts)
143
+ if reduction == 'mean':
144
+ return np.mean(DCG / IDCG)
145
+ elif reduction is None:
146
+ return DCG / IDCG
147
+
148
+
149
+ def calculate_mAP(sim_mat, relevancy_matrix):
150
+ """
151
+ Computes the mean average precision according to the following formula of
152
+ average precision:
153
+ \frac{\sum_{k=1}^n p(k) x rel(k)}{num_rel_docs}
154
+ where p(k) is the precision at k, rel(k) is an indicator function
155
+ determining whether the kth returned item is relevant or not and
156
+ num_rel_docs is the number of relevant items to find within the search.
157
+ The mean average precision is the mean of the average precision for each
158
+ query item (i.e row in the matrix)
159
+ This function takes in two parameters:
160
+ - sim_mat: a NxM matrix which represents the similarity between two
161
+ modalities (with modality 1 being of size N and modality 2 of size M).
162
+ - relevancy_matrix: an NxM matrix which represents the relevancy between two
163
+ modalities of items (with modality 1 being of size N and modality 2 of
164
+ size M).
165
+ """
166
+ # Find the order of the items in modality 2 according to modality 1
167
+ ranked_order = (-sim_mat).argsort()
168
+ ranked_sim_mat = sim_mat[np.arange(sim_mat.shape[0])[:, None], ranked_order]
169
+ # re-order the relevancy matrix to accommodate the proposals
170
+ ranked_rel_mat = relevancy_matrix[np.arange(relevancy_matrix.shape[0])[:, None], ranked_order]
171
+
172
+ # find the number of relevant items found at each k
173
+ cumulative_rel_mat = np.cumsum(ranked_rel_mat, axis=1)
174
+ # Mask this ensuring that it is non zero if the kth term is 1 (rel(k) above)
175
+ cumulative_rel_mat[ranked_rel_mat != 1] = 0
176
+ # find the divisor for p(k)
177
+ divisor = np.arange(ranked_rel_mat.shape[1]) + 1
178
+
179
+ # find the number of relevant docs per query item
180
+ number_rel_docs = np.sum(ranked_rel_mat == 1, axis=1)
181
+
182
+ # find the average precision per query, within np.sum finds p(k) * rel(k)
183
+ avg_precision = np.sum(cumulative_rel_mat / divisor, axis=1) / number_rel_docs
184
+ mAP = np.mean(avg_precision)
185
+ return mAP
186
+
187
+
188
+ def get_mAP(similarity_matrix, rel_matrix):
189
+ vis_map = calculate_mAP(similarity_matrix, rel_matrix)
190
+ txt_map = calculate_mAP(similarity_matrix.T, rel_matrix.T)
191
+ return vis_map, txt_map, (vis_map + txt_map) / 2
192
+
193
+
194
+ def get_nDCG(similarity_matrix, rel_matrix):
195
+ vis_k_counts = calculate_k_counts(rel_matrix)
196
+ txt_k_counts = calculate_k_counts(rel_matrix.T)
197
+ vis_IDCG = calculate_IDCG(rel_matrix, vis_k_counts)
198
+ txt_IDCG = calculate_IDCG(rel_matrix.T, txt_k_counts)
199
+ vis_nDCG = calculate_nDCG(similarity_matrix, rel_matrix, k_counts=vis_k_counts, IDCG=vis_IDCG)
200
+ txt_nDCG = calculate_nDCG(similarity_matrix.T, rel_matrix.T, k_counts=txt_k_counts, IDCG=txt_IDCG)
201
+ return vis_nDCG, txt_nDCG, (vis_nDCG + txt_nDCG) / 2
lavila/utils/preprocess.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import csv
8
+
9
+ from lavila.models.tokenizer import MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer
10
+
11
+
12
+ def generate_label_map(dataset):
13
+ if dataset == 'ek100_cls':
14
+ print("Preprocess ek100 action label space")
15
+ vn_list = []
16
+ mapping_vn2narration = {}
17
+ for f in [
18
+ '/data/EK100/epic-kitchens-100-annotations/EPIC_100_train.csv',
19
+ '/data/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv',
20
+ ]:
21
+ csv_reader = csv.reader(open(f))
22
+ _ = next(csv_reader) # skip the header
23
+ for row in csv_reader:
24
+ vn = '{}:{}'.format(int(row[10]), int(row[12]))
25
+ narration = row[8]
26
+ if vn not in vn_list:
27
+ vn_list.append(vn)
28
+ if vn not in mapping_vn2narration:
29
+ mapping_vn2narration[vn] = [narration]
30
+ else:
31
+ mapping_vn2narration[vn].append(narration)
32
+ # mapping_vn2narration[vn] = [narration]
33
+ vn_list = sorted(vn_list)
34
+ print('# of action= {}'.format(len(vn_list)))
35
+ mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)}
36
+ labels = [list(set(mapping_vn2narration[vn_list[i]])) for i in range(len(mapping_vn2act))]
37
+ print(labels[:5])
38
+ elif dataset == 'charades_ego':
39
+ print("=> preprocessing charades_ego action label space")
40
+ vn_list = []
41
+ labels = []
42
+ with open('/data/CharadesEgo/CharadesEgo/Charades_v1_classes.txt') as f:
43
+ csv_reader = csv.reader(f)
44
+ for row in csv_reader:
45
+ vn = row[0][:4]
46
+ vn_list.append(vn)
47
+ narration = row[0][5:]
48
+ labels.append(narration)
49
+ mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)}
50
+ print(labels[:5])
51
+ elif dataset == 'egtea':
52
+ print("=> preprocessing egtea action label space")
53
+ labels = []
54
+ with open('/data/EGTEA/action_idx.txt') as f:
55
+ for row in f:
56
+ row = row.strip()
57
+ narration = ' '.join(row.split(' ')[:-1])
58
+ labels.append(narration.replace('_', ' ').lower())
59
+ # labels.append(narration)
60
+ mapping_vn2act = {label: i for i, label in enumerate(labels)}
61
+ print(len(labels), labels[:5])
62
+ else:
63
+ raise NotImplementedError
64
+ return labels, mapping_vn2act
65
+
66
+
67
+ def generate_tokenizer(model):
68
+ if model.endswith('DISTILBERT_BASE'):
69
+ tokenizer = MyDistilBertTokenizer('distilbert-base-uncased')
70
+ elif model.endswith('BERT_BASE'):
71
+ tokenizer = MyBertTokenizer('bert-base-uncased')
72
+ elif model.endswith('BERT_LARGE'):
73
+ tokenizer = MyBertTokenizer('bert-large-uncased')
74
+ elif model.endswith('GPT2'):
75
+ tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True)
76
+ elif model.endswith('GPT2_MEDIUM'):
77
+ tokenizer = MyGPT2Tokenizer('gpt2-medium', add_bos=True)
78
+ elif model.endswith('GPT2_LARGE'):
79
+ tokenizer = MyGPT2Tokenizer('gpt2-large', add_bos=True)
80
+ elif model.endswith('GPT2_XL'):
81
+ tokenizer = MyGPT2Tokenizer('gpt2-xl', add_bos=True)
82
+ else:
83
+ print("Using SimpleTokenizer because of model '{}'. "
84
+ "Please check if this is what you want".format(model))
85
+ tokenizer = SimpleTokenizer()
86
+ return tokenizer