wolfrage89 commited on
Commit
c388795
1 Parent(s): e53bc50
Files changed (4) hide show
  1. .gitignore +2 -0
  2. app.py +354 -0
  3. requirements.txt +5 -0
  4. sample_qa.json +0 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv
2
+ trained_pytorch.pth
app.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf8
2
+
3
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import Dataset, DataLoader
7
+ import streamlit as st
8
+ import gdown
9
+ import numpy as np
10
+ import pandas as pd
11
+ import collections
12
+ from string import punctuation
13
+
14
+
15
+ class CONFIG:
16
+ #model params
17
+ model = 'deepset/xlm-roberta-large-squad2'
18
+ max_input_length = 384 #Hyperparameter to be tuned, following the guide from huggingface
19
+ doc_stride = 128 #Hyperparameter to be tuned, following the guide from huggingface
20
+ model_checkpoint = "pytorch_model.pth"
21
+ trained_model_url = 'https://drive.google.com/uc?id=16Vp918RglyLEFEyDlFuRD1HeNZ8SI7P5'
22
+ trained_model_output_fp = 'trained_pytorch.pth'
23
+ sample_df_fp = "sample_qa.json"
24
+
25
+ # model class
26
+ class ChaiModel(nn.Module):
27
+ def __init__(self, model_config):
28
+ super(ChaiModel, self).__init__()
29
+ self.backbone = AutoModel.from_pretrained(CONFIG.model)
30
+ self.linear = nn.Linear(model_config.hidden_size, 2)
31
+
32
+ def forward(self, input_ids, attention_mask):
33
+ model_output = self.backbone(input_ids, attention_mask=attention_mask)
34
+ sequence_output = model_output[0] # (batchsize, sequencelength, hidden_dim)
35
+
36
+ qa_logits = self.linear(sequence_output) # (batchsize, sequencelength, 2)
37
+ start_logit, end_logit = qa_logits.split(1, dim=-1) # (batchsize, sequencelength), 1), (batchsize, sequencelength, 1)
38
+ start_logits = start_logit.squeeze(-1) # remove last dim (batchsize, sequencelength)
39
+ end_logits = end_logit.squeeze(-1) #remove last dim (batchsize, sequencelength)
40
+
41
+ return start_logits, end_logits # (2,batchsize, sequencelength)
42
+
43
+ # dataset class
44
+ class ChaiDataset(Dataset):
45
+ def __init__(self, dataset, is_train=True):
46
+ super(ChaiDataset, self).__init__()
47
+ self.dataset = dataset #list of features
48
+ self.is_train= is_train
49
+
50
+ def __len__(self):
51
+ return len(self.dataset)
52
+
53
+ def __getitem__(self, index):
54
+ features = self.dataset[index]
55
+ if self.is_train:
56
+ return {
57
+ 'input_ids': torch.tensor(features['input_ids'], dtype=torch.long),
58
+ 'attention_mask': torch.tensor(features['attention_mask'], dtype=torch.long),
59
+ 'offset_mapping':torch.tensor(features['offset_mapping'], dtype=torch.long),
60
+ 'start_position':torch.tensor(features['start_position'], dtype=torch.long),
61
+ 'end_position':torch.tensor(features['end_position'], dtype=torch.long)
62
+ }
63
+ else:
64
+ return {
65
+ 'input_ids': torch.tensor(features['input_ids'], dtype=torch.long),
66
+ 'attention_mask': torch.tensor(features['attention_mask'], dtype=torch.long),
67
+ 'offset_mapping':torch.tensor(features['offset_mapping'], dtype=torch.long),
68
+ 'sequence_ids':features['sequence_ids'],
69
+ 'id':features['example_id'],
70
+ 'context':features['context'],
71
+ 'question':features['question']
72
+ }
73
+
74
+ def break_long_context(df, tokenizer, train=True):
75
+ if train:
76
+ n_examples = len(df)
77
+ full_set = []
78
+ for i in range(n_examples):
79
+ row = df.iloc[i]
80
+ # tokenizer parameters can be found here
81
+ # https://huggingface.co/transformers/internal/tokenization_utils.html#transformers.tokenization_utils_base.PreTrainedTokenizerBase
82
+ tokenized_examples = tokenizer(row['question'],
83
+ row['context'],
84
+ padding='max_length',
85
+ max_length=CONFIG.max_input_length,
86
+ truncation='only_second',
87
+ stride=CONFIG.doc_stride,
88
+ return_overflowing_tokens=True, #returns the number of over flow
89
+ return_offsets_mapping=True #returns the BPE mapping to the original word
90
+ )
91
+
92
+ # tokenized_example keys
93
+ #'input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'
94
+ sample_mappings = tokenized_examples.pop("overflow_to_sample_mapping")
95
+ offset_mappings = tokenized_examples.pop("offset_mapping")
96
+
97
+ final_examples = []
98
+ n_sub_examples = len(sample_mappings)
99
+ for j in range(n_sub_examples):
100
+ input_ids = tokenized_examples["input_ids"][j]
101
+ attention_mask = tokenized_examples["attention_mask"][j]
102
+
103
+ sliced_text = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids))
104
+ final_example = dict(input_ids = input_ids,
105
+ attention_mask = attention_mask,
106
+ sliced_text = sliced_text,
107
+ offset_mapping=offset_mappings[j],
108
+ fold=row['fold'])
109
+
110
+
111
+
112
+ # Most of the time cls_index is 0
113
+ cls_index = input_ids.index(tokenizer.cls_token_id)
114
+ # None, 0, 0, .... None, None, 1, 1,.....
115
+ sequence_ids = tokenized_examples.sequence_ids(j)
116
+
117
+ sample_index = sample_mappings[j]
118
+ offset_map = offset_mappings[j]
119
+
120
+ if np.isnan(row["answer_start"]) : # if no answer, start and end position is cls_index
121
+ final_example['start_position'] = cls_index
122
+ final_example['end_position'] = cls_index
123
+ final_example['tokenized_answer'] = ""
124
+ final_example['answer_text'] = ""
125
+ else:
126
+ start_char = row["answer_start"]
127
+ end_char = start_char + len(row["answer_text"])
128
+
129
+ token_start_index = sequence_ids.index(1)
130
+ token_end_index = len(sequence_ids)- 1 - (sequence_ids[::-1].index(1))
131
+
132
+ if not (offset_map[token_start_index][0]<=start_char and offset_map[token_end_index][1] >= end_char):
133
+ final_example['start_position'] = cls_index
134
+ final_example['end_position'] = cls_index
135
+ final_example['tokenized_answer'] = ""
136
+ final_example['answer_text'] = ""
137
+ else:
138
+ #Move token_start_index to the correct context index
139
+ while token_start_index < len(offset_map) and offset_map[token_start_index][0] <= start_char:
140
+ token_start_index +=1
141
+ final_example['start_position'] = token_start_index -1
142
+
143
+ while offset_map[token_end_index][1] >= end_char: #Take note that we will want the end_index inclusively, we will need to slice properly later
144
+ token_end_index -=1
145
+ final_example['end_position'] = token_end_index + 1
146
+ tokenized_answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[final_example['start_position']:final_example['end_position']+1]))
147
+ final_example['tokenized_answer'] = tokenized_answer
148
+ final_example['answer_text'] = row['answer_text']
149
+
150
+ final_examples.append(final_example)
151
+ full_set += final_examples
152
+
153
+ else:
154
+ n_examples = len(df)
155
+ full_set = []
156
+ for i in range(n_examples):
157
+ row = df.iloc[i]
158
+ tokenized_examples = tokenizer(row['question'],
159
+ row['context'],
160
+ padding='max_length',
161
+ max_length=CONFIG.max_input_length,
162
+ truncation='only_second',
163
+ stride=CONFIG.doc_stride,
164
+ return_overflowing_tokens=True, #returns the number of over flow
165
+ return_offsets_mapping=True #returns the BPE mapping to the original word
166
+ )
167
+
168
+ sample_mappings = tokenized_examples.pop("overflow_to_sample_mapping")
169
+ offset_mappings = tokenized_examples.pop("offset_mapping")
170
+ n_sub_examples = len(sample_mappings)
171
+
172
+ final_examples = []
173
+ for j in range(n_sub_examples):
174
+ input_ids = tokenized_examples["input_ids"][j]
175
+ attention_mask = tokenized_examples["attention_mask"][j]
176
+
177
+ final_example = dict(
178
+ input_ids = input_ids,
179
+ attention_mask = attention_mask,
180
+ offset_mapping=offset_mappings[j],
181
+ example_id = row['id'],
182
+ context = row['context'],
183
+ question = row['question'],
184
+ sequence_ids = [0 if value is None else value for value in tokenized_examples.sequence_ids(j)]
185
+ )
186
+
187
+ final_examples.append(final_example)
188
+ full_set += final_examples
189
+ return full_set
190
+
191
+ def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):
192
+ all_start_logits, all_end_logits = raw_predictions
193
+
194
+ example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
195
+ features_per_example = collections.defaultdict(list)
196
+ for i, feature in enumerate(features):
197
+ features_per_example[example_id_to_index[feature["example_id"]]].append(i)
198
+
199
+ predictions = collections.OrderedDict()
200
+
201
+ print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
202
+
203
+ for example_index, example in examples.iterrows():
204
+ feature_indices = features_per_example[example_index]
205
+
206
+ min_null_score = None
207
+ valid_answers = []
208
+
209
+ context = example["context"]
210
+ for feature_index in feature_indices:
211
+ start_logits = all_start_logits[feature_index]
212
+ end_logits = all_end_logits[feature_index]
213
+
214
+ sequence_ids = features[feature_index]["sequence_ids"]
215
+ context_index = 1
216
+
217
+ features[feature_index]["offset_mapping"] = [
218
+ (o if sequence_ids[k] == context_index else None)
219
+ for k, o in enumerate(features[feature_index]["offset_mapping"])
220
+ ]
221
+ offset_mapping = features[feature_index]["offset_mapping"]
222
+ cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
223
+ feature_null_score = start_logits[cls_index] + end_logits[cls_index]
224
+ if min_null_score is None or min_null_score < feature_null_score:
225
+ min_null_score = feature_null_score
226
+
227
+ start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
228
+ end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
229
+ for start_index in start_indexes:
230
+ for end_index in end_indexes:
231
+ if (
232
+ start_index >= len(offset_mapping)
233
+ or end_index >= len(offset_mapping)
234
+ or offset_mapping[start_index] is None
235
+ or offset_mapping[end_index] is None
236
+ ):
237
+ continue
238
+ # Don't consider answers with a length that is either < 0 or > max_answer_length.
239
+ if end_index < start_index or end_index - start_index + 1 > max_answer_length:
240
+ continue
241
+
242
+ start_char = offset_mapping[start_index][0]
243
+ end_char = offset_mapping[end_index][1]
244
+ valid_answers.append(
245
+ {
246
+ "score": start_logits[start_index] + end_logits[end_index],
247
+ "text": context[start_char: end_char]
248
+ }
249
+ )
250
+
251
+ if len(valid_answers) > 0:
252
+ best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
253
+ else:
254
+ best_answer = {"text": "", "score": 0.0}
255
+
256
+ predictions[example["id"]] = best_answer["text"]
257
+
258
+
259
+ return predictions
260
+
261
+ def download_finetuned_model():
262
+ gdown.download(url=CONFIG.trained_model_url, output=CONFIG.trained_model_output_fp, quiet=False)
263
+
264
+ def get_prediction(context:str, question:str, model, tokenizer) -> str:
265
+ # convert to dataframe format to make it consistent with training way
266
+ test_df = pd.DataFrame({"id":[1], "context":[context.strip()], "question":[question.strip()]})
267
+ test_set = break_long_context(test_df, tokenizer, train=False)
268
+
269
+ #create dataset and dataloader of batch 1 to prevent OOM
270
+ test_dataset = ChaiDataset(test_set, is_train=False)
271
+ test_dataloader = DataLoader(test_dataset,
272
+ batch_size=1,
273
+ shuffle=False,
274
+ drop_last=False
275
+ )
276
+
277
+ #main prediction function
278
+ start_logits =[]
279
+ end_logits=[]
280
+
281
+ for features in test_dataloader:
282
+ input_ids = features['input_ids']
283
+ attention_mask = features['attention_mask']
284
+ with torch.no_grad():
285
+ start_logit, end_logit = model(input_ids, attention_mask) #(batch, 384,1) , (batch, 384,1)
286
+ start_logits.append(start_logit.to("cpu").numpy())
287
+ end_logits.append(end_logit.to("cpu").numpy())
288
+
289
+ start_logits, end_logits = np.vstack(start_logits), np.vstack(end_logits)
290
+
291
+ predictions = postprocess_qa_predictions(test_df, test_set, (start_logits, end_logits))
292
+ predictions = list(predictions.items())[0][1]
293
+ predictions = predictions.strip(punctuation)
294
+
295
+ return predictions
296
+
297
+ @st.cache(allow_output_mutation=True)
298
+ def load_model():
299
+ gdown.download(url=CONFIG.trained_model_url, output=CONFIG.trained_model_output_fp, quiet=False)
300
+ print("Downloaded pretrained model")
301
+ config = AutoConfig.from_pretrained(CONFIG.model)
302
+ model = ChaiModel(config)
303
+ model.load_state_dict(torch.load(CONFIG.trained_model_output_fp, map_location=torch.device('cpu')))
304
+ model.eval()
305
+ tokenizer = AutoTokenizer.from_pretrained(CONFIG.model)
306
+ sample_df = pd.read_json(CONFIG.sample_df_fp)
307
+ return model, tokenizer, sample_df
308
+
309
+
310
+
311
+ model, tokenizer, sample_df = load_model()
312
+
313
+
314
+ ## initialize session_state
315
+ if "context" not in st.session_state:
316
+ st.session_state["context"] = ""
317
+ if "question" not in st.session_state:
318
+ st.session_state['question'] = ""
319
+ if "answer" not in st.session_state:
320
+ st.session_state['answer'] = ""
321
+
322
+
323
+ ## Layout
324
+ st.sidebar.title("Hindi/Tamil Extractive Question Answering")
325
+ st.sidebar.markdown("---")
326
+ random_button = st.sidebar.button("Random")
327
+ st.sidebar.write("Randomly Generates a Hindi/Tamil Context and Question")
328
+ st.sidebar.markdown("---")
329
+ answer_button = st.sidebar.button("Answer!")
330
+
331
+ if random_button:
332
+ sample = sample_df.sample(1)
333
+ st.session_state['context'] = sample['context'].item()
334
+ st.session_state['question'] = sample['question'].item()
335
+ st.session_state['answer'] = ""
336
+
337
+ if answer_button:
338
+ # if question or context is empty text
339
+ if len(st.session_state['context']) == 0 or len(st.session_state['question']) ==0:
340
+ st.session_state['answer'] = " "
341
+ else:
342
+ st.session_state['answer'] = get_prediction(st.session_state['context'], st.session_state['question'], model, tokenizer)
343
+
344
+
345
+ st.session_state["context"] = st.text_area("Context", value=st.session_state['context'], height=300)
346
+
347
+ with st.container():
348
+ col_1, col_2 = st.columns(2)
349
+ with col_1:
350
+ st.session_state['question'] = st.text_area("Question", value=st.session_state['question'], height=200)
351
+
352
+ with col_2:
353
+ st.text_area("Answer", value=st.session_state['answer'], height=200)
354
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ sentencepiece
3
+ transformers
4
+ streamlit==1.0.0
5
+ gdown==4.2.0
sample_qa.json ADDED
The diff for this file is too large to render. See raw diff