hvaldez's picture
update to demo.py
e0cb1ec verified
### demo.py
# Define model classes for inference.
###
import json
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.transforms._transforms_video as transforms_video
from sklearn.metrics import confusion_matrix
from einops import rearrange
from transformers import BertTokenizer
from svitt.model import SViTT
from svitt.datasets import VideoClassyDataset
from svitt.video_transforms import Permute
from svitt.config import load_cfg, setup_config
from svitt.evaluation_charades import charades_map
from svitt.evaluation import get_mean_accuracy
class VideoModel(nn.Module):
""" Base model for video understanding based on SViTT architecture. """
def __init__(self, config):
""" Initializes the model.
Parameters:
config: config file
"""
super(VideoModel, self).__init__()
self.cfg = load_cfg(config)
self.model = self.build_model()
use_gpu = torch.cuda.is_available()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if use_gpu:
self.model = self.model.to(self.device)
self.templates = ['{}']
self.dataset = self.cfg['data']['dataset']
self.eval()
def build_model(self):
cfg = self.cfg
if cfg['model'].get('pretrain', False):
ckpt_path = cfg['model']['pretrain']
else:
raise Exception('no checkpoint found')
if cfg['model'].get('config', False):
config_path = cfg['model']['config']
else:
raise Exception('no model config found')
self.model_cfg = setup_config(config_path)
self.tokenizer = BertTokenizer.from_pretrained(self.model_cfg.text_encoder)
model = SViTT(config=self.model_cfg, tokenizer=self.tokenizer)
print(f"Loading checkpoint from {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint["model"]
# fix for zero-shot evaluation
for key in list(state_dict.keys()):
if "bert" in key:
encoder_key = key.replace("bert.", "")
state_dict[encoder_key] = state_dict[key]
if torch.cuda.is_available():
model.cuda()
model.load_state_dict(state_dict, strict=False)
return model
def eval(self):
cudnn.benchmark = True
for p in self.model.parameters():
p.requires_grad = False
self.model.eval()
class VideoCLSModel(VideoModel):
""" Video model for video classification tasks (Charades-Ego, EGTEA). """
def __init__(self, config):
super(VideoCLSModel, self).__init__(config)
self.labels, self.mapping_vn2act = self.gen_label_map()
self.text_features = self.get_text_features()
def gen_label_map(self):
labelmap = self.cfg.get('label_map', 'meta/charades_ego/label_map.json')
if os.path.isfile(labelmap):
print(f"=> Loading label maps from {labelmap}")
meta = json.load(open(labelmap, 'r'))
labels, mapping_vn2act = meta['labels'], meta['mapping_vn2act']
else:
from svitt.preprocess import generate_label_map
labels, mapping_vn2act = generate_label_map(self.dataset)
meta = {'labels': labels, 'mapping_vn2act': mapping_vn2act}
meta_dir = f'meta/{self.dataset}'
if not os.path.exists(meta_dir):
os.makedirs(meta_dir)
json.dump(meta, open(f'{meta_dir}/label_map.json', 'w'))
print(f"=> Label map is generated and saved to {meta_dir}/label_map.json")
return labels, mapping_vn2act
def load_data(self, idx=None):
print(f"=> Creating dataset")
cfg, dataset = self.cfg, self.dataset
data_cfg = cfg['data']
crop_size = 224
val_transform = transforms.Compose([
Permute([3, 0, 1, 2]), # T H W C -> C T H W
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms_video.NormalizeVideo(
mean=[108.3272985, 116.7460125, 104.09373615000001],
std=[68.5005327, 66.6321579, 70.32316305],
),
])
if idx is None:
metadata_val = data_cfg['metadata_val']
else:
metadata_val = data_cfg['metadata_val'].format(idx)
if dataset in ['charades_ego', 'egtea']:
val_dataset = VideoClassyDataset(
dataset,
data_cfg['root'],
metadata_val,
transform=val_transform,
is_training=False,
label_mapping=self.mapping_vn2act,
is_trimmed=False,
num_clips=1,
clip_length=data_cfg['clip_length'],
clip_stride=data_cfg['clip_stride'],
sparse_sample=data_cfg['sparse_sample'],
)
else:
raise NotImplementedError
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=8, shuffle=False,
num_workers=4, pin_memory=True, sampler=None, drop_last=False
)
return val_loader
@torch.no_grad()
def get_text_features(self):
print('=> Extracting text features')
embeddings = self.tokenizer(
self.labels,
padding="max_length",
truncation=True,
max_length=self.model_cfg.max_txt_l.video,
return_tensors="pt",
).to(self.device)
_, class_embeddings = self.model.encode_text(embeddings)
return class_embeddings
@torch.no_grad()
def forward(self, idx=None):
print('=> Start forwarding')
val_loader = self.load_data(idx)
all_outputs = []
all_targets = []
for i, values in enumerate(val_loader):
images = values[0]
target = values[1]
images = images.to(self.device)
# encode images
images = rearrange(images, 'b c k h w -> b k c h w')
dims = images.shape
images = images.reshape(-1, 4, dims[-3], dims[-2], dims[-1])
image_features, _ = self.model.encode_image(images)
if image_features.ndim == 3:
image_features = rearrange(image_features, '(b k) n d -> b (k n) d', b=1)
else:
image_features = rearrange(image_features, '(b k) d -> b k d', b=1)
# cosine similarity as logits
similarity = self.model.get_sim(image_features, self.text_features)[0]
all_outputs.append(similarity.cpu())
all_targets.append(target)
all_outputs = torch.cat(all_outputs)
all_targets = torch.cat(all_targets)
return all_outputs, all_targets
@torch.no_grad()
def predict(self, idx=0):
all_outputs, all_targets = self.forward(idx)
preds, targets = all_outputs.numpy(), all_targets.numpy()
#sel = np.where(np.cumsum(sorted(preds[0].tolist(), reverse=True)) > 0.06)[0][0]
sel = 5
df = pd.DataFrame(self.labels)
pred_action = df.iloc[preds[0].argsort()[-sel:]].values.tolist()
gt_action = df.iloc[np.where(targets[0])[0]].values.tolist()
pred_action = sorted([x[0] for x in pred_action])
gt_action = sorted([x[0] for x in gt_action])
return pred_action, gt_action
@torch.no_grad()
def evaluate(self):
all_outputs, all_targets = self.forward()
preds, targets = all_outputs.numpy(), all_targets.numpy()
if self.dataset == 'charades_ego':
m_ap, _, m_aps = charades_map(preds, targets)
print('mAP = {:.3f}'.format(m_ap))
elif self.dataset == 'egtea':
cm = confusion_matrix(targets, preds.argmax(axis=1))
mean_class_acc, acc = get_mean_accuracy(cm)
print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc))
else:
raise NotImplementedError