Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
# coding=utf-8 | |
__author__ = "aagrawal" | |
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: | |
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py). | |
import sys | |
import re | |
class VQAEval: | |
def __init__(self, vqa=None, vqaRes=None, n=2): | |
self.n = n | |
self.accuracy = {} | |
self.evalQA = {} | |
self.evalQuesType = {} | |
self.evalAnsType = {} | |
self.vqa = vqa | |
self.vqaRes = vqaRes | |
if vqa is not None: | |
self.params = {"question_id": vqa.getQuesIds()} | |
self.contractions = { | |
"aint": "ain't", | |
"arent": "aren't", | |
"cant": "can't", | |
"couldve": "could've", | |
"couldnt": "couldn't", | |
"couldn'tve": "couldn't've", | |
"couldnt've": "couldn't've", | |
"didnt": "didn't", | |
"doesnt": "doesn't", | |
"dont": "don't", | |
"hadnt": "hadn't", | |
"hadnt've": "hadn't've", | |
"hadn'tve": "hadn't've", | |
"hasnt": "hasn't", | |
"havent": "haven't", | |
"hed": "he'd", | |
"hed've": "he'd've", | |
"he'dve": "he'd've", | |
"hes": "he's", | |
"howd": "how'd", | |
"howll": "how'll", | |
"hows": "how's", | |
"Id've": "I'd've", | |
"I'dve": "I'd've", | |
"Im": "I'm", | |
"Ive": "I've", | |
"isnt": "isn't", | |
"itd": "it'd", | |
"itd've": "it'd've", | |
"it'dve": "it'd've", | |
"itll": "it'll", | |
"let's": "let's", | |
"maam": "ma'am", | |
"mightnt": "mightn't", | |
"mightnt've": "mightn't've", | |
"mightn'tve": "mightn't've", | |
"mightve": "might've", | |
"mustnt": "mustn't", | |
"mustve": "must've", | |
"neednt": "needn't", | |
"notve": "not've", | |
"oclock": "o'clock", | |
"oughtnt": "oughtn't", | |
"ow's'at": "'ow's'at", | |
"'ows'at": "'ow's'at", | |
"'ow'sat": "'ow's'at", | |
"shant": "shan't", | |
"shed've": "she'd've", | |
"she'dve": "she'd've", | |
"she's": "she's", | |
"shouldve": "should've", | |
"shouldnt": "shouldn't", | |
"shouldnt've": "shouldn't've", | |
"shouldn'tve": "shouldn't've", | |
"somebody'd": "somebodyd", | |
"somebodyd've": "somebody'd've", | |
"somebody'dve": "somebody'd've", | |
"somebodyll": "somebody'll", | |
"somebodys": "somebody's", | |
"someoned": "someone'd", | |
"someoned've": "someone'd've", | |
"someone'dve": "someone'd've", | |
"someonell": "someone'll", | |
"someones": "someone's", | |
"somethingd": "something'd", | |
"somethingd've": "something'd've", | |
"something'dve": "something'd've", | |
"somethingll": "something'll", | |
"thats": "that's", | |
"thered": "there'd", | |
"thered've": "there'd've", | |
"there'dve": "there'd've", | |
"therere": "there're", | |
"theres": "there's", | |
"theyd": "they'd", | |
"theyd've": "they'd've", | |
"they'dve": "they'd've", | |
"theyll": "they'll", | |
"theyre": "they're", | |
"theyve": "they've", | |
"twas": "'twas", | |
"wasnt": "wasn't", | |
"wed've": "we'd've", | |
"we'dve": "we'd've", | |
"weve": "we've", | |
"werent": "weren't", | |
"whatll": "what'll", | |
"whatre": "what're", | |
"whats": "what's", | |
"whatve": "what've", | |
"whens": "when's", | |
"whered": "where'd", | |
"wheres": "where's", | |
"whereve": "where've", | |
"whod": "who'd", | |
"whod've": "who'd've", | |
"who'dve": "who'd've", | |
"wholl": "who'll", | |
"whos": "who's", | |
"whove": "who've", | |
"whyll": "why'll", | |
"whyre": "why're", | |
"whys": "why's", | |
"wont": "won't", | |
"wouldve": "would've", | |
"wouldnt": "wouldn't", | |
"wouldnt've": "wouldn't've", | |
"wouldn'tve": "wouldn't've", | |
"yall": "y'all", | |
"yall'll": "y'all'll", | |
"y'allll": "y'all'll", | |
"yall'd've": "y'all'd've", | |
"y'alld've": "y'all'd've", | |
"y'all'dve": "y'all'd've", | |
"youd": "you'd", | |
"youd've": "you'd've", | |
"you'dve": "you'd've", | |
"youll": "you'll", | |
"youre": "you're", | |
"youve": "you've", | |
} | |
self.manualMap = { | |
"none": "0", | |
"zero": "0", | |
"one": "1", | |
"two": "2", | |
"three": "3", | |
"four": "4", | |
"five": "5", | |
"six": "6", | |
"seven": "7", | |
"eight": "8", | |
"nine": "9", | |
"ten": "10", | |
} | |
self.articles = ["a", "an", "the"] | |
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") | |
self.commaStrip = re.compile("(\d)(,)(\d)") | |
self.punct = [ | |
";", | |
r"/", | |
"[", | |
"]", | |
'"', | |
"{", | |
"}", | |
"(", | |
")", | |
"=", | |
"+", | |
"\\", | |
"_", | |
"-", | |
">", | |
"<", | |
"@", | |
"`", | |
",", | |
"?", | |
"!", | |
] | |
def evaluate(self, quesIds=None): | |
if quesIds == None: | |
quesIds = [quesId for quesId in self.params["question_id"]] | |
gts = {} | |
res = {} | |
for quesId in quesIds: | |
gts[quesId] = self.vqa.qa[quesId] | |
res[quesId] = self.vqaRes.qa[quesId] | |
# ================================================= | |
# Compute accuracy | |
# ================================================= | |
accQA = [] | |
accQuesType = {} | |
accAnsType = {} | |
print("computing accuracy") | |
step = 0 | |
for quesId in quesIds: | |
resAns = res[quesId]["answer"] | |
resAns = resAns.replace("\n", " ") | |
resAns = resAns.replace("\t", " ") | |
resAns = resAns.strip() | |
resAns = self.processPunctuation(resAns) | |
resAns = self.processDigitArticle(resAns) | |
gtAcc = [] | |
gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]] | |
if len(set(gtAnswers)) > 1: | |
for ansDic in gts[quesId]["answers"]: | |
ansDic["answer"] = self.processPunctuation(ansDic["answer"]) | |
for gtAnsDatum in gts[quesId]["answers"]: | |
otherGTAns = [ | |
item for item in gts[quesId]["answers"] if item != gtAnsDatum | |
] | |
matchingAns = [item for item in otherGTAns if item["answer"] == resAns] | |
acc = min(1, float(len(matchingAns)) / 3) | |
gtAcc.append(acc) | |
quesType = gts[quesId]["question_type"] | |
ansType = gts[quesId]["answer_type"] | |
avgGTAcc = float(sum(gtAcc)) / len(gtAcc) | |
accQA.append(avgGTAcc) | |
if quesType not in accQuesType: | |
accQuesType[quesType] = [] | |
accQuesType[quesType].append(avgGTAcc) | |
if ansType not in accAnsType: | |
accAnsType[ansType] = [] | |
accAnsType[ansType].append(avgGTAcc) | |
self.setEvalQA(quesId, avgGTAcc) | |
self.setEvalQuesType(quesId, quesType, avgGTAcc) | |
self.setEvalAnsType(quesId, ansType, avgGTAcc) | |
if step % 100 == 0: | |
self.updateProgress(step / float(len(quesIds))) | |
step = step + 1 | |
self.setAccuracy(accQA, accQuesType, accAnsType) | |
print("Done computing accuracy") | |
def processPunctuation(self, inText): | |
outText = inText | |
for p in self.punct: | |
if (p + " " in inText or " " + p in inText) or ( | |
re.search(self.commaStrip, inText) != None | |
): | |
outText = outText.replace(p, "") | |
else: | |
outText = outText.replace(p, " ") | |
outText = self.periodStrip.sub("", outText, re.UNICODE) | |
return outText | |
def processDigitArticle(self, inText): | |
outText = [] | |
tempText = inText.lower().split() | |
for word in tempText: | |
word = self.manualMap.setdefault(word, word) | |
if word not in self.articles: | |
outText.append(word) | |
else: | |
pass | |
for wordId, word in enumerate(outText): | |
if word in self.contractions: | |
outText[wordId] = self.contractions[word] | |
outText = " ".join(outText) | |
return outText | |
def setAccuracy(self, accQA, accQuesType, accAnsType): | |
self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n) | |
self.accuracy["perQuestionType"] = { | |
quesType: round( | |
100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]), | |
self.n, | |
) | |
for quesType in accQuesType | |
} | |
self.accuracy["perAnswerType"] = { | |
ansType: round( | |
100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n | |
) | |
for ansType in accAnsType | |
} | |
def setEvalQA(self, quesId, acc): | |
self.evalQA[quesId] = round(100 * acc, self.n) | |
def setEvalQuesType(self, quesId, quesType, acc): | |
if quesType not in self.evalQuesType: | |
self.evalQuesType[quesType] = {} | |
self.evalQuesType[quesType][quesId] = round(100 * acc, self.n) | |
def setEvalAnsType(self, quesId, ansType, acc): | |
if ansType not in self.evalAnsType: | |
self.evalAnsType[ansType] = {} | |
self.evalAnsType[ansType][quesId] = round(100 * acc, self.n) | |
def updateProgress(self, progress): | |
barLength = 20 | |
status = "" | |
if isinstance(progress, int): | |
progress = float(progress) | |
if not isinstance(progress, float): | |
progress = 0 | |
status = "error: progress var must be float\r\n" | |
if progress < 0: | |
progress = 0 | |
status = "Halt...\r\n" | |
if progress >= 1: | |
progress = 1 | |
status = "Done...\r\n" | |
block = int(round(barLength * progress)) | |
text = "\rFinshed Percent: [{0}] {1}% {2}".format( | |
"#" * block + "-" * (barLength - block), int(progress * 100), status | |
) | |
sys.stdout.write(text) | |
sys.stdout.flush() | |