Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSeq2SeqLM | |
max_input_length = 128 | |
max_target_length = 128 | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f'Using {device} device') | |
# 1是中文到满语 | |
model1_checkpoint = "Helsinki-NLP/opus-mt-zh-en" | |
tokenizer1 = AutoTokenizer.from_pretrained(model1_checkpoint) | |
model1 = AutoModelForSeq2SeqLM.from_pretrained(model1_checkpoint) | |
model1 = model1.to(device) | |
model1.load_state_dict(torch.load('epoch_41_valid_bleu_100.00_model_weights.bin')) | |
model1.eval() | |
# 2是满语到中文 | |
model2_checkpoint = "Helsinki-NLP/opus-mt-en-zh" | |
tokenizer2 = AutoTokenizer.from_pretrained(model2_checkpoint) | |
model2 = AutoModelForSeq2SeqLM.from_pretrained(model2_checkpoint) | |
model2 = model2.to(device) | |
model2.load_state_dict(torch.load('epoch_41_valid_bleu_0.00_model_weights.bin')) | |
model2.eval() | |
def chineseToManju(text): | |
batch_data = tokenizer1( | |
text, | |
padding=True, | |
max_length=max_input_length, | |
truncation=True, | |
return_tensors="pt" | |
) | |
generated_tokens = model1.generate( | |
batch_data["input_ids"], | |
attention_mask=batch_data["attention_mask"], | |
max_length=max_target_length, | |
).cpu().numpy() | |
res = tokenizer1.batch_decode(generated_tokens, skip_special_tokens=True) | |
return res | |
def manjuToChinese(text): | |
batch_data = tokenizer2( | |
text, | |
padding=True, | |
max_length=max_input_length, | |
truncation=True, | |
return_tensors="pt" | |
) | |
generated_tokens = model2.generate( | |
batch_data["input_ids"], | |
attention_mask=batch_data["attention_mask"], | |
max_length=max_target_length, | |
).cpu().numpy() | |
res = tokenizer2.batch_decode(generated_tokens, skip_special_tokens=True) | |
return res | |
with gr.Blocks() as demo: | |
#用markdown语法编辑输出一段话 | |
gr.Markdown("## 满语翻译演示") | |
# 设置tab选项卡 | |
with gr.Tab("满to中"): | |
#Blocks特有组件,设置所有子组件按垂直排列 | |
#垂直排列是默认情况,不加也没关系 | |
with gr.Column(): | |
text_input1 = gr.Textbox(lines=2, placeholder="请输入满语",label="manju") | |
text_button1 = gr.Button("翻译") | |
text_output1 = gr.Textbox(lines=2, label="chinese") | |
with gr.Tab("中to满"): | |
#Blocks特有组件,设置所有子组件按水平排列 | |
with gr.Column(): | |
text_input2 = gr.Textbox(lines=2, placeholder="请输入中文",label="chinese") | |
text_button2 = gr.Button("翻译") | |
text_output2 = gr.Textbox(lines=2, label="manju") | |
#设置折叠内容 | |
with gr.Accordion(""): | |
gr.Markdown("## 东北师范大学信息科学与技术学院 满语智能处理实验室") | |
gr.Markdown("#### 注意事项") | |
gr.Markdown("最长语句不能超过128个词!") | |
gr.Markdown("#### 以下是一些例子") | |
gr.Markdown('"manju": "sakda amji,be yabume oho.", "chinese": "大爷,我们该走了。"') | |
gr.Markdown('"manju": "ume ekxere,jai emu majige teki,majige muke be omiki.", "chinese": "忙什么呀,再坐一会儿,喝点水。"') | |
gr.Markdown('"manju": "omirakv oho,ubade emgeri hontoha inenggi tehe,suwembe ambula jobohuha.", "chinese": "不了,在这呆了有半天了,打扰你们了。"') | |
gr.Markdown('"manju": "ume yabure,ubade emu erin buda be jejfi jai yabuki.", "chinese": "别走了,在这吃了饭再走吧。"') | |
text_button1.click(manjuToChinese, inputs=text_input1, outputs=text_output1) | |
text_button2.click(chineseToManju, inputs=text_input2, outputs=text_output2) | |
if __name__ == "__main__": | |
demo.launch() | |