scutcyr commited on
Commit
0d14890
1 Parent(s): 162f1cb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # install torch and tf
3
+ os.system('pip install transformers SentencePiece')
4
+ os.system('pip install torch')
5
+
6
+ from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer
7
+ import torch
8
+ import gradio as gr
9
+
10
+ # 下载模型
11
+ tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v1")
12
+ model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1")
13
+ # 修改colab笔记本设置为gpu,推理更快
14
+ device = torch.device('cpu')
15
+ model.to(device)
16
+ print('Model Load done!')
17
+
18
+ def preprocess(text):
19
+ text = text.replace("\n", "\\n").replace("\t", "\\t")
20
+ return text
21
+
22
+ def postprocess(text):
23
+ return text.replace("\\n", "\n").replace("\\t", "\t")
24
+
25
+ def answer(text, sample=True, top_p=0.9, temperature=0.7):
26
+ '''sample:是否抽样。生成任务,可以设置为True;
27
+ top_p:0-1之间,生成的内容越多样
28
+ max_new_tokens=512 lost...'''
29
+ text = preprocess(text)
30
+ print('用户: '+text)
31
+ encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=768, return_tensors="pt").to(device)
32
+ if not sample:
33
+ out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, num_beams=1, length_penalty=0.6)
34
+ else:
35
+ out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=3)
36
+ out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
37
+ print('小元: '+postprocess(out_text[0]))
38
+ return postprocess(out_text[0])
39
+
40
+ def command_result(text):
41
+ output = answer(text)
42
+ return output
43
+
44
+ iface = gr.Interface(fn=command_result, inputs="text", outputs="text", title="中文聊天机器人Demo")
45
+ iface.launch()