|
|
|
|
|
|
|
import json |
|
import numpy as np |
|
import os |
|
import pandas as pd |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.backends.cudnn as cudnn |
|
import torchvision.transforms as transforms |
|
import torchvision.transforms._transforms_video as transforms_video |
|
from sklearn.metrics import confusion_matrix |
|
from einops import rearrange |
|
from transformers import BertTokenizer |
|
|
|
from svitt.model import SViTT |
|
from svitt.datasets import VideoClassyDataset |
|
from svitt.video_transforms import Permute |
|
from svitt.config import load_cfg, setup_config |
|
from svitt.evaluation_charades import charades_map |
|
from svitt.evaluation import get_mean_accuracy |
|
|
|
|
|
class VideoModel(nn.Module): |
|
""" Base model for video understanding based on SViTT architecture. """ |
|
def __init__(self, config): |
|
""" Initializes the model. |
|
Parameters: |
|
config: config file |
|
""" |
|
super(VideoModel, self).__init__() |
|
self.cfg = load_cfg(config) |
|
self.model = self.build_model() |
|
use_gpu = torch.cuda.is_available() |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
if use_gpu: |
|
self.model = self.model.to(self.device) |
|
self.templates = ['{}'] |
|
self.dataset = self.cfg['data']['dataset'] |
|
self.eval() |
|
|
|
def build_model(self): |
|
cfg = self.cfg |
|
if cfg['model'].get('pretrain', False): |
|
ckpt_path = cfg['model']['pretrain'] |
|
else: |
|
raise Exception('no checkpoint found') |
|
|
|
if cfg['model'].get('config', False): |
|
config_path = cfg['model']['config'] |
|
else: |
|
raise Exception('no model config found') |
|
|
|
self.model_cfg = setup_config(config_path) |
|
self.tokenizer = BertTokenizer.from_pretrained(self.model_cfg.text_encoder) |
|
model = SViTT(config=self.model_cfg, tokenizer=self.tokenizer) |
|
|
|
print(f"Loading checkpoint from {ckpt_path}") |
|
checkpoint = torch.load(ckpt_path, map_location="cpu") |
|
state_dict = checkpoint["model"] |
|
|
|
|
|
for key in list(state_dict.keys()): |
|
if "bert" in key: |
|
encoder_key = key.replace("bert.", "") |
|
state_dict[encoder_key] = state_dict[key] |
|
|
|
if torch.cuda.is_available(): |
|
model.cuda() |
|
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
return model |
|
|
|
|
|
|
|
def eval(self): |
|
cudnn.benchmark = True |
|
for p in self.model.parameters(): |
|
p.requires_grad = False |
|
self.model.eval() |
|
|
|
|
|
class VideoCLSModel(VideoModel): |
|
""" Video model for video classification tasks (Charades-Ego, EGTEA). """ |
|
def __init__(self, config): |
|
super(VideoCLSModel, self).__init__(config) |
|
self.labels, self.mapping_vn2act = self.gen_label_map() |
|
self.text_features = self.get_text_features() |
|
|
|
def gen_label_map(self): |
|
labelmap = self.cfg.get('label_map', 'meta/charades_ego/label_map.json') |
|
if os.path.isfile(labelmap): |
|
print(f"=> Loading label maps from {labelmap}") |
|
meta = json.load(open(labelmap, 'r')) |
|
labels, mapping_vn2act = meta['labels'], meta['mapping_vn2act'] |
|
else: |
|
from svitt.preprocess import generate_label_map |
|
labels, mapping_vn2act = generate_label_map(self.dataset) |
|
meta = {'labels': labels, 'mapping_vn2act': mapping_vn2act} |
|
meta_dir = f'meta/{self.dataset}' |
|
if not os.path.exists(meta_dir): |
|
os.makedirs(meta_dir) |
|
json.dump(meta, open(f'{meta_dir}/label_map.json', 'w')) |
|
print(f"=> Label map is generated and saved to {meta_dir}/label_map.json") |
|
|
|
return labels, mapping_vn2act |
|
|
|
def load_data(self, idx=None): |
|
print(f"=> Creating dataset") |
|
cfg, dataset = self.cfg, self.dataset |
|
data_cfg = cfg['data'] |
|
crop_size = 224 |
|
val_transform = transforms.Compose([ |
|
Permute([3, 0, 1, 2]), |
|
transforms.Resize(crop_size), |
|
transforms.CenterCrop(crop_size), |
|
transforms_video.NormalizeVideo( |
|
mean=[108.3272985, 116.7460125, 104.09373615000001], |
|
std=[68.5005327, 66.6321579, 70.32316305], |
|
), |
|
]) |
|
|
|
if idx is None: |
|
metadata_val = data_cfg['metadata_val'] |
|
else: |
|
metadata_val = data_cfg['metadata_val'].format(idx) |
|
if dataset in ['charades_ego', 'egtea']: |
|
val_dataset = VideoClassyDataset( |
|
dataset, |
|
data_cfg['root'], |
|
metadata_val, |
|
transform=val_transform, |
|
is_training=False, |
|
label_mapping=self.mapping_vn2act, |
|
is_trimmed=False, |
|
num_clips=1, |
|
clip_length=data_cfg['clip_length'], |
|
clip_stride=data_cfg['clip_stride'], |
|
sparse_sample=data_cfg['sparse_sample'], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
val_loader = torch.utils.data.DataLoader( |
|
val_dataset, batch_size=8, shuffle=False, |
|
num_workers=4, pin_memory=True, sampler=None, drop_last=False |
|
) |
|
|
|
return val_loader |
|
|
|
@torch.no_grad() |
|
def get_text_features(self): |
|
print('=> Extracting text features') |
|
embeddings = self.tokenizer( |
|
self.labels, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=self.model_cfg.max_txt_l.video, |
|
return_tensors="pt", |
|
).to(self.device) |
|
_, class_embeddings = self.model.encode_text(embeddings) |
|
return class_embeddings |
|
|
|
@torch.no_grad() |
|
def forward(self, idx=None): |
|
print('=> Start forwarding') |
|
val_loader = self.load_data(idx) |
|
all_outputs = [] |
|
all_targets = [] |
|
for i, values in enumerate(val_loader): |
|
images = values[0] |
|
target = values[1] |
|
|
|
images = images.to(self.device) |
|
|
|
|
|
images = rearrange(images, 'b c k h w -> b k c h w') |
|
dims = images.shape |
|
images = images.reshape(-1, 4, dims[-3], dims[-2], dims[-1]) |
|
|
|
image_features, _ = self.model.encode_image(images) |
|
|
|
if image_features.ndim == 3: |
|
image_features = rearrange(image_features, '(b k) n d -> b (k n) d', b=1) |
|
else: |
|
image_features = rearrange(image_features, '(b k) d -> b k d', b=1) |
|
|
|
|
|
similarity = self.model.get_sim(image_features, self.text_features)[0] |
|
|
|
all_outputs.append(similarity.cpu()) |
|
all_targets.append(target) |
|
|
|
all_outputs = torch.cat(all_outputs) |
|
all_targets = torch.cat(all_targets) |
|
|
|
return all_outputs, all_targets |
|
|
|
@torch.no_grad() |
|
def predict(self, idx=0): |
|
all_outputs, all_targets = self.forward(idx) |
|
preds, targets = all_outputs.numpy(), all_targets.numpy() |
|
|
|
sel = 5 |
|
df = pd.DataFrame(self.labels) |
|
pred_action = df.iloc[preds[0].argsort()[-sel:]].values.tolist() |
|
gt_action = df.iloc[np.where(targets[0])[0]].values.tolist() |
|
pred_action = sorted([x[0] for x in pred_action]) |
|
gt_action = sorted([x[0] for x in gt_action]) |
|
return pred_action, gt_action |
|
|
|
@torch.no_grad() |
|
def evaluate(self): |
|
all_outputs, all_targets = self.forward() |
|
preds, targets = all_outputs.numpy(), all_targets.numpy() |
|
if self.dataset == 'charades_ego': |
|
m_ap, _, m_aps = charades_map(preds, targets) |
|
print('mAP = {:.3f}'.format(m_ap)) |
|
elif self.dataset == 'egtea': |
|
cm = confusion_matrix(targets, preds.argmax(axis=1)) |
|
mean_class_acc, acc = get_mean_accuracy(cm) |
|
print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc)) |
|
else: |
|
raise NotImplementedError |
|
|
|
|