Dmytro Vodianytskyi commited on
Commit
6d8648d
1 Parent(s): 3219277

space updated

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -7,13 +7,13 @@ TOKENIZER = T5Tokenizer.from_pretrained('werent4/mt5TranslatorLT')
7
  MODEL = MT5ForConditionalGeneration.from_pretrained("werent4/mt5TranslatorLT")
8
  MODEL.to(DEVICE)
9
 
10
- def translate(text, max_length, num_beams, translation_way = "en-lt"):
11
  translations_ways = {
12
  "en-lt": "<EN2LT>",
13
  "lt-en": "<LT2EN>"
14
  }
15
  if translation_way not in translations_ways:
16
- raise ValueError(f"Invalid translation way. Supported ways: {list(translations_ways.keys())}")
17
  input_text = f"{translations_ways[translation_way]} {text}"
18
  encoded_input = TOKENIZER(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
19
  with torch.no_grad():
 
7
  MODEL = MT5ForConditionalGeneration.from_pretrained("werent4/mt5TranslatorLT")
8
  MODEL.to(DEVICE)
9
 
10
+ def translate(text, max_length, num_beams, translation_way):
11
  translations_ways = {
12
  "en-lt": "<EN2LT>",
13
  "lt-en": "<LT2EN>"
14
  }
15
  if translation_way not in translations_ways:
16
+ raise ValueError(f"Invalid translation way: {translation_way}. Supported ways: {list(translations_ways.keys())}")
17
  input_text = f"{translations_ways[translation_way]} {text}"
18
  encoded_input = TOKENIZER(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
19
  with torch.no_grad():