BAAI
/

shunxing1234 commited on
Commit
4b56aa6
1 Parent(s): b1df38b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -24
README.md CHANGED
@@ -58,40 +58,19 @@ For detailed evaluation results, please refer to the website http://flageval.baa
58
  ```python
59
  from transformers import AutoTokenizer, AutoModelForCausalLM
60
  import torch
61
-
62
- device = torch.device("cuda:1")
63
-
64
  model_info = "BAAI/AquilaChat-7B"
65
  tokenizer = AutoTokenizer.from_pretrained(model_info, trust_remote_code=True)
66
  model = AutoModelForCausalLM.from_pretrained(model_info, trust_remote_code=True)
67
  model.eval()
68
  model.to(device)
69
-
70
  text = "请给出10个要到北京旅游的理由。"
71
-
72
  tokens = tokenizer.encode_plus(text)['input_ids'][:-1]
73
-
74
  tokens = torch.tensor(tokens)[None,].to(device)
75
-
76
-
77
  with torch.no_grad():
78
- out = model.generate(tokens, do_sample=True, max_length=512, eos_token_id=100007)[0]
79
-
80
  out = tokenizer.decode(out.cpu().numpy().tolist())
81
- if "###" in out:
82
- special_index = out.index("###")
83
- out = out[: special_index]
84
-
85
- if "[UNK]" in out:
86
- special_index = out.index("[UNK]")
87
- out = out[:special_index]
88
-
89
- if "</s>" in out:
90
- special_index = out.index("</s>")
91
- out = out[: special_index]
92
-
93
- if len(out) > 0 and out[0] == " ":
94
- out = out[1:]
95
  print(out)
96
  ```
97
 
 
58
  ```python
59
  from transformers import AutoTokenizer, AutoModelForCausalLM
60
  import torch
61
+ device = torch.device("cuda")
 
 
62
  model_info = "BAAI/AquilaChat-7B"
63
  tokenizer = AutoTokenizer.from_pretrained(model_info, trust_remote_code=True)
64
  model = AutoModelForCausalLM.from_pretrained(model_info, trust_remote_code=True)
65
  model.eval()
66
  model.to(device)
 
67
  text = "请给出10个要到北京旅游的理由。"
 
68
  tokens = tokenizer.encode_plus(text)['input_ids'][:-1]
 
69
  tokens = torch.tensor(tokens)[None,].to(device)
70
+ stop_tokens = ["###", "[UNK]", "</s>"]
 
71
  with torch.no_grad():
72
+ out = model.generate(tokens, do_sample=True, max_length=512, eos_token_id=100007, bad_words_ids=[[tokenizer.encode(token)[0] for token in stop_tokens]])[0]
 
73
  out = tokenizer.decode(out.cpu().numpy().tolist())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  print(out)
75
  ```
76