Spaces:
Runtime error
Runtime error
from typing import Tuple | |
import string | |
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification | |
import spacy | |
import torch | |
import gradio as gr | |
class NER: | |
prompt: str = """ | |
Identify entities in the text having the following classes: | |
{} | |
Text: | |
""" | |
def __init__( | |
self, | |
model_name: str, | |
sents_batch: int=10, | |
tokens_limit: int=2048 | |
): | |
self.sents_batch = sents_batch | |
self.tokens_limit = tokens_limit | |
self.nlp: spacy.Language = spacy.load( | |
'en_core_web_sm', | |
disable = ['lemmatizer', 'parser', 'tagger', 'ner'] | |
) | |
self.nlp.add_pipe('sentencizer') | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForTokenClassification.from_pretrained(model_name) | |
self.pipeline = pipeline( | |
"ner", | |
model=model, | |
tokenizer=self.tokenizer, | |
aggregation_strategy='first', | |
batch_size=12, | |
device=device | |
) | |
def get_last_sentence_id(self, i: int, sentences_len: int) -> int: | |
return min(i + self.sents_batch, sentences_len) - 1 | |
def chunkanize(self, text: str) -> Tuple[list[str], list[int]]: | |
doc = self.nlp(text) | |
chunks = [] | |
starts = [] | |
sentences = list(doc.sents) | |
for i in range(0, len(sentences), self.sents_batch): | |
start = sentences[i].start_char | |
starts.append(start) | |
last_sentence = self.get_last_sentence_id(i, len(sentences)) | |
end = sentences[last_sentence].end_char | |
chunks.append(text[start:end]) | |
return chunks, starts | |
def get_inputs( | |
self, chunks: list[str], labels: list[str] | |
) -> Tuple[list[str], list[int]]: | |
inputs = [] | |
prompts_lens = [] | |
for label in labels: | |
prompt = self.prompt.format(label) | |
prompts_lens.append(len(prompt)) | |
for chunk in chunks: | |
inputs.append(prompt + chunk) | |
return inputs, prompts_lens | |
def clean_span( | |
cls, start: int, end: int, span: str | |
) -> Tuple[int, int, str]: | |
if len(span) >= 1: | |
if span[0] in string.punctuation: | |
return cls.clean_span(start+1, end, span[1:]) | |
if span[-1] in string.punctuation: | |
return cls.clean_span(start, end-1, span[:-1]) | |
return start, end, span.strip() | |
def predict( | |
self, | |
text: str, | |
inputs: list[str], | |
labels: list[str], | |
chunks_starts: list[int], | |
prompts_lens: list[int], | |
threshold: float | |
) -> list[dict[str, any]]: | |
outputs = [] | |
for id, output in enumerate(self.pipeline(inputs)): | |
label = labels[id//len(chunks_starts)] | |
shift = chunks_starts[id%len(chunks_starts)] - prompts_lens[id//len(chunks_starts)] | |
for ent in output: | |
start = ent['start'] + shift + 1 | |
end = ent['end'] + shift | |
start, end, span = self.clean_span(start, end, text[start:end]) | |
if not span: | |
continue | |
if ent['score'] >= threshold: | |
outputs.append({ | |
'span': span, | |
'start': start, | |
'end': end, | |
'entity': label | |
}) | |
return outputs | |
def check_text(self, text: str) -> None: | |
if not text: | |
raise gr.Error('No text provided. Please provide text.') | |
def check_labels(self, labels: list[str]) -> None: | |
if not labels: | |
raise gr.Error( | |
'No labels provided. Please provide labels.' | |
' Multiple labels should be divided by commas.' | |
' See examples below.' | |
) | |
def check_tokens_limit(self, inputs: list[str]) -> None: | |
tokens = 0 | |
for input_ in inputs: | |
tokens += len(self.tokenizer.encode(input_)) | |
if tokens > self.tokens_limit: | |
raise gr.Error( | |
'Too many tokens! Please reduce size of text or amount of labels.' | |
f' Max tokens count is: {self.tokens_limit}.' | |
) | |
def process( | |
self, labels: str, text: str, threshold: float=0. | |
) -> dict[str, any]: | |
labels_list = list({ | |
l for label in labels.split(',') | |
if (l:=label.strip()) | |
}) | |
self.check_labels(labels_list) | |
self.check_text(text) | |
chunks, chunks_starts = self.chunkanize(text) | |
inputs, prompts_lens = self.get_inputs(chunks, labels_list) | |
self.check_tokens_limit(inputs) | |
outputs = self.predict( | |
text, inputs, labels_list, chunks_starts, prompts_lens, threshold | |
) | |
return {"text": text, "entities": outputs} |