TAPASxHF2 / question_answerer.py
jskinner215's picture
Upload 17 files
25fc3a2
raw
history blame contribute delete
674 Bytes
from transformers import TapasTokenizer, TapasForQuestionAnswering, pipeline
class QuestionAnswerer:
def __init__(self):
self.model_name = "google/tapas-large-finetuned-wtq"
self.tokenizer = TapasTokenizer.from_pretrained(self.model_name)
self.model = TapasForQuestionAnswering.from_pretrained(self.model_name)
self.pipe = pipeline("table-question-answering", model=self.model, tokenizer=self.tokenizer)
def query_table(self, query):
# Implement the logic to query the chroma db for the relevant table
pass
def answer_question(self, query, table):
inputs = self.pipe(table, query)
return inputs