hungdungn47 commited on
Commit
31521f5
1 Parent(s): 74d656b

fix chdg infer

Browse files
Files changed (1) hide show
  1. chdg_inference.py +14 -6
chdg_inference.py CHANGED
@@ -287,7 +287,7 @@ def meanTokenVecs(text):
287
  buffer, buffer_str = [], ''
288
  else:
289
  wordVecs[token[0]] = token[1]
290
-
291
  return torch.mean(torch.stack([vec for w, vec in wordVecs.items() if w not in string.punctuation]), dim=0)
292
 
293
  def getPositionEncoding(pos, d=768, n=10000):
@@ -299,7 +299,6 @@ def getPositionEncoding(pos, d=768, n=10000):
299
  return P
300
 
301
  PositionVec = torch.stack([torch.from_numpy(getPositionEncoding(i, d=768)) for i in range(200)], dim=0).float().to(device)
302
-
303
  stop_w = ['...']
304
  with open('./vietnamese-stopwords-dash.txt', 'r', encoding='utf-8') as f:
305
  for w in f.readlines():
@@ -348,9 +347,19 @@ def loadClusterData(docs_org, category): # docs_org: list of text for each docum
348
  for d, doc in enumerate(docs_org):
349
  seclist[d], sentTexts = divideSection(doc, category)
350
  docs.append(sentTexts)
351
-
 
 
 
 
 
 
 
 
 
 
 
352
  sents, sentVecs, secIDs, doc_lens = [], [], [], []
353
- secnum = 4
354
  sentnum = sum([len(doc.values()) for doc in seclist.values()])
355
  doc_sec_mask = np.zeros((len(docs), secnum))
356
  sec_sen_mask = np.zeros((secnum, sentnum))
@@ -366,7 +375,7 @@ def loadClusterData(docs_org, category): # docs_org: list of text for each docum
366
  sentVecs.append(meanTokenVecs(sent))
367
  sec_sen_mask[seclist[d][s], cursent] = 1
368
  cursent += 1
369
-
370
  return Cluster(sents, sentVecs, doc_lens, doc_sec_mask, sec_sen_mask)
371
 
372
  def val_e2e(data, model, max_word_num=200, c_model=None):
@@ -414,7 +423,6 @@ c_model.load_state_dict(torch.load('./c_25_0.3071.mdl', map_location=device), st
414
  def infer(docs, category):
415
  # docs = [text.strip() for text in full_text.split('<><><><><>')]
416
  docs = [text.strip() for text in docs]
417
- print(docs)
418
  data_tree = loadClusterData(docs, category)
419
  summ = val_e2e(data_tree, model, c_model=c_model, max_word_num=200)
420
  summ = re.sub(r'\s+([.,;:"?()/!?])', r'\1', summ.replace('_', ' '))
 
287
  buffer, buffer_str = [], ''
288
  else:
289
  wordVecs[token[0]] = token[1]
290
+
291
  return torch.mean(torch.stack([vec for w, vec in wordVecs.items() if w not in string.punctuation]), dim=0)
292
 
293
  def getPositionEncoding(pos, d=768, n=10000):
 
299
  return P
300
 
301
  PositionVec = torch.stack([torch.from_numpy(getPositionEncoding(i, d=768)) for i in range(200)], dim=0).float().to(device)
 
302
  stop_w = ['...']
303
  with open('./vietnamese-stopwords-dash.txt', 'r', encoding='utf-8') as f:
304
  for w in f.readlines():
 
347
  for d, doc in enumerate(docs_org):
348
  seclist[d], sentTexts = divideSection(doc, category)
349
  docs.append(sentTexts)
350
+
351
+ secnum = 0
352
+ for k, val_dict in seclist.items():
353
+ vals = set(val_dict.values())
354
+ for ki, vi in val_dict.items():
355
+ for i, v in enumerate(vals):
356
+ if vi == v:
357
+ val_dict[ki] = i + secnum
358
+ break
359
+ seclist[k] = val_dict
360
+ secnum += len(vals)
361
+
362
  sents, sentVecs, secIDs, doc_lens = [], [], [], []
 
363
  sentnum = sum([len(doc.values()) for doc in seclist.values()])
364
  doc_sec_mask = np.zeros((len(docs), secnum))
365
  sec_sen_mask = np.zeros((secnum, sentnum))
 
375
  sentVecs.append(meanTokenVecs(sent))
376
  sec_sen_mask[seclist[d][s], cursent] = 1
377
  cursent += 1
378
+
379
  return Cluster(sents, sentVecs, doc_lens, doc_sec_mask, sec_sen_mask)
380
 
381
  def val_e2e(data, model, max_word_num=200, c_model=None):
 
423
  def infer(docs, category):
424
  # docs = [text.strip() for text in full_text.split('<><><><><>')]
425
  docs = [text.strip() for text in docs]
 
426
  data_tree = loadClusterData(docs, category)
427
  summ = val_e2e(data_tree, model, c_model=c_model, max_word_num=200)
428
  summ = re.sub(r'\s+([.,;:"?()/!?])', r'\1', summ.replace('_', ' '))