hyw-en-demo / translation.py
AriNubar's picture
improve translation speed
e8c3b4c
# -*- coding: utf-8 -*-
import os
import re
import sys
import typing as tp
import torch
import pysbd
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
import unicodedata
import time
#hy_segmenter = pysbd.Segmenter(language="hy", clean=False) not needed
MODEL_NAME = "AriNubar/nllb-200-distilled-600m-en-hyw"
LANGUAGES = {
"Արեւմտահայերէն | Western Armenian": "hyw_Armn",
"Անգլերէն | English": "eng_Latn",
}
HF_TOKEN = os.environ.get("HF_TOKEN")
def get_non_printing_char_replacer(replace_by: str = " "):
non_printable_map = {
ord(c): replace_by
for c in (chr(i) for i in range(sys.maxunicode + 1))
# same as \p{C} in perl
# see https://www.unicode.org/reports/tr44/#General_Category_Values
if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}
}
def replace_non_printing_char(line) -> str:
return line.translate(non_printable_map)
return replace_non_printing_char
def clean_text(text: str, lang) -> str:
HYW_CHARS_TO_NORMALIZE = {
"«": '"',
"»": '"',
"“": '"',
"”": '"',
"’": "'",
"‘": "'",
"–": "-",
"—": "-",
"ՙ": "'",
"՚": "'",
}
DOUBLE_CHARS_TO_NORMALIZE = {
"Կ՛": "Կ'",
"կ՛": "կ'",
"Չ՛": "Չ'",
"չ՛": "չ'",
"Մ՛": "Մ'",
"մ՛": "մ'",
}
replace_nonprint = get_non_printing_char_replacer()
text = replace_nonprint(text)
# print(text)
text = text.replace("\t", " ").replace("\n", " ").replace("\r", " ").replace(r"[^\x00-\x7F]+", " ").replace(r"\s+", " ")
text = text.strip()
if lang == "hyw_Armn":
text = text.translate(str.maketrans(HYW_CHARS_TO_NORMALIZE))
for k, v in DOUBLE_CHARS_TO_NORMALIZE.items():
text = text.replace(k, v)
return text
def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False):
if fix_double_space:
text = re.sub(r"\s+", " ", text)
text = text.strip()
sentences = splitter.segment(text)
fillers = []
i = 0
for sent in sentences:
start_idx = text.find(sent, i)
if ignore_errors and start_idx == -1:
start_idx = i + 1
assert start_idx != -1, f"Sent not found after index {i} in text: {text}"
fillers.append(text[i:start_idx])
i = start_idx + len(sent)
fillers.append(text[i:])
return sentences, fillers
def init_tokenizer(tokenizer, new_lang='hyw_Armn'):
""" Add a new language token to the tokenizer vocabulary (this should be done each time after its initialization) """
old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
tokenizer.lang_code_to_id[new_lang] = old_len-1
tokenizer.id_to_lang_code[old_len-1] = new_lang
# always move "mask" to the last position
tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset
tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
if new_lang not in tokenizer._additional_special_tokens:
tokenizer._additional_special_tokens.append(new_lang)
# clear the added token encoder; otherwise a new token may end up there by mistake
tokenizer.added_tokens_encoder = {}
tokenizer.added_tokens_decoder = {}
class Translator:
def __init__(self) -> None:
self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, token=HF_TOKEN)
if torch.cuda.is_available():
self.model = self.model.cuda()
self.tokenizer = NllbTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
init_tokenizer(self.tokenizer)
self.hyw_splitter = pysbd.Segmenter(language="hy", clean=True)
self.eng_splitter = pysbd.Segmenter(language="en", clean=True)
self.languages = LANGUAGES
def translate_single(
self,
text,
src_lang,
tgt_lang,
max_length="auto",
num_beams=4,
n_out=None,
**kwargs,
):
self.tokenizer.src_lang = src_lang
encoded = self.tokenizer(
text, return_tensors="pt", truncation=True, max_length=256
)
if max_length == "auto":
max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
generated_tokens = self.model.generate(
**encoded.to(self.model.device),
forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang],
max_length=max_length,
num_beams=num_beams,
num_return_sequences=n_out or 1,
**kwargs,
)
out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
if isinstance(text, str) and n_out is None:
return out[0]
return out
def translate(self, text: str,
src_lang: str,
tgt_lang: str,
max_length=256,
num_beams=4,
by_sentence=True,
clean=True,
**kwargs):
if by_sentence:
if src_lang == "eng_Latn":
sents = self.eng_splitter.segment(text)
elif src_lang == "hyw_Armn":
sents = self.hyw_splitter.segment(text)
if clean:
sents = [clean_text(sent, src_lang) for sent in sents]
if len(sents) > 1:
results = self.translate_batch(sents, src_lang, tgt_lang, num_beams=num_beams, max_length=max_length, **kwargs)
else:
results = self.translate_single(sents, src_lang, tgt_lang, max_length=max_length, num_beams=num_beams, **kwargs)
return " ".join(results)
def translate_batch(self, texts, src_lang, tgt_lang, num_beams=4, max_length=256, **kwargs):
self.tokenizer.src_lang = src_lang
if torch.cuda.is_available():
inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True).input_ids.to("cuda")
translated_tokens = self.model.generate(inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang])
else:
inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True)
translated_tokens = self.model.generate(**inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang])
return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
if __name__ == "__main__":
print("Initializing translator...")
translator = Translator()
print("Translator initialized.")
start_time = time.time()
print(translator.translate("Hello world!", "eng_Latn", "hyw_Armn"))
print("Time elapsed: ", time.time() - start_time)
start_time = time.time()
print(translator.translate("I am the greatest translator! Do not fuck with me!", "eng_Latn", "hyw_Armn"))
print("Time elapsed: ", time.time() - start_time)