hvaldez's picture
first commit
c18a21e verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import csv
from lavila.models.tokenizer import MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer
def generate_label_map(dataset):
if dataset == 'ek100_cls':
print("Preprocess ek100 action label space")
vn_list = []
mapping_vn2narration = {}
for f in [
'/data/EK100/epic-kitchens-100-annotations/EPIC_100_train.csv',
'/data/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv',
]:
csv_reader = csv.reader(open(f))
_ = next(csv_reader) # skip the header
for row in csv_reader:
vn = '{}:{}'.format(int(row[10]), int(row[12]))
narration = row[8]
if vn not in vn_list:
vn_list.append(vn)
if vn not in mapping_vn2narration:
mapping_vn2narration[vn] = [narration]
else:
mapping_vn2narration[vn].append(narration)
# mapping_vn2narration[vn] = [narration]
vn_list = sorted(vn_list)
print('# of action= {}'.format(len(vn_list)))
mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)}
labels = [list(set(mapping_vn2narration[vn_list[i]])) for i in range(len(mapping_vn2act))]
print(labels[:5])
elif dataset == 'charades_ego':
print("=> preprocessing charades_ego action label space")
vn_list = []
labels = []
with open('data/charades_ego/Charades_v1_classes.txt') as f:
csv_reader = csv.reader(f)
for row in csv_reader:
vn = row[0][:4]
vn_list.append(vn)
narration = row[0][5:]
labels.append(narration)
mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)}
print(labels[:5])
elif dataset == 'egtea':
print("=> preprocessing egtea action label space")
labels = []
with open('/data/EGTEA/action_idx.txt') as f:
for row in f:
row = row.strip()
narration = ' '.join(row.split(' ')[:-1])
labels.append(narration.replace('_', ' ').lower())
# labels.append(narration)
mapping_vn2act = {label: i for i, label in enumerate(labels)}
print(len(labels), labels[:5])
else:
raise NotImplementedError
return labels, mapping_vn2act
def generate_tokenizer(model):
if model.endswith('DISTILBERT_BASE'):
tokenizer = MyDistilBertTokenizer('distilbert-base-uncased')
elif model.endswith('BERT_BASE'):
tokenizer = MyBertTokenizer('bert-base-uncased')
elif model.endswith('BERT_LARGE'):
tokenizer = MyBertTokenizer('bert-large-uncased')
elif model.endswith('GPT2'):
tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True)
elif model.endswith('GPT2_MEDIUM'):
tokenizer = MyGPT2Tokenizer('gpt2-medium', add_bos=True)
elif model.endswith('GPT2_LARGE'):
tokenizer = MyGPT2Tokenizer('gpt2-large', add_bos=True)
elif model.endswith('GPT2_XL'):
tokenizer = MyGPT2Tokenizer('gpt2-xl', add_bos=True)
else:
print("Using SimpleTokenizer because of model '{}'. "
"Please check if this is what you want".format(model))
tokenizer = SimpleTokenizer()
return tokenizer