Shaltiel commited on
Commit
64f9a08
1 Parent(s): deb5cae

Fixed ud break

Browse files
Files changed (1) hide show
  1. BertForJointParsing.py +3 -3
BertForJointParsing.py CHANGED
@@ -239,7 +239,7 @@ class BertForJointParsing(BertPreTrainedModel):
239
  final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(final_output[sent_idx], parsed)
240
 
241
  if output_style in ['ud', 'iahlt_ud']:
242
- final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
243
 
244
  if is_single_sentence:
245
  final_output = final_output[0]
@@ -369,7 +369,7 @@ ud_suffix_to_htb_str = {
369
  'Gender=Fem|Number=Sing|Person=2': '_את',
370
  'Gender=Masc|Number=Plur|Person=3': '_הם'
371
  }
372
- def convert_output_to_ud(output_sentences, style: Literal['htb', 'iahlt']):
373
  if style not in ['htb', 'iahlt']:
374
  raise ValueError('style must be htb/iahlt')
375
 
@@ -393,7 +393,7 @@ def convert_output_to_ud(output_sentences, style: Literal['htb', 'iahlt']):
393
  start = len(intermediate_output)
394
  # Add in all the prefixes
395
  if len(word['seg']) > 1:
396
- for pre in get_prefixes_from_str(word['seg'][0], greedy=True):
397
  # pos - just take the first valid pos that appears in the predicted prefixes list.
398
  pos = next((pos for pos in ud_prefixes_to_pos[pre] if pos in word['morph']['prefixes']), ud_prefixes_to_pos[pre][0])
399
  dep, func = ud_get_prefix_dep(pre, word, word_idx)
 
239
  final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(final_output[sent_idx], parsed)
240
 
241
  if output_style in ['ud', 'iahlt_ud']:
242
+ final_output = convert_output_to_ud(final_output, self.config, style='htb' if output_style == 'ud' else 'iahlt')
243
 
244
  if is_single_sentence:
245
  final_output = final_output[0]
 
369
  'Gender=Fem|Number=Sing|Person=2': '_את',
370
  'Gender=Masc|Number=Plur|Person=3': '_הם'
371
  }
372
+ def convert_output_to_ud(output_sentences, model_cfg, style: Literal['htb', 'iahlt']):
373
  if style not in ['htb', 'iahlt']:
374
  raise ValueError('style must be htb/iahlt')
375
 
 
393
  start = len(intermediate_output)
394
  # Add in all the prefixes
395
  if len(word['seg']) > 1:
396
+ for pre in get_prefixes_from_str(word['seg'][0], model_cfg.prefix_cfg, greedy=True):
397
  # pos - just take the first valid pos that appears in the predicted prefixes list.
398
  pos = next((pos for pos in ud_prefixes_to_pos[pre] if pos in word['morph']['prefixes']), ud_prefixes_to_pos[pre][0])
399
  dep, func = ud_get_prefix_dep(pre, word, word_idx)