Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
import os | |
import re | |
import sys | |
import typing as tp | |
import torch | |
import pysbd | |
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM | |
import unicodedata | |
import time | |
#hy_segmenter = pysbd.Segmenter(language="hy", clean=False) not needed | |
MODEL_NAME = "AriNubar/nllb-200-distilled-600m-en-hyw" | |
LANGUAGES = { | |
"Արեւմտահայերէն | Western Armenian": "hyw_Armn", | |
"Անգլերէն | English": "eng_Latn", | |
} | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
def get_non_printing_char_replacer(replace_by: str = " "): | |
non_printable_map = { | |
ord(c): replace_by | |
for c in (chr(i) for i in range(sys.maxunicode + 1)) | |
# same as \p{C} in perl | |
# see https://www.unicode.org/reports/tr44/#General_Category_Values | |
if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"} | |
} | |
def replace_non_printing_char(line) -> str: | |
return line.translate(non_printable_map) | |
return replace_non_printing_char | |
def clean_text(text: str, lang) -> str: | |
HYW_CHARS_TO_NORMALIZE = { | |
"«": '"', | |
"»": '"', | |
"“": '"', | |
"”": '"', | |
"’": "'", | |
"‘": "'", | |
"–": "-", | |
"—": "-", | |
"ՙ": "'", | |
"՚": "'", | |
} | |
DOUBLE_CHARS_TO_NORMALIZE = { | |
"Կ՛": "Կ'", | |
"կ՛": "կ'", | |
"Չ՛": "Չ'", | |
"չ՛": "չ'", | |
"Մ՛": "Մ'", | |
"մ՛": "մ'", | |
} | |
replace_nonprint = get_non_printing_char_replacer() | |
text = replace_nonprint(text) | |
# print(text) | |
text = text.replace("\t", " ").replace("\n", " ").replace("\r", " ").replace(r"[^\x00-\x7F]+", " ").replace(r"\s+", " ") | |
text = text.strip() | |
if lang == "hyw_Armn": | |
text = text.translate(str.maketrans(HYW_CHARS_TO_NORMALIZE)) | |
for k, v in DOUBLE_CHARS_TO_NORMALIZE.items(): | |
text = text.replace(k, v) | |
return text | |
def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False): | |
if fix_double_space: | |
text = re.sub(r"\s+", " ", text) | |
text = text.strip() | |
sentences = splitter.segment(text) | |
fillers = [] | |
i = 0 | |
for sent in sentences: | |
start_idx = text.find(sent, i) | |
if ignore_errors and start_idx == -1: | |
start_idx = i + 1 | |
assert start_idx != -1, f"Sent not found after index {i} in text: {text}" | |
fillers.append(text[i:start_idx]) | |
i = start_idx + len(sent) | |
fillers.append(text[i:]) | |
return sentences, fillers | |
def init_tokenizer(tokenizer, new_lang='hyw_Armn'): | |
""" Add a new language token to the tokenizer vocabulary (this should be done each time after its initialization) """ | |
old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder) | |
tokenizer.lang_code_to_id[new_lang] = old_len-1 | |
tokenizer.id_to_lang_code[old_len-1] = new_lang | |
# always move "mask" to the last position | |
tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset | |
tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id) | |
tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()} | |
if new_lang not in tokenizer._additional_special_tokens: | |
tokenizer._additional_special_tokens.append(new_lang) | |
# clear the added token encoder; otherwise a new token may end up there by mistake | |
tokenizer.added_tokens_encoder = {} | |
tokenizer.added_tokens_decoder = {} | |
class Translator: | |
def __init__(self) -> None: | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, token=HF_TOKEN) | |
if torch.cuda.is_available(): | |
self.model = self.model.cuda() | |
self.tokenizer = NllbTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN) | |
init_tokenizer(self.tokenizer) | |
self.hyw_splitter = pysbd.Segmenter(language="hy", clean=True) | |
self.eng_splitter = pysbd.Segmenter(language="en", clean=True) | |
self.languages = LANGUAGES | |
def translate_single( | |
self, | |
text, | |
src_lang, | |
tgt_lang, | |
max_length="auto", | |
num_beams=4, | |
n_out=None, | |
**kwargs, | |
): | |
self.tokenizer.src_lang = src_lang | |
encoded = self.tokenizer( | |
text, return_tensors="pt", truncation=True, max_length=256 | |
) | |
if max_length == "auto": | |
max_length = int(32 + 2.0 * encoded.input_ids.shape[1]) | |
generated_tokens = self.model.generate( | |
**encoded.to(self.model.device), | |
forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang], | |
max_length=max_length, | |
num_beams=num_beams, | |
num_return_sequences=n_out or 1, | |
**kwargs, | |
) | |
out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
if isinstance(text, str) and n_out is None: | |
return out[0] | |
return out | |
def translate(self, text: str, | |
src_lang: str, | |
tgt_lang: str, | |
max_length=256, | |
num_beams=4, | |
by_sentence=True, | |
clean=True, | |
**kwargs): | |
if by_sentence: | |
if src_lang == "eng_Latn": | |
sents = self.eng_splitter.segment(text) | |
elif src_lang == "hyw_Armn": | |
sents = self.hyw_splitter.segment(text) | |
if clean: | |
sents = [clean_text(sent, src_lang) for sent in sents] | |
if len(sents) > 1: | |
results = self.translate_batch(sents, src_lang, tgt_lang, num_beams=num_beams, max_length=max_length, **kwargs) | |
else: | |
results = self.translate_single(sents, src_lang, tgt_lang, max_length=max_length, num_beams=num_beams, **kwargs) | |
return " ".join(results) | |
def translate_batch(self, texts, src_lang, tgt_lang, num_beams=4, max_length=256, **kwargs): | |
self.tokenizer.src_lang = src_lang | |
if torch.cuda.is_available(): | |
inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True).input_ids.to("cuda") | |
translated_tokens = self.model.generate(inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang]) | |
else: | |
inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True) | |
translated_tokens = self.model.generate(**inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang]) | |
return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) | |
if __name__ == "__main__": | |
print("Initializing translator...") | |
translator = Translator() | |
print("Translator initialized.") | |
start_time = time.time() | |
print(translator.translate("Hello world!", "eng_Latn", "hyw_Armn")) | |
print("Time elapsed: ", time.time() - start_time) | |
start_time = time.time() | |
print(translator.translate("I am the greatest translator! Do not fuck with me!", "eng_Latn", "hyw_Armn")) | |
print("Time elapsed: ", time.time() - start_time) |