0-shot-NER / ner.py
mvy
add validations checks
8e19b14
from typing import Tuple
import string
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
import spacy
import torch
import gradio as gr
class NER:
prompt: str = """
Identify entities in the text having the following classes:
{}
Text:
"""
def __init__(
self,
model_name: str,
sents_batch: int=10,
tokens_limit: int=2048
):
self.sents_batch = sents_batch
self.tokens_limit = tokens_limit
self.nlp: spacy.Language = spacy.load(
'en_core_web_sm',
disable = ['lemmatizer', 'parser', 'tagger', 'ner']
)
self.nlp.add_pipe('sentencizer')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
self.pipeline = pipeline(
"ner",
model=model,
tokenizer=self.tokenizer,
aggregation_strategy='first',
batch_size=12,
device=device
)
def get_last_sentence_id(self, i: int, sentences_len: int) -> int:
return min(i + self.sents_batch, sentences_len) - 1
def chunkanize(self, text: str) -> Tuple[list[str], list[int]]:
doc = self.nlp(text)
chunks = []
starts = []
sentences = list(doc.sents)
for i in range(0, len(sentences), self.sents_batch):
start = sentences[i].start_char
starts.append(start)
last_sentence = self.get_last_sentence_id(i, len(sentences))
end = sentences[last_sentence].end_char
chunks.append(text[start:end])
return chunks, starts
def get_inputs(
self, chunks: list[str], labels: list[str]
) -> Tuple[list[str], list[int]]:
inputs = []
prompts_lens = []
for label in labels:
prompt = self.prompt.format(label)
prompts_lens.append(len(prompt))
for chunk in chunks:
inputs.append(prompt + chunk)
return inputs, prompts_lens
@classmethod
def clean_span(
cls, start: int, end: int, span: str
) -> Tuple[int, int, str]:
if len(span) >= 1:
if span[0] in string.punctuation:
return cls.clean_span(start+1, end, span[1:])
if span[-1] in string.punctuation:
return cls.clean_span(start, end-1, span[:-1])
return start, end, span.strip()
def predict(
self,
text: str,
inputs: list[str],
labels: list[str],
chunks_starts: list[int],
prompts_lens: list[int],
threshold: float
) -> list[dict[str, any]]:
outputs = []
for id, output in enumerate(self.pipeline(inputs)):
label = labels[id//len(chunks_starts)]
shift = chunks_starts[id%len(chunks_starts)] - prompts_lens[id//len(chunks_starts)]
for ent in output:
start = ent['start'] + shift + 1
end = ent['end'] + shift
start, end, span = self.clean_span(start, end, text[start:end])
if not span:
continue
if ent['score'] >= threshold:
outputs.append({
'span': span,
'start': start,
'end': end,
'entity': label
})
return outputs
def check_text(self, text: str) -> None:
if not text:
raise gr.Error('No text provided. Please provide text.')
def check_labels(self, labels: list[str]) -> None:
if not labels:
raise gr.Error(
'No labels provided. Please provide labels.'
' Multiple labels should be divided by commas.'
' See examples below.'
)
def check_tokens_limit(self, inputs: list[str]) -> None:
tokens = 0
for input_ in inputs:
tokens += len(self.tokenizer.encode(input_))
if tokens > self.tokens_limit:
raise gr.Error(
'Too many tokens! Please reduce size of text or amount of labels.'
f' Max tokens count is: {self.tokens_limit}.'
)
def process(
self, labels: str, text: str, threshold: float=0.
) -> dict[str, any]:
labels_list = list({
l for label in labels.split(',')
if (l:=label.strip())
})
self.check_labels(labels_list)
self.check_text(text)
chunks, chunks_starts = self.chunkanize(text)
inputs, prompts_lens = self.get_inputs(chunks, labels_list)
self.check_tokens_limit(inputs)
outputs = self.predict(
text, inputs, labels_list, chunks_starts, prompts_lens, threshold
)
return {"text": text, "entities": outputs}