Sybghat commited on
Commit
c389554
1 Parent(s): 670bb5f

Create Q_A.py

Browse files
Files changed (1) hide show
  1. Q_A.py +56 -0
Q_A.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pysbd
3
+ from transformers import pipeline
4
+ from sentence_transformers import CrossEncoder
5
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline
6
+
7
+ class QuestionAnswering:
8
+
9
+ def __init__(self):
10
+ model_name = "MaRiOrOsSi/t5-base-finetuned-question-answering"
11
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ self.model = AutoModelWithLMHead.from_pretrained(model_name)
13
+ self.sentence_segmenter = pysbd.Segmenter(language='en',clean=False)
14
+ self.passage_retreival_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
15
+ self.qa_model = pipeline("question-answering",'a-ware/bart-squadv2')
16
+
17
+ def fetch_answers(self, question, document):
18
+ document_paragraphs = document.splitlines()
19
+ query_paragraph_list = [(question, para) for para in document_paragraphs if len(para.strip()) > 0 ]
20
+
21
+ scores = self.passage_retreival_model.predict(query_paragraph_list)
22
+ top_5_indices = scores.argsort()[-3:]
23
+ top_5_query_paragraph_list = [query_paragraph_list[i] for i in top_5_indices ]
24
+ top_5_query_paragraph_list.reverse()
25
+
26
+ top_5_query_paragraph_answer_list = ""
27
+ count = 1
28
+ for query, passage in top_5_query_paragraph_list:
29
+ passage_sentences = self.sentence_segmenter.segment(passage)
30
+ answer = self.qa_model(question = query, context = passage)['answer']
31
+ evidence_sentence = ""
32
+ for i in range(len(passage_sentences)):
33
+ if answer.startswith('.') or answer.startswith(':'):
34
+ answer = answer[1:].strip()
35
+ if answer in passage_sentences[i]:
36
+ evidence_sentence = evidence_sentence + " " + passage_sentences[i]
37
+
38
+
39
+ model_input = f"question: {query} context: {evidence_sentence}"
40
+ encoded_input = self.tokenizer([model_input],
41
+ return_tensors='pt',
42
+ max_length=512,
43
+ truncation=True)
44
+
45
+ output = self.model.generate(input_ids = encoded_input.input_ids,
46
+ attention_mask = encoded_input.attention_mask)
47
+ output_answer = self.tokenizer.decode(output[0], skip_special_tokens=True)
48
+
49
+ result_str = ""+str(count)+": "+ output_answer +"\n"
50
+ result_str = result_str + " "+ evidence_sentence + "\n\n"
51
+ top_5_query_paragraph_answer_list += result_str
52
+ count+=1
53
+
54
+ return top_5_query_paragraph_answer_list
55
+
56
+