Spaces:
Runtime error
Runtime error
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification | |
import spacy | |
import torch | |
nlp = spacy.load('en_core_web_sm', disable = ['lemmatizer', 'parser', 'tagger', 'ner']) | |
nlp.add_pipe('sentencizer') | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
class NER: | |
model_name = 'knowledgator/UTC-DeBERTa-small' | |
prompt=""" | |
Identify entities in the text having the following classes: | |
{} | |
Text: | |
""" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForTokenClassification.from_pretrained(model_name) | |
ner_pipeline = pipeline( | |
"ner", | |
model=model, | |
tokenizer=tokenizer, | |
aggregation_strategy='first', | |
batch_size=12, | |
device=device | |
) | |
def chunkanize(cls, text, prompt_ = '', n_sents = 10): | |
doc = nlp(text) | |
chunks = [] | |
starts = [] | |
start = 0 | |
end = 0 | |
proc = False | |
for id, sent in enumerate(doc.sents, start=1): | |
if not proc: | |
start = sent[0].idx | |
starts.append(start) | |
proc = True | |
end = sent[-1].idx+len(sent[-1].text) | |
if id%n_sents==0: | |
chunk_text = prompt_+text[start:end] | |
chunks.append(chunk_text) | |
proc = False | |
if proc: | |
chunk_text = prompt_+text[start:end] | |
chunks.append(chunk_text) | |
return chunks, starts | |
def ner(cls, labels, text, treshold = 0.): | |
chunks, starts, classes = [], [], [] | |
label2prompt_len = {} | |
for label in labels.split(', '): | |
prompt_ = cls.prompt.format(label) | |
prompt_len = len(prompt_) | |
label2prompt_len[label] = prompt_len | |
curr_chunks, curr_starts = cls.chunkanize(text, prompt_) | |
curr_labels = [label for _ in range(len(curr_chunks))] | |
chunks+=curr_chunks | |
starts+=curr_starts | |
classes+=curr_labels | |
outputs = [] | |
for id, output in enumerate(cls.ner_pipeline(chunks)): | |
label = classes[id] | |
prompt_len = label2prompt_len[label] | |
start = starts[id]-prompt_len | |
for ent in output: | |
if ent['score']>treshold: | |
ent['start'] += start | |
ent['end'] += start | |
ent['entity'] = label | |
outputs.append(ent) | |
return {"text": text, "entities": outputs} |