|
|
|
|
|
|
|
|
|
|
|
|
|
import csv |
|
|
|
from lavila.models.tokenizer import MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer |
|
|
|
|
|
def generate_label_map(dataset): |
|
if dataset == 'ek100_cls': |
|
print("Preprocess ek100 action label space") |
|
vn_list = [] |
|
mapping_vn2narration = {} |
|
for f in [ |
|
'/data/EK100/epic-kitchens-100-annotations/EPIC_100_train.csv', |
|
'/data/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv', |
|
]: |
|
csv_reader = csv.reader(open(f)) |
|
_ = next(csv_reader) |
|
for row in csv_reader: |
|
vn = '{}:{}'.format(int(row[10]), int(row[12])) |
|
narration = row[8] |
|
if vn not in vn_list: |
|
vn_list.append(vn) |
|
if vn not in mapping_vn2narration: |
|
mapping_vn2narration[vn] = [narration] |
|
else: |
|
mapping_vn2narration[vn].append(narration) |
|
|
|
vn_list = sorted(vn_list) |
|
print('# of action= {}'.format(len(vn_list))) |
|
mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)} |
|
labels = [list(set(mapping_vn2narration[vn_list[i]])) for i in range(len(mapping_vn2act))] |
|
print(labels[:5]) |
|
elif dataset == 'charades_ego': |
|
print("=> preprocessing charades_ego action label space") |
|
vn_list = [] |
|
labels = [] |
|
with open('data/charades_ego/Charades_v1_classes.txt') as f: |
|
csv_reader = csv.reader(f) |
|
for row in csv_reader: |
|
vn = row[0][:4] |
|
vn_list.append(vn) |
|
narration = row[0][5:] |
|
labels.append(narration) |
|
mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)} |
|
print(labels[:5]) |
|
elif dataset == 'egtea': |
|
print("=> preprocessing egtea action label space") |
|
labels = [] |
|
with open('/data/EGTEA/action_idx.txt') as f: |
|
for row in f: |
|
row = row.strip() |
|
narration = ' '.join(row.split(' ')[:-1]) |
|
labels.append(narration.replace('_', ' ').lower()) |
|
|
|
mapping_vn2act = {label: i for i, label in enumerate(labels)} |
|
print(len(labels), labels[:5]) |
|
else: |
|
raise NotImplementedError |
|
return labels, mapping_vn2act |
|
|
|
|
|
def generate_tokenizer(model): |
|
if model.endswith('DISTILBERT_BASE'): |
|
tokenizer = MyDistilBertTokenizer('distilbert-base-uncased') |
|
elif model.endswith('BERT_BASE'): |
|
tokenizer = MyBertTokenizer('bert-base-uncased') |
|
elif model.endswith('BERT_LARGE'): |
|
tokenizer = MyBertTokenizer('bert-large-uncased') |
|
elif model.endswith('GPT2'): |
|
tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True) |
|
elif model.endswith('GPT2_MEDIUM'): |
|
tokenizer = MyGPT2Tokenizer('gpt2-medium', add_bos=True) |
|
elif model.endswith('GPT2_LARGE'): |
|
tokenizer = MyGPT2Tokenizer('gpt2-large', add_bos=True) |
|
elif model.endswith('GPT2_XL'): |
|
tokenizer = MyGPT2Tokenizer('gpt2-xl', add_bos=True) |
|
else: |
|
print("Using SimpleTokenizer because of model '{}'. " |
|
"Please check if this is what you want".format(model)) |
|
tokenizer = SimpleTokenizer() |
|
return tokenizer |
|
|