### demo.py # Define model classes for inference. ### from collections import OrderedDict import json import numpy as np import os import pandas as pd import torch import torch.nn as nn import torch.backends.cudnn as cudnn import torchvision.transforms as transforms import torchvision.transforms._transforms_video as transforms_video from sklearn.metrics import confusion_matrix from lavila.data import datasets from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop from lavila.models import models from lavila.models.tokenizer import (MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer) from lavila.models.utils import inflate_positional_embeds from lavila.utils.config import load_cfg from lavila.utils.evaluation_charades import charades_map from lavila.utils.evaluation import get_mean_accuracy class VideoModel(nn.Module): """ Base model for video understanding based on LaViLa architecture. """ def __init__(self, config): """ Initializes the model. Parameters: config: config file """ super(VideoModel, self).__init__() self.cfg = load_cfg(config) self.model = self.build_model() self.tokenizer = self.get_tokenizer() self.templates = ['{}'] self.dataset = self.cfg['data']['dataset'] self.eval() def build_model(self): cfg = self.cfg if cfg['model'].get('pretrain', False): ckpt_path = cfg['model']['pretrain'] else: raise Exception('no checkpoint found') ckpt = torch.load(ckpt_path, map_location='cpu') state_dict = OrderedDict() for k, v in ckpt['state_dict'].items(): state_dict[k.replace('module.', '')] = v old_args = vars(ckpt['args']) arch = old_args.get('model', 'CLIP_OPENAI_TIMESFORMER_BASE') self.arch = arch cfg['model']['arch'] = arch cfg['model']['norm_embed'] = old_args.get('norm_embed', True) print("=> creating model: {}".format(arch)) model = getattr(models, arch)( pretrained=old_args.get('load_visual_pretrained', None), pretrained2d=old_args.get('load_visual_pretrained', None) is not None, text_use_cls_token=old_args.get('use_cls_token', False), project_embed_dim=old_args.get('project_embed_dim', 256), timesformer_gated_xattn=False, num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']), model_cfg=cfg['model'] ) model.logit_scale.requires_grad = False if torch.cuda.is_available(): model.cuda() if ('TIMESFORMER' in arch or 'EGOVLP' in arch) and cfg['model'].get('inflat_posemb', True): # inflate weight print('=> inflating PE in models due to different frame numbers') state_dict = inflate_positional_embeds( model.state_dict(), state_dict, num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']), load_temporal_fix='bilinear', ) model.load_state_dict(state_dict, strict=True) print("=> loaded resume checkpoint '{}' (epoch {})".format(ckpt_path, ckpt['epoch'])) return model def eval(self): cudnn.benchmark = True for p in self.model.parameters(): p.requires_grad = False self.model.eval() def get_tokenizer(self): arch = self.arch if arch.endswith('DISTILBERT_BASE'): tokenizer = MyDistilBertTokenizer('distilbert-base-uncased') elif arch.endswith('BERT_BASE'): tokenizer = MyBertTokenizer('bert-base-uncased') elif arch.endswith('BERT_LARGE'): tokenizer = MyBertTokenizer('bert-large-uncased') elif arch.endswith('GPT2'): tokenizer = MyGPT2Tokenizer('gpt2') elif arch.endswith('GPT2_MEDIUM'): tokenizer = MyGPT2Tokenizer('gpt2-medium') elif arch.endswith('GPT2_LARGE'): tokenizer = MyGPT2Tokenizer('gpt2-large') elif arch.endswith('GPT2_XL'): tokenizer = MyGPT2Tokenizer('gpt2-xl') else: print("Using SimpleTokenizer because of model '{}'. " "Please check if this is what you want".format(arch)) tokenizer = SimpleTokenizer() return tokenizer class VideoCLSModel(VideoModel): """ Video model for video classification tasks (Charades-Ego, EGTEA). """ def __init__(self, config): super(VideoCLSModel, self).__init__(config) self.labels, self.mapping_vn2act = self.gen_label_map() self.text_features = self.get_text_features() def gen_label_map(self): labelmap = self.cfg.get('label_map', 'meta/charades_ego/label_map.json') if os.path.isfile(labelmap): print(f"=> Loading label maps from {labelmap}") meta = json.load(open(labelmap, 'r')) labels, mapping_vn2act = meta['labels'], meta['mapping_vn2act'] else: from lavila.utils.preprocess import generate_label_map labels, mapping_vn2act = generate_label_map(self.dataset) meta = {'labels': labels, 'mapping_vn2act': mapping_vn2act} meta_dir = f'meta/{self.dataset}' if not os.path.exists(meta_dir): os.makedirs(meta_dir) json.dump(meta, open(f'{meta_dir}/label_map.json', 'w')) print(f"=> Label map is generated and saved to {meta_dir}/label_map.json") return labels, mapping_vn2act def load_data(self, idx=None): print(f"=> Creating dataset") cfg, dataset = self.cfg, self.dataset data_cfg = cfg['data'] crop_size = 224 if '336PX' not in self.arch else 336 val_transform = transforms.Compose([ Permute([3, 0, 1, 2]), # T H W C -> C T H W transforms.Resize(crop_size), transforms.CenterCrop(crop_size), transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]), ]) if idx is None: metadata_val = data_cfg['metadata_val'] else: metadata_val = data_cfg['metadata_val'].format(idx) if dataset in ['charades_ego', 'egtea']: val_dataset = datasets.VideoClassyDataset( dataset, data_cfg['root'], metadata_val, transform=val_transform, is_training=False, label_mapping=self.mapping_vn2act, is_trimmed=False, num_clips=1, clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'], sparse_sample=data_cfg['sparse_sample'] ) else: raise NotImplementedError val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=8, shuffle=False, num_workers=4, pin_memory=True, sampler=None, drop_last=False ) return val_loader @torch.no_grad() def get_text_features(self): print('=> Extracting text features') text_features = [] for label in self.labels: if isinstance(label, list): texts = [tmpl.format(lbl) for tmpl in self.templates for lbl in label] else: texts = [tmpl.format(label) for tmpl in self.templates] texts = self.tokenizer(texts) if isinstance(texts, tuple): # Bert-style tokenizer will output both ids and mask texts, masks = texts texts = texts.cuda(non_blocking=True) masks = masks.cuda(non_blocking=True) else: texts = texts.cuda(non_blocking=True) masks = None texts = texts.view(-1, 77).contiguous() masks = masks.view(-1, 77).contiguous() if masks is not None else None if masks is not None: class_embeddings, _ = self.model.encode_text(texts, attention_mask=masks) else: class_embeddings, _ = self.model.encode_text(texts) class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) class_embeddings = class_embeddings.mean(dim=0) class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) text_features.append(class_embeddings) text_features = torch.stack(text_features, dim=0) return text_features @torch.no_grad() def forward(self, idx=None): print('=> Start forwarding') val_loader = self.load_data(idx) all_outputs = [] all_targets = [] for i, values in enumerate(val_loader): images = values[0] target = values[1] images = images.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # encode images image_features, _ = self.model.encode_image(images) image_features = image_features / image_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logits_per_image = image_features @ self.text_features.t() logits_per_image = torch.softmax(logits_per_image, dim=1) all_outputs.append(logits_per_image.cpu()) all_targets.append(target.cpu()) all_outputs = torch.cat(all_outputs) all_targets = torch.cat(all_targets) return all_outputs, all_targets @torch.no_grad() def predict(self, idx=0): all_outputs, all_targets = self.forward(idx) preds, targets = all_outputs.numpy(), all_targets.numpy() #sel = np.where(np.cumsum(sorted(preds[0].tolist(), reverse=True)) > 0.06)[0][0] sel = 5 df = pd.DataFrame(self.labels) pred_action = df.iloc[preds[0].argsort()[-sel:]].values.tolist() gt_action = df.iloc[np.where(targets[0])[0]].values.tolist() pred_action = sorted([x[0] for x in pred_action]) gt_action = sorted([x[0] for x in gt_action]) return pred_action, gt_action @torch.no_grad() def evaluate(self): all_outputs, all_targets = self.forward() preds, targets = all_outputs.numpy(), all_targets.numpy() if self.dataset == 'charades_ego': m_ap, _, m_aps = charades_map(preds, targets) print('mAP = {:.3f}'.format(m_ap)) elif self.dataset == 'egtea': cm = confusion_matrix(targets, preds.argmax(axis=1)) mean_class_acc, acc = get_mean_accuracy(cm) print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc)) else: raise NotImplementedError def main(): lavila = VideoCLSModel("configs/charades_ego/zeroshot.yml") egovpa = VideoCLSModel("configs/charades_ego/egovpa.yml") lavila.evaluate() egovpa.evaluate() if __name__ == '__main__': main()