Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from transformers import AutoProcessor, AutoTokenizer, XCLIPVisionModel, AutoModel, AutoModelForSequenceClassification | |
import numpy as np | |
import cv2 | |
import opensmile | |
class TextClassificationModel: | |
def __init__(self, model, device): | |
self.model = model | |
self.device = device | |
self.model.to(device) | |
def __call__(self, input_ids, attn_mask, return_last_hidden_state=False): | |
self.model.eval() | |
with torch.no_grad(): | |
input_ids = input_ids.to(self.device) | |
attn_mask = attn_mask.to(self.device) | |
output = self.model(input_ids=input_ids, attention_mask=attn_mask, | |
output_hidden_states=return_last_hidden_state) | |
logits = output['logits'] | |
pred = torch.argmax(logits, dim=1) | |
if return_last_hidden_state: | |
hidden_states = output['hidden_states'] | |
if return_last_hidden_state: | |
return pred, hidden_states[-1][:, 0, :] | |
else: | |
return pred | |
class XCLIPClassificationModel(nn.Module): | |
def __init__(self, num_labels): | |
super(XCLIPClassificationModel, self).__init__() | |
self.base_model = XCLIPVisionModel.from_pretrained("microsoft/xclip-base-patch32") | |
self.num_labels = num_labels | |
hidden_size = self.base_model.config.hidden_size | |
self.fc_norm = nn.LayerNorm(hidden_size) | |
self.classifier = nn.Linear(hidden_size, self.num_labels) | |
self.loss_fct = nn.CrossEntropyLoss() | |
self.pool1 = nn.AdaptiveAvgPool1d(1) | |
self.pool2 = nn.AdaptiveAvgPool1d(1) | |
def forward(self, pixel_values, labels=None, return_last_hidden_state=False): | |
batch_size, num_frames, num_channels, height, width = pixel_values.shape | |
pixel_values = pixel_values.reshape(-1, num_channels, height, width) | |
out = self.base_model(pixel_values)[0] # [48, 50, 768] | |
out = torch.transpose(out, 1, 2) # [48, 768, 50] | |
out = self.pool1(out) # [48, 768, 1] | |
out = torch.transpose(out, 1, 2) # [48, 1, 768] | |
out = out.squeeze(1) # [48, 768] | |
hidden_out = out.view(batch_size, num_frames, -1) # [3, 16, 768] | |
hidden_out = torch.transpose(hidden_out, 1, 2) # [3, 768, 16] | |
pooled_out = self.pool2(hidden_out) # [3, 768, 1] | |
pooled_out = torch.transpose(pooled_out, 1, 2) # [3, 1, 768] | |
pooled_out = pooled_out[:, 0, :] # [3, 768] | |
logits = self.classifier(pooled_out) | |
loss = None | |
if labels is not None: | |
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
if return_last_hidden_state: | |
return {'logits': logits, 'loss': loss, 'last_hidden_state': pooled_out} | |
else: | |
return {'logits': logits, 'loss': loss} | |
class VideoClassificationModel: | |
def __init__(self, model, device): | |
self.model = model | |
self.device = device | |
self.model.to(device) | |
def __call__(self, pixel_values, return_last_hidden_state=False): | |
self.model.eval() | |
with torch.no_grad(): | |
pixel_values = pixel_values.to(self.device) | |
output = self.model(pixel_values, return_last_hidden_state=return_last_hidden_state) | |
logits = output['logits'] | |
pred = torch.argmax(logits, dim=1) | |
if return_last_hidden_state: | |
hidden_states = output['last_hidden_state'] | |
if return_last_hidden_state: | |
return pred, hidden_states | |
else: | |
return pred | |
class ConvNet(nn.Module): | |
def __init__(self, num_labels, n_input=1, n_channel=32): | |
super(ConvNet, self).__init__() | |
self.ln0 = nn.LayerNorm((1, 6191)) | |
self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=3) | |
self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3) | |
self.bn1 = nn.BatchNorm1d(n_channel) | |
self.bn2 = nn.BatchNorm1d(n_channel) | |
self.pool1 = nn.MaxPool1d(2) | |
self.fc1 = nn.Linear(n_channel*3093, 3093) | |
self.fc2 = nn.Linear(3093, num_labels) | |
self.flat = nn.Flatten() | |
self.dropout = nn.Dropout(0.3) | |
def forward(self, x, return_last_hidden_state=False): | |
x = self.ln0(x) | |
x = self.conv1(x) | |
x = F.relu(self.bn1(x)) | |
x = self.conv2(x) | |
x = F.relu(self.bn2(x)) | |
x = self.pool1(x) | |
x = self.dropout(x) | |
x = self.flat(x) | |
hid = F.relu(self.fc1(x)) | |
x = self.fc2(hid) | |
if not return_last_hidden_state: | |
return {'logits': F.log_softmax(x, dim=1)} | |
else: | |
return {'logits': F.log_softmax(x, dim=1), 'last_hidden_state': hid} | |
class AudioClassificationModel: | |
def __init__(self, model, device): | |
self.model = model | |
self.device = device | |
self.model.to(device) | |
def __call__(self, input_ids, return_last_hidden_state=False): | |
self.model.eval() | |
with torch.no_grad(): | |
input_ids = torch.tensor(input_ids, dtype=torch.float).to(self.device) | |
output = self.model(input_ids, return_last_hidden_state=return_last_hidden_state) | |
logits = output['logits'] | |
pred = torch.argmax(logits, dim=1) | |
if return_last_hidden_state: | |
hidden_state = output['last_hidden_state'] | |
if return_last_hidden_state: | |
return pred, hidden_state | |
else: | |
return pred | |
class MultimodalClassificationModel(nn.Module): | |
def __init__(self, text_model, video_model, audio_model, num_labels, input_size, hidden_size=256): | |
super(MultimodalClassificationModel, self).__init__() | |
self.text_model = text_model | |
self.video_model = video_model | |
self.audio_model = audio_model | |
self.num_labels = num_labels | |
self.linear1 = nn.Linear(input_size, hidden_size) | |
self.linear2 = nn.Linear(hidden_size, self.num_labels) | |
self.relu1 = nn.ReLU() | |
self.drop1 = nn.Dropout() | |
self.loss_func = nn.CrossEntropyLoss() | |
def forward(self, batch, labels=None): | |
text_pred, text_last_hidden = self.text_model( | |
batch['text']['input_ids'].squeeze(1), | |
batch['text']['attention_mask'].squeeze(1), | |
return_last_hidden_state=True | |
) | |
video_pred, video_last_hidden = self.video_model( | |
batch['video']['pixel_values'].squeeze(1), | |
return_last_hidden_state=True | |
) | |
audio_pred, audio_last_hidden = self.audio_model( | |
batch['audio'], | |
return_last_hidden_state=True | |
) | |
concat_input = torch.cat((text_last_hidden, video_last_hidden, audio_last_hidden), dim=1) | |
hidden_state = self.linear1(concat_input) | |
hidden_state = self.drop1(self.relu1(hidden_state)) | |
logits = self.linear2(hidden_state) | |
loss = None | |
if labels is not None: | |
loss = self.loss_func(logits.view(-1, self.num_labels), labels.view(-1)) | |
return {'logits': logits, 'loss': loss} | |
class MainModel: | |
def __init__(self, model, device): | |
self.model = model | |
self.device = device | |
self.model.to(device) | |
def __call__(self, batch): | |
self.model.eval() | |
with torch.no_grad(): | |
output = self.model(batch) | |
logits = output['logits'] | |
pred = torch.argmax(logits, dim=1) | |
return pred | |
def prepare_models(num_labels: int, | |
text_model_path: str, | |
video_model_path: str, | |
audio_model_path: str, | |
device: str='cpu'): | |
# TEXT | |
text_model_name = 'bert-large-uncased' | |
text_base_model = AutoModelForSequenceClassification.from_pretrained( | |
text_model_name, | |
num_labels=num_labels | |
) | |
state_dict = torch.load(text_model_path, map_location=torch.device('cpu')) | |
text_base_model.load_state_dict(state_dict, strict=False) | |
text_model = TextClassificationModel(text_base_model, device=device) | |
# VIDEO | |
video_base_model = XCLIPClassificationModel(num_labels) | |
state_dict = torch.load(video_model_path, map_location=torch.device('cpu')) | |
video_base_model.load_state_dict(state_dict, strict=False) | |
video_model = VideoClassificationModel(video_base_model, device=device) | |
# AUDIO | |
audio_base_model = ConvNet(num_labels) | |
checkpoint = torch.load(audio_model_path, map_location=torch.device('cpu')) | |
audio_base_model.load_state_dict(checkpoint['model_state_dict']) | |
audio_model = AudioClassificationModel(audio_base_model, device=device) | |
return text_model, video_model, audio_model | |
def sample_frame_indices(seg_len, clip_len=16, frame_sample_rate=4, mode="video"): | |
# seg_len -- how many frames are received | |
# clip_len -- how many frames to return | |
converted_len = int(clip_len * frame_sample_rate) | |
converted_len = min(converted_len, seg_len-1) | |
end_idx = np.random.randint(converted_len, seg_len) | |
start_idx = end_idx - converted_len | |
if mode == "video": | |
indices = np.linspace(start_idx, end_idx, num=clip_len) | |
else: | |
indices = np.linspace(start_idx, end_idx, num=clip_len*frame_sample_rate) | |
indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) | |
return indices | |
def get_frames(file_path, clip_len=16,): | |
cap = cv2.VideoCapture(file_path) | |
v_len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
indices = sample_frame_indices(v_len) | |
frames = [] | |
for fn in range(v_len): | |
success, frame = cap.read() | |
if success is False: | |
continue | |
if (fn in indices): | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
res = cv2.resize(frame[90:-80, 60:-100], dsize=(224, 224), interpolation=cv2.INTER_CUBIC) | |
frames.append(res) | |
cap.release() | |
if len(frames) < clip_len: | |
add_num = clip_len - len(frames) | |
frames_to_add = [frames[-1]] * add_num | |
frames.extend(frames_to_add) | |
return frames | |
def prepare_data_input(text: str, | |
video_path: str): | |
# VIDEO | |
video_frames = get_frames(video_path) | |
video_model_name = "microsoft/xclip-base-patch32" | |
video_feature_extractor = AutoProcessor.from_pretrained(video_model_name) | |
video_encoding = video_feature_extractor(videos=video_frames, return_tensors="pt") | |
# AUDIO | |
smile = opensmile.Smile( | |
opensmile.FeatureSet.ComParE_2016, | |
opensmile.FeatureLevel.Functionals, | |
sampling_rate=16000, | |
resample=True, | |
num_workers=5, | |
verbose=True, | |
) | |
audio_features = smile.process_files([video_path]) | |
redundant_feat = open('files/redundant_feat.txt').read().split(',') | |
audio_features.drop(columns=redundant_feat, inplace=True) | |
# TEXT | |
text_model_name = 'bert-large-uncased' | |
tokenizer = AutoTokenizer.from_pretrained(text_model_name) | |
text_encoding = tokenizer(text, | |
padding='max_length', | |
truncation=True, | |
max_length=128, | |
return_tensors='pt') | |
return {'text': text_encoding, 'video': video_encoding, 'audio': audio_features.values.reshape((1, 1, 6191))} | |
def infer_multimodal_model(text: str, | |
video_path: str, | |
model_pathes: dict): | |
label2id = {'anger': 0, 'disgust': 1, 'fear': 2, 'joy': 3, 'neutral': 4, 'sadness': 5, 'surprise': 6} | |
id2label = {v: k for k, v in label2id.items()} | |
num_labels = 7 | |
text_model, video_model, audio_model = prepare_models(num_labels, | |
model_pathes['text_model_path'], | |
model_pathes['video_model_path'], | |
model_pathes['audio_model_path'],) | |
multi_model = MultimodalClassificationModel( | |
text_model, | |
video_model, | |
audio_model, | |
num_labels, | |
input_size=4885, | |
hidden_size=512 | |
) | |
checkpoint = torch.load(model_pathes['multimodal_model_path'], map_location=torch.device('cpu')) | |
multi_model.load_state_dict(checkpoint) | |
device = 'cpu' | |
final_model = MainModel(multi_model, device=device) | |
batch = prepare_data_input(text, video_path) | |
label = final_model(batch).detach().cpu().tolist() | |
return id2label[label[0]] |