jhj0517 commited on
Commit
d85434f
2 Parent(s): bb4ed2f a0d6f10

Merge pull request #230 from jhj0517/fix/local-files-only

Browse files
modules/translation/nllb_inference.py CHANGED
@@ -40,10 +40,13 @@ class NLLBInference(TranslationBase):
40
  print("\nInitializing NLLB Model..\n")
41
  progress(0, desc="Initializing NLLB Model..")
42
  self.current_model_size = model_size
 
43
  self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
44
- cache_dir=self.model_dir)
 
45
  self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
46
- cache_dir=os.path.join(self.model_dir, "tokenizers"))
 
47
  src_lang = NLLB_AVAILABLE_LANGS[src_lang]
48
  tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
49
  self.pipeline = pipeline("translation",
@@ -53,6 +56,18 @@ class NLLBInference(TranslationBase):
53
  tgt_lang=tgt_lang,
54
  device=self.device)
55
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  NLLB_AVAILABLE_LANGS = {
57
  "Acehnese (Arabic script)": "ace_Arab",
58
  "Acehnese (Latin script)": "ace_Latn",
 
40
  print("\nInitializing NLLB Model..\n")
41
  progress(0, desc="Initializing NLLB Model..")
42
  self.current_model_size = model_size
43
+ local_files_only = self.is_model_exists(self.current_model_size)
44
  self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
45
+ cache_dir=self.model_dir,
46
+ local_files_only=local_files_only)
47
  self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
48
+ cache_dir=os.path.join(self.model_dir, "tokenizers"),
49
+ local_files_only=local_files_only)
50
  src_lang = NLLB_AVAILABLE_LANGS[src_lang]
51
  tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
52
  self.pipeline = pipeline("translation",
 
56
  tgt_lang=tgt_lang,
57
  device=self.device)
58
 
59
+ def is_model_exists(self,
60
+ model_size: str):
61
+ """Check if model exists or not (Only facebook model)"""
62
+ prefix = "models--facebook--"
63
+ _id, model_size_name = model_size.split("/")
64
+ model_dir_name = prefix + model_size_name
65
+ model_dir_path = os.path.join(self.model_dir, model_dir_name)
66
+ if os.path.exists(model_dir_path) and os.listdir(model_dir_path):
67
+ return True
68
+ return False
69
+
70
+
71
  NLLB_AVAILABLE_LANGS = {
72
  "Acehnese (Arabic script)": "ace_Arab",
73
  "Acehnese (Latin script)": "ace_Latn",