wenge-research commited on
Commit
ae584ce
1 Parent(s): 940e2d8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -5
README.md CHANGED
@@ -19,6 +19,7 @@ tags:
19
 
20
  ```python
21
  from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
 
22
 
23
  yayi_7b_path = "wenge-research/yayi-7b"
24
  tokenizer = AutoTokenizer.from_pretrained(yayi_7b_path)
@@ -26,7 +27,7 @@ model = AutoModelForCausalLM.from_pretrained(yayi_7b_path, device_map="auto", to
26
 
27
  prompt = "你好"
28
  formatted_prompt = f"<|System|>:\nA chat between a human and an AI assistant named YaYi.\nYaYi is a helpful and harmless language model developed by Beijing Wenge Technology Co.,Ltd.\n\n<|Human|>:\n{prompt}\n\n<|YaYi|>:"
29
- inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
30
 
31
  generation_config = GenerationConfig(
32
  do_sample=True,
@@ -36,12 +37,14 @@ generation_config = GenerationConfig(
36
  no_repeat_ngram_size=0
37
  )
38
  response = model.generate(**inputs, generation_config=generation_config)
39
- print(tokenizer.decode(outputs[0]))
40
  ```
41
 
42
  注意,模型训练时添加了 special token `<|End|>` 作为结束符,上述代码在生成式若不能自动停止,可定义 `KeywordsStoppingCriteria` 类,并将其对象传参至 `model.generate()` 函数。
43
 
44
  ```python
 
 
45
  class KeywordsStoppingCriteria(StoppingCriteria):
46
  def __init__(self, keywords_ids:list):
47
  self.keywords = keywords_ids
@@ -54,11 +57,10 @@ class KeywordsStoppingCriteria(StoppingCriteria):
54
 
55
  ```python
56
  stop_criteria = KeywordsStoppingCriteria([tokenizer.encode(w)[0] for w in ["<|End|>"]])
57
- ...
58
- response = model.generate(**inputs, generation_config=generation_config, stop_criteria=stop_criteria)
59
  ```
60
 
61
-
62
  ## 相关协议
63
 
64
  ### 局限性
 
19
 
20
  ```python
21
  from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
22
+ import torch
23
 
24
  yayi_7b_path = "wenge-research/yayi-7b"
25
  tokenizer = AutoTokenizer.from_pretrained(yayi_7b_path)
 
27
 
28
  prompt = "你好"
29
  formatted_prompt = f"<|System|>:\nA chat between a human and an AI assistant named YaYi.\nYaYi is a helpful and harmless language model developed by Beijing Wenge Technology Co.,Ltd.\n\n<|Human|>:\n{prompt}\n\n<|YaYi|>:"
30
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
31
 
32
  generation_config = GenerationConfig(
33
  do_sample=True,
 
37
  no_repeat_ngram_size=0
38
  )
39
  response = model.generate(**inputs, generation_config=generation_config)
40
+ print(tokenizer.decode(response[0]))
41
  ```
42
 
43
  注意,模型训练时添加了 special token `<|End|>` 作为结束符,上述代码在生成式若不能自动停止,可定义 `KeywordsStoppingCriteria` 类,并将其对象传参至 `model.generate()` 函数。
44
 
45
  ```python
46
+ from transformers import StoppingCriteria, StoppingCriteriaList
47
+
48
  class KeywordsStoppingCriteria(StoppingCriteria):
49
  def __init__(self, keywords_ids:list):
50
  self.keywords = keywords_ids
 
57
 
58
  ```python
59
  stop_criteria = KeywordsStoppingCriteria([tokenizer.encode(w)[0] for w in ["<|End|>"]])
60
+ response = model.generate(**inputs, generation_config=generation_config, stopping_criteria=StoppingCriteriaList([stop_criteria]))
61
+ print(tokenizer.decode(response[0]))
62
  ```
63
 
 
64
  ## 相关协议
65
 
66
  ### 局限性