Fixed ud break
Browse files- BertForJointParsing.py +3 -3
BertForJointParsing.py
CHANGED
@@ -239,7 +239,7 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
239 |
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(final_output[sent_idx], parsed)
|
240 |
|
241 |
if output_style in ['ud', 'iahlt_ud']:
|
242 |
-
final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
|
243 |
|
244 |
if is_single_sentence:
|
245 |
final_output = final_output[0]
|
@@ -369,7 +369,7 @@ ud_suffix_to_htb_str = {
|
|
369 |
'Gender=Fem|Number=Sing|Person=2': '_את',
|
370 |
'Gender=Masc|Number=Plur|Person=3': '_הם'
|
371 |
}
|
372 |
-
def convert_output_to_ud(output_sentences, style: Literal['htb', 'iahlt']):
|
373 |
if style not in ['htb', 'iahlt']:
|
374 |
raise ValueError('style must be htb/iahlt')
|
375 |
|
@@ -393,7 +393,7 @@ def convert_output_to_ud(output_sentences, style: Literal['htb', 'iahlt']):
|
|
393 |
start = len(intermediate_output)
|
394 |
# Add in all the prefixes
|
395 |
if len(word['seg']) > 1:
|
396 |
-
for pre in get_prefixes_from_str(word['seg'][0], greedy=True):
|
397 |
# pos - just take the first valid pos that appears in the predicted prefixes list.
|
398 |
pos = next((pos for pos in ud_prefixes_to_pos[pre] if pos in word['morph']['prefixes']), ud_prefixes_to_pos[pre][0])
|
399 |
dep, func = ud_get_prefix_dep(pre, word, word_idx)
|
|
|
239 |
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(final_output[sent_idx], parsed)
|
240 |
|
241 |
if output_style in ['ud', 'iahlt_ud']:
|
242 |
+
final_output = convert_output_to_ud(final_output, self.config, style='htb' if output_style == 'ud' else 'iahlt')
|
243 |
|
244 |
if is_single_sentence:
|
245 |
final_output = final_output[0]
|
|
|
369 |
'Gender=Fem|Number=Sing|Person=2': '_את',
|
370 |
'Gender=Masc|Number=Plur|Person=3': '_הם'
|
371 |
}
|
372 |
+
def convert_output_to_ud(output_sentences, model_cfg, style: Literal['htb', 'iahlt']):
|
373 |
if style not in ['htb', 'iahlt']:
|
374 |
raise ValueError('style must be htb/iahlt')
|
375 |
|
|
|
393 |
start = len(intermediate_output)
|
394 |
# Add in all the prefixes
|
395 |
if len(word['seg']) > 1:
|
396 |
+
for pre in get_prefixes_from_str(word['seg'][0], model_cfg.prefix_cfg, greedy=True):
|
397 |
# pos - just take the first valid pos that appears in the predicted prefixes list.
|
398 |
pos = next((pos for pos in ud_prefixes_to_pos[pre] if pos in word['morph']['prefixes']), ud_prefixes_to_pos[pre][0])
|
399 |
dep, func = ud_get_prefix_dep(pre, word, word_idx)
|