hungdungn47
commited on
Commit
•
31521f5
1
Parent(s):
74d656b
fix chdg infer
Browse files- chdg_inference.py +14 -6
chdg_inference.py
CHANGED
@@ -287,7 +287,7 @@ def meanTokenVecs(text):
|
|
287 |
buffer, buffer_str = [], ''
|
288 |
else:
|
289 |
wordVecs[token[0]] = token[1]
|
290 |
-
|
291 |
return torch.mean(torch.stack([vec for w, vec in wordVecs.items() if w not in string.punctuation]), dim=0)
|
292 |
|
293 |
def getPositionEncoding(pos, d=768, n=10000):
|
@@ -299,7 +299,6 @@ def getPositionEncoding(pos, d=768, n=10000):
|
|
299 |
return P
|
300 |
|
301 |
PositionVec = torch.stack([torch.from_numpy(getPositionEncoding(i, d=768)) for i in range(200)], dim=0).float().to(device)
|
302 |
-
|
303 |
stop_w = ['...']
|
304 |
with open('./vietnamese-stopwords-dash.txt', 'r', encoding='utf-8') as f:
|
305 |
for w in f.readlines():
|
@@ -348,9 +347,19 @@ def loadClusterData(docs_org, category): # docs_org: list of text for each docum
|
|
348 |
for d, doc in enumerate(docs_org):
|
349 |
seclist[d], sentTexts = divideSection(doc, category)
|
350 |
docs.append(sentTexts)
|
351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
sents, sentVecs, secIDs, doc_lens = [], [], [], []
|
353 |
-
secnum = 4
|
354 |
sentnum = sum([len(doc.values()) for doc in seclist.values()])
|
355 |
doc_sec_mask = np.zeros((len(docs), secnum))
|
356 |
sec_sen_mask = np.zeros((secnum, sentnum))
|
@@ -366,7 +375,7 @@ def loadClusterData(docs_org, category): # docs_org: list of text for each docum
|
|
366 |
sentVecs.append(meanTokenVecs(sent))
|
367 |
sec_sen_mask[seclist[d][s], cursent] = 1
|
368 |
cursent += 1
|
369 |
-
|
370 |
return Cluster(sents, sentVecs, doc_lens, doc_sec_mask, sec_sen_mask)
|
371 |
|
372 |
def val_e2e(data, model, max_word_num=200, c_model=None):
|
@@ -414,7 +423,6 @@ c_model.load_state_dict(torch.load('./c_25_0.3071.mdl', map_location=device), st
|
|
414 |
def infer(docs, category):
|
415 |
# docs = [text.strip() for text in full_text.split('<><><><><>')]
|
416 |
docs = [text.strip() for text in docs]
|
417 |
-
print(docs)
|
418 |
data_tree = loadClusterData(docs, category)
|
419 |
summ = val_e2e(data_tree, model, c_model=c_model, max_word_num=200)
|
420 |
summ = re.sub(r'\s+([.,;:"?()/!?])', r'\1', summ.replace('_', ' '))
|
|
|
287 |
buffer, buffer_str = [], ''
|
288 |
else:
|
289 |
wordVecs[token[0]] = token[1]
|
290 |
+
|
291 |
return torch.mean(torch.stack([vec for w, vec in wordVecs.items() if w not in string.punctuation]), dim=0)
|
292 |
|
293 |
def getPositionEncoding(pos, d=768, n=10000):
|
|
|
299 |
return P
|
300 |
|
301 |
PositionVec = torch.stack([torch.from_numpy(getPositionEncoding(i, d=768)) for i in range(200)], dim=0).float().to(device)
|
|
|
302 |
stop_w = ['...']
|
303 |
with open('./vietnamese-stopwords-dash.txt', 'r', encoding='utf-8') as f:
|
304 |
for w in f.readlines():
|
|
|
347 |
for d, doc in enumerate(docs_org):
|
348 |
seclist[d], sentTexts = divideSection(doc, category)
|
349 |
docs.append(sentTexts)
|
350 |
+
|
351 |
+
secnum = 0
|
352 |
+
for k, val_dict in seclist.items():
|
353 |
+
vals = set(val_dict.values())
|
354 |
+
for ki, vi in val_dict.items():
|
355 |
+
for i, v in enumerate(vals):
|
356 |
+
if vi == v:
|
357 |
+
val_dict[ki] = i + secnum
|
358 |
+
break
|
359 |
+
seclist[k] = val_dict
|
360 |
+
secnum += len(vals)
|
361 |
+
|
362 |
sents, sentVecs, secIDs, doc_lens = [], [], [], []
|
|
|
363 |
sentnum = sum([len(doc.values()) for doc in seclist.values()])
|
364 |
doc_sec_mask = np.zeros((len(docs), secnum))
|
365 |
sec_sen_mask = np.zeros((secnum, sentnum))
|
|
|
375 |
sentVecs.append(meanTokenVecs(sent))
|
376 |
sec_sen_mask[seclist[d][s], cursent] = 1
|
377 |
cursent += 1
|
378 |
+
|
379 |
return Cluster(sents, sentVecs, doc_lens, doc_sec_mask, sec_sen_mask)
|
380 |
|
381 |
def val_e2e(data, model, max_word_num=200, c_model=None):
|
|
|
423 |
def infer(docs, category):
|
424 |
# docs = [text.strip() for text in full_text.split('<><><><><>')]
|
425 |
docs = [text.strip() for text in docs]
|
|
|
426 |
data_tree = loadClusterData(docs, category)
|
427 |
summ = val_e2e(data_tree, model, c_model=c_model, max_word_num=200)
|
428 |
summ = re.sub(r'\s+([.,;:"?()/!?])', r'\1', summ.replace('_', ' '))
|