Meteor / utils /utils.py
BK-Lee's picture
v1
8a25753
raw
history blame contribute delete
No virus
6.74 kB
import gc
import torch
def memory_optimization():
# memory deallocation
gc.collect()
# removing cache
torch.cuda.empty_cache()
def freeze_model(model):
for param in model.parameters():
param.requires_grad=False
def find_special_token(string, special_token):
start = 0
while True:
start = string.find(special_token, start)
if start == -1: return
yield start
start += len(special_token) # use start += 1 to find overlapping matches
def insert_tor(sentence, tor_count):
words = sentence.split()
gap = len(words) // (tor_count-1)
# filtering
if 0<=gap<=2:
return False
count = 0
result = ""
for i, word in enumerate(words):
if 0<i<len(words)-1:
result+=' '
if i % gap == 0 and count != tor_count-1:
result += '<tor>'
count += 1
result += word
result = result + "<tor>"
assert len(list(find_special_token(result, '<tor>'))) == tor_count
return result
def add_bundle_tokens(input_string, special_token, num):
# number of special tokens in input_string
num_special_tokens = len(list(find_special_token(input_string, special_token)))
# No special token -> return the raw
if not num_special_tokens:
return input_string
result = ""
index = 0
while index < len(input_string):
if input_string[index:index + len(special_token)] == special_token:
result += special_token * num
index += len(special_token)
else:
result += input_string[index]
index += 1
assert len(list(find_special_token(result, special_token))) == num_special_tokens * num
return result
def make_instruction_for_mmamba(question, tor=None):
if tor:
qa_prompt = make_human_string(f"<s>[UNUSED_TOKEN_146]user\n{question}[UNUSED_TOKEN_145]",
f"[UNUSED_TOKEN_146]rationale\n{tor}[UNUSED_TOKEN_145]\n</s>",
split='\n')
else:
qa_prompt = make_human_string(f"<s>[UNUSED_TOKEN_146]user\n{question}[UNUSED_TOKEN_145]",
f"[UNUSED_TOKEN_146]rationale\n"+"<tor>"*10+"[UNUSED_TOKEN_145]\n</s>",
split='\n')
return qa_prompt
def make_instruction_for_eval_meteor(question, dataset):
system_prompt = "You should give helpful answer to user based on the rationale."
if dataset != "mmmu" and dataset != "mathverse" and dataset != "hallusionbench" and dataset != "demo":
question = "<image>" + question
if dataset in ["sqa", "mmbench", "mmbench_cn", "mmbench_dev", "mmbench_cn_dev", "seed", "qbench", "ai2d", "mmstar"]:
question = question + "\nAnswer with the option's letter from the given choices directly."
elif dataset in ["vqav2", "gqa", "pope", "chartqa"]:
question = question + "\nAnswer the question using a single word or phrase."
elif dataset in ["vizwiz"]:
question = question + "\nWhen the provided information is insufficient, respond with 'Unanswerable'. Answer the question using a single word or phrase."
elif dataset in ["mmmu"]:
if "A." in question:
question = question + "\nAnswer with the option's letter from the given choices directly."
else:
question = question + "\nAnswer the question using a single word or phrase."
elif dataset in ["hallusionbench"]:
if "Please answer yes or no." not in question:
question = question + "Please answer yes or no."
qa_prompt = make_human_string("<s>"+"<tor>"*10+f"[UNUSED_TOKEN_146]system\n{system_prompt}[UNUSED_TOKEN_145]",
f"[UNUSED_TOKEN_146]user\n{question}[UNUSED_TOKEN_145]",
"[UNUSED_TOKEN_146]assistant\n",
split='\n')
return qa_prompt
def make_human_string(*args, split):
out = ''
for i, arg in enumerate(args):
out += arg
if i != len(args)-1:
out += split
return out
def get_max_new_tokens(data_name):
if data_name.lower() in ["mme", "pope", "sqa", "mmbench", "mmbench_cn", "mmbench_dev","mmbench_cn_dev", "seed", "qbench", "ai2d", "mmstar", "vqav2", "gqa", "chartqa", "hallusionbench", "textvqa", "mmmu"]:
return 5
if data_name.lower() in ["llava", "mm-vet"]:
return 1024
else:
return 128
"""
Print Data Statistics
"""
def print_data_statistics(data):
# name set
name_set = {'caption',
'instruction',
'minigemini',
'docdownstream',
'docreason',
'gllava',
'mathvision',
'mathinstruct',
'mathplus'}
caption = []
instruction = []
minigemini = []
docdownstream = []
docreason = []
gllava = []
mathvision = []
mathinstruct = []
mathplus = []
for d in data:
for name in name_set:
if name in d['id']:
eval(f'{name}.append(1)')
break
num_caption = sum(caption)
num_instruction = sum(instruction)
num_minigemini = sum(minigemini)
num_docdownstream = sum(docdownstream)
num_docreason = sum(docreason)
num_gllava = sum(gllava)
num_mathvision = sum(mathvision)
num_mathinstruct = sum(mathinstruct)
num_mathplus = sum(mathplus)
total_len = num_caption + num_instruction + num_minigemini + \
num_docdownstream + num_docreason + num_gllava + \
num_mathvision + num_mathinstruct + num_mathplus
print('Meteor Dataset Structure Statistics')
print(f'Total Length: {total_len}')
print('--------------------------------------------')
print(f'ShareGPT4V-Caption: {num_caption}')
print(f'ShareGPT4V-Instruction: {num_instruction}')
print(f'MiniGemini: {num_minigemini}')
print(f'DocDownstream: {num_docdownstream}')
print(f'DocReason: {num_docreason}')
print(f'GLLaVA: {num_gllava}')
print(f'MathVision: {num_mathvision}')
print(f'MathInstruct: {num_mathinstruct}')
print(f'MathPlus: {num_mathplus}')
print('--------------------------------------------')
print(f'Real-World Image: {num_caption + num_instruction}')
print(f'Document & Chart & Diagram & Sign & Symbol: {num_minigemini + num_docdownstream + num_docreason}')
print(f'Math: {num_gllava + num_mathvision + num_mathinstruct + num_mathplus}')
print(f' Math with Vision: {num_gllava + num_mathvision}')
print(f' Math with Text only: {num_mathinstruct + num_mathplus}')
print('--------------------------------------------')
print('')