shunxing1234 commited on
Commit
d53e5a3
1 Parent(s): 811643a

Upload test_chat.py

Browse files
Files changed (1) hide show
  1. test_chat.py +21 -0
test_chat.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AquilaForCausalLM
2
+ import torch
3
+ from cyg_conversation import default_conversation, covert_prompt_to_input_ids_with_history
4
+
5
+ tokenizer = AutoTokenizer.from_pretrained("BAAI/AquilaChat-7B")
6
+ model = AquilaForCausalLM.from_pretrained("BAAI/AquilaChat-7B")
7
+ model.eval()
8
+ model.to("cuda:4")
9
+ vocab = tokenizer.vocab
10
+ print(len(vocab))
11
+
12
+ text = "请给出10个要到北京旅游的理由。"
13
+
14
+ tokens = covert_prompt_to_input_ids_with_history(text, history=[], tokenizer=tokenizer, max_token=512)
15
+
16
+ tokens = torch.tensor(tokens)[None,].to("cuda:4")
17
+
18
+ out = model.generate(tokens, do_sample=True, max_length=512, eos_token_id=100007)[0]
19
+ out = tokenizer.decode(out.cpu().numpy().tolist())
20
+
21
+ print(out)