Tapex_Q_A / app.py
17Goutham's picture
Update app.py
8a57a60 verified
raw
history blame contribute delete
No virus
2.19 kB
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import datetime
import torch
import gradio as gr
def execute_query(query, csv_file):
a = datetime.datetime.now()
table = pd.read_csv(csv_file.name, delimiter=",")
table = table.astype(str)
model_name = "microsoft/tapex-large-finetuned-wtq"
model = BartForConditionalGeneration.from_pretrained(model_name)
tokenizer = TapexTokenizer.from_pretrained(model_name)
queries = [query]
encoding = tokenizer(table=table, query=queries, padding=True, return_tensors="pt", truncation=True)
outputs = model.generate(**encoding)
ans = tokenizer.batch_decode(outputs, skip_special_tokens=True)
query_result = {
"query": query,
"answer": ans[0]
}
b = datetime.datetime.now()
print(b - a)
return query_result, table
def main():
description = "Querying a CSV using the TAPEX model. You can ask a question about tabular data, and the TAPEX model will produce the result. The finetuned TAPEX model runs on data with a maximum of 5000 rows and 20 columns. A sample dataset of Shopify store sales is provided."
article = "<p style='text-align: center'><a href='https://unscrambl.com/' target='_blank'>Unscrambl</a> | <a href='https://huggingface.co/microsoft/tapex-large-finetuned-wtq' target='_blank'>TAPEX Model</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=abaranovskij_tablequery' alt='visitor badge'></center>"
iface = gr.Interface(fn=execute_query,
inputs=[gr.Textbox(label="Search query"),
gr.File(label="CSV file")],
outputs=[gr.JSON(label="Result"),
gr.Dataframe(label="All data")],
title="Table Question Answering (TAPEX)",
description=description,
article=article,
allow_flagging='never')
# Use this config when running on Docker
# iface.launch(server_name="0.0.0.0", server_port=7000)
iface.launch(enable_queue=True)
if __name__ == "__main__":
main()