import re import gradio as gr from dataclasses import dataclass from prettytable import PrettyTable from pytorch_ie.annotations import LabeledSpan, BinaryRelation from pytorch_ie.auto import AutoPipeline from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from typing import List @dataclass class ExampleDocument(TextDocument): entities: AnnotationList[LabeledSpan] = annotation_field(target="text") relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") ner_model_name_or_path = "pie/example-ner-spanclf-conll03" re_model_name_or_path = "pie/example-re-textclf-tacred" ner_pipeline = AutoPipeline.from_pretrained(ner_model_name_or_path, device=-1, num_workers=0) re_pipeline = AutoPipeline.from_pretrained(re_model_name_or_path, device=-1, num_workers=0) def predict(text): document = ExampleDocument(text) ner_pipeline(document) while len(document.entities.predictions) > 0: document.entities.append(document.entities.predictions.pop(0)) re_pipeline(document) t = PrettyTable() t.field_names = ["head", "tail", "relation"] t.align = "l" for relation in document.relations.predictions: t.add_row([str(relation.head), str(relation.tail), relation.label]) html = t.get_html_string(format=True) html = ( "