CharadesEgo / demo.py
gina9726's picture
Upload demo files
c6f92cc verified
raw
history blame
11 kB
### demo.py
# Define model classes for inference.
###
from collections import OrderedDict
import json
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.transforms._transforms_video as transforms_video
from sklearn.metrics import confusion_matrix
from lavila.data import datasets
from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop
from lavila.models import models
from lavila.models.tokenizer import (MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer)
from lavila.models.utils import inflate_positional_embeds
from lavila.utils.config import load_cfg
from lavila.utils.evaluation_charades import charades_map
from lavila.utils.evaluation import get_mean_accuracy
class VideoModel(nn.Module):
""" Base model for video understanding based on LaViLa architecture. """
def __init__(self, config):
""" Initializes the model.
Parameters:
config: config file
"""
super(VideoModel, self).__init__()
self.cfg = load_cfg(config)
self.model = self.build_model()
self.tokenizer = self.get_tokenizer()
self.templates = ['{}']
self.dataset = self.cfg['data']['dataset']
self.eval()
def build_model(self):
cfg = self.cfg
if cfg['model'].get('pretrain', False):
ckpt_path = cfg['model']['pretrain']
else:
raise Exception('no checkpoint found')
ckpt = torch.load(ckpt_path, map_location='cpu')
state_dict = OrderedDict()
for k, v in ckpt['state_dict'].items():
state_dict[k.replace('module.', '')] = v
old_args = vars(ckpt['args'])
arch = old_args.get('model', 'CLIP_OPENAI_TIMESFORMER_BASE')
self.arch = arch
cfg['model']['arch'] = arch
cfg['model']['norm_embed'] = old_args.get('norm_embed', True)
print("=> creating model: {}".format(arch))
model = getattr(models, arch)(
pretrained=old_args.get('load_visual_pretrained', None),
pretrained2d=old_args.get('load_visual_pretrained', None) is not None,
text_use_cls_token=old_args.get('use_cls_token', False),
project_embed_dim=old_args.get('project_embed_dim', 256),
timesformer_gated_xattn=False,
num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']),
model_cfg=cfg['model']
)
model.logit_scale.requires_grad = False
if torch.cuda.is_available():
model.cuda()
if ('TIMESFORMER' in arch or 'EGOVLP' in arch) and cfg['model'].get('inflat_posemb', True):
# inflate weight
print('=> inflating PE in models due to different frame numbers')
state_dict = inflate_positional_embeds(
model.state_dict(), state_dict,
num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']),
load_temporal_fix='bilinear',
)
model.load_state_dict(state_dict, strict=True)
print("=> loaded resume checkpoint '{}' (epoch {})".format(ckpt_path, ckpt['epoch']))
return model
def eval(self):
cudnn.benchmark = True
for p in self.model.parameters():
p.requires_grad = False
self.model.eval()
def get_tokenizer(self):
arch = self.arch
if arch.endswith('DISTILBERT_BASE'):
tokenizer = MyDistilBertTokenizer('distilbert-base-uncased')
elif arch.endswith('BERT_BASE'):
tokenizer = MyBertTokenizer('bert-base-uncased')
elif arch.endswith('BERT_LARGE'):
tokenizer = MyBertTokenizer('bert-large-uncased')
elif arch.endswith('GPT2'):
tokenizer = MyGPT2Tokenizer('gpt2')
elif arch.endswith('GPT2_MEDIUM'):
tokenizer = MyGPT2Tokenizer('gpt2-medium')
elif arch.endswith('GPT2_LARGE'):
tokenizer = MyGPT2Tokenizer('gpt2-large')
elif arch.endswith('GPT2_XL'):
tokenizer = MyGPT2Tokenizer('gpt2-xl')
else:
print("Using SimpleTokenizer because of model '{}'. "
"Please check if this is what you want".format(arch))
tokenizer = SimpleTokenizer()
return tokenizer
class VideoCLSModel(VideoModel):
""" Video model for video classification tasks (Charades-Ego, EGTEA). """
def __init__(self, config):
super(VideoCLSModel, self).__init__(config)
self.labels, self.mapping_vn2act = self.gen_label_map()
self.text_features = self.get_text_features()
def gen_label_map(self):
labelmap = self.cfg.get('label_map', 'meta/charades_ego/label_map.json')
if os.path.isfile(labelmap):
print(f"=> Loading label maps from {labelmap}")
meta = json.load(open(labelmap, 'r'))
labels, mapping_vn2act = meta['labels'], meta['mapping_vn2act']
else:
from lavila.utils.preprocess import generate_label_map
labels, mapping_vn2act = generate_label_map(self.dataset)
meta = {'labels': labels, 'mapping_vn2act': mapping_vn2act}
meta_dir = f'meta/{self.dataset}'
if not os.path.exists(meta_dir):
os.makedirs(meta_dir)
json.dump(meta, open(f'{meta_dir}/label_map.json', 'w'))
print(f"=> Label map is generated and saved to {meta_dir}/label_map.json")
return labels, mapping_vn2act
def load_data(self, idx=None):
print(f"=> Creating dataset")
cfg, dataset = self.cfg, self.dataset
data_cfg = cfg['data']
crop_size = 224 if '336PX' not in self.arch else 336
val_transform = transforms.Compose([
Permute([3, 0, 1, 2]), # T H W C -> C T H W
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]),
])
if idx is None:
metadata_val = data_cfg['metadata_val']
else:
metadata_val = data_cfg['metadata_val'].format(idx)
if dataset in ['charades_ego', 'egtea']:
val_dataset = datasets.VideoClassyDataset(
dataset, data_cfg['root'], metadata_val,
transform=val_transform, is_training=False,
label_mapping=self.mapping_vn2act, is_trimmed=False,
num_clips=1, clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
sparse_sample=data_cfg['sparse_sample']
)
else:
raise NotImplementedError
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=8, shuffle=False,
num_workers=4, pin_memory=True, sampler=None, drop_last=False
)
return val_loader
@torch.no_grad()
def get_text_features(self):
print('=> Extracting text features')
text_features = []
for label in self.labels:
if isinstance(label, list):
texts = [tmpl.format(lbl) for tmpl in self.templates for lbl in label]
else:
texts = [tmpl.format(label) for tmpl in self.templates]
texts = self.tokenizer(texts)
if isinstance(texts, tuple):
# Bert-style tokenizer will output both ids and mask
texts, masks = texts
texts = texts.cuda(non_blocking=True)
masks = masks.cuda(non_blocking=True)
else:
texts = texts.cuda(non_blocking=True)
masks = None
texts = texts.view(-1, 77).contiguous()
masks = masks.view(-1, 77).contiguous() if masks is not None else None
if masks is not None:
class_embeddings, _ = self.model.encode_text(texts, attention_mask=masks)
else:
class_embeddings, _ = self.model.encode_text(texts)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
class_embeddings = class_embeddings.mean(dim=0)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
text_features.append(class_embeddings)
text_features = torch.stack(text_features, dim=0)
return text_features
@torch.no_grad()
def forward(self, idx=None):
print('=> Start forwarding')
val_loader = self.load_data(idx)
all_outputs = []
all_targets = []
for i, values in enumerate(val_loader):
images = values[0]
target = values[1]
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# encode images
image_features, _ = self.model.encode_image(images)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_image = image_features @ self.text_features.t()
logits_per_image = torch.softmax(logits_per_image, dim=1)
all_outputs.append(logits_per_image.cpu())
all_targets.append(target.cpu())
all_outputs = torch.cat(all_outputs)
all_targets = torch.cat(all_targets)
return all_outputs, all_targets
@torch.no_grad()
def predict(self, idx=0):
all_outputs, all_targets = self.forward(idx)
preds, targets = all_outputs.numpy(), all_targets.numpy()
#sel = np.where(np.cumsum(sorted(preds[0].tolist(), reverse=True)) > 0.06)[0][0]
sel = 5
df = pd.DataFrame(self.labels)
pred_action = df.iloc[preds[0].argsort()[-sel:]].values.tolist()
gt_action = df.iloc[np.where(targets[0])[0]].values.tolist()
pred_action = sorted([x[0] for x in pred_action])
gt_action = sorted([x[0] for x in gt_action])
return pred_action, gt_action
@torch.no_grad()
def evaluate(self):
all_outputs, all_targets = self.forward()
preds, targets = all_outputs.numpy(), all_targets.numpy()
if self.dataset == 'charades_ego':
m_ap, _, m_aps = charades_map(preds, targets)
print('mAP = {:.3f}'.format(m_ap))
elif self.dataset == 'egtea':
cm = confusion_matrix(targets, preds.argmax(axis=1))
mean_class_acc, acc = get_mean_accuracy(cm)
print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc))
else:
raise NotImplementedError
def main():
lavila = VideoCLSModel("configs/charades_ego/zeroshot.yml")
egovpa = VideoCLSModel("configs/charades_ego/egovpa.yml")
lavila.evaluate()
egovpa.evaluate()
if __name__ == '__main__':
main()