Spaces:
Running
Running
Merge pull request #230 from jhj0517/fix/local-files-only
Browse files
modules/translation/nllb_inference.py
CHANGED
@@ -40,10 +40,13 @@ class NLLBInference(TranslationBase):
|
|
40 |
print("\nInitializing NLLB Model..\n")
|
41 |
progress(0, desc="Initializing NLLB Model..")
|
42 |
self.current_model_size = model_size
|
|
|
43 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
|
44 |
-
cache_dir=self.model_dir
|
|
|
45 |
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
|
46 |
-
cache_dir=os.path.join(self.model_dir, "tokenizers")
|
|
|
47 |
src_lang = NLLB_AVAILABLE_LANGS[src_lang]
|
48 |
tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
|
49 |
self.pipeline = pipeline("translation",
|
@@ -53,6 +56,18 @@ class NLLBInference(TranslationBase):
|
|
53 |
tgt_lang=tgt_lang,
|
54 |
device=self.device)
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
NLLB_AVAILABLE_LANGS = {
|
57 |
"Acehnese (Arabic script)": "ace_Arab",
|
58 |
"Acehnese (Latin script)": "ace_Latn",
|
|
|
40 |
print("\nInitializing NLLB Model..\n")
|
41 |
progress(0, desc="Initializing NLLB Model..")
|
42 |
self.current_model_size = model_size
|
43 |
+
local_files_only = self.is_model_exists(self.current_model_size)
|
44 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
|
45 |
+
cache_dir=self.model_dir,
|
46 |
+
local_files_only=local_files_only)
|
47 |
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
|
48 |
+
cache_dir=os.path.join(self.model_dir, "tokenizers"),
|
49 |
+
local_files_only=local_files_only)
|
50 |
src_lang = NLLB_AVAILABLE_LANGS[src_lang]
|
51 |
tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
|
52 |
self.pipeline = pipeline("translation",
|
|
|
56 |
tgt_lang=tgt_lang,
|
57 |
device=self.device)
|
58 |
|
59 |
+
def is_model_exists(self,
|
60 |
+
model_size: str):
|
61 |
+
"""Check if model exists or not (Only facebook model)"""
|
62 |
+
prefix = "models--facebook--"
|
63 |
+
_id, model_size_name = model_size.split("/")
|
64 |
+
model_dir_name = prefix + model_size_name
|
65 |
+
model_dir_path = os.path.join(self.model_dir, model_dir_name)
|
66 |
+
if os.path.exists(model_dir_path) and os.listdir(model_dir_path):
|
67 |
+
return True
|
68 |
+
return False
|
69 |
+
|
70 |
+
|
71 |
NLLB_AVAILABLE_LANGS = {
|
72 |
"Acehnese (Arabic script)": "ace_Arab",
|
73 |
"Acehnese (Latin script)": "ace_Latn",
|