File size: 3,151 Bytes
6c1c798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import torch
import gradio as gr
import time
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

codes_as_string = '''Assamese	asm_Beng
Awadhi	awa_Deva
Bengali	ben_Beng
Bhojpuri	bho_Deva
Standard Tibetan	bod_Tibt
Dzongkha	dzo_Tibt
English	eng_Latn
Gujarati	guj_Gujr
Hindi	hin_Deva
Chhattisgarhi	hne_Deva
Kannada	kan_Knda
Kashmiri (Arabic script)	kas_Arab
Kashmiri (Devanagari script)	kas_Deva
Mizo	lus_Latn
Magahi	mag_Deva
Maithili	mai_Deva
Malayalam	mal_Mlym
Marathi	mar_Deva
Meitei (Bengali script)	mni_Beng
Burmese	mya_Mymr
Nepali	npi_Deva
Odia	ory_Orya
Punjabi	pan_Guru
Sanskrit	san_Deva
Santali	sat_Olck
Sindhi	snd_Arab
Tamil	tam_Taml
Telugu	tel_Telu
Urdu	urd_Arab
Vietnamese	vie_Latn'''



def load_models():
    # build model and tokenizer
    model_name_dict = {
                  'nllb-1.3B': "ychenNLP/nllb-200-distilled-1.3B-easyproject",
                  }

    model_dict = {}

    for call_name, real_name in model_name_dict.items():
        print('\tLoading model: %s' % call_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
        tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
        model_dict[call_name+'_model'] = model
        model_dict[call_name+'_tokenizer'] = tokenizer

    return model_dict


def translation(source, target, text):
    if len(model_dict) == 2:
        model_name = 'nllb-1.3B'

    start_time = time.time()
    source = flores_codes[source]
    target = flores_codes[target]

    model = model_dict[model_name + '_model']
    tokenizer = model_dict[model_name + '_tokenizer']

    translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
    output = translator(text, max_length=400)

    end_time = time.time()

    full_output = output
    output = output[0]['translation_text']
    # result = {'inference_time': end_time - start_time,
    #           'source': source,
    #           'target': target,
    #           'result': output,
    #           'full_output': full_output}
    return output


if __name__ == '__main__':
    print('\tinit models')
    codes_as_string = codes_as_string.split('\n')

    flores_codes = {}
    for code in codes_as_string:
        lang, lang_code = code.split('\t')
        flores_codes[lang] = lang_code

    global model_dict

    model_dict = load_models()
    
    # define gradio demo
    lang_codes = list(flores_codes.keys())
    
    inputs = [gr.inputs.Dropdown(lang_codes, default='English', label='Source'),
              gr.inputs.Dropdown(lang_codes, default='Hindi', label='Target'),
              gr.inputs.Textbox(lines=5, label="Input text"),
              ]

    outputs = gr.inputs.Textbox(label="Output text")

    title = "Machine Translation Demo"

    demo_status = "Machine Translation System."
    description = f"{demo_status}"

    gr.Interface(translation,
                 inputs,
                 outputs,
                 title=title,
                 description=description,
                 examples_per_page=50,
                 theme="JohnSmith9982/small_and_pretty"
                 ).launch()