patrawtf's picture
Update app/tapas.py
6d20a1a
raw
history blame
894 Bytes
from transformers import TapasTokenizer, TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import datetime
def execute_query(query, csv_file):
a = datetime.datetime.now()
table = pd.read_csv(csv_file.name, delimiter=",")
table.fillna(0, inplace=True)
table = table.astype(str)
model_name = "microsoft/tapex-large-finetuned-wtq"
model = BartForConditionalGeneration.from_pretrained(model_name)
tokenizer = TapexTokenizer.from_pretrained(model_name)
queries = [query]
encoding = tokenizer(table=table, queries=queries, padding=True, return_tensors="tf",truncated=True)
outputs = model.generate(**encoding)
ans = tokenizer.batch_decode(outputs, skip_special_tokens=True)
query_result = {
"query": query,
"answer": ans[0]
}
b = datetime.datetime.now()
print(b - a)
return query_result, table