mxmax commited on
Commit
16ee7d6
1 Parent(s): 6c68f11

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +121 -0
README.md CHANGED
@@ -1,3 +1,124 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ ## 一、基于baichuan 7b模型进行sft,对其人类意图
6
+
7
+
8
+ ## 二、sft数据是在开源MOSS数据中通过各个类别均衡采样15w数据进行sft
9
+
10
+ ## 模型推理
11
+
12
+ Install package:
13
+ ```
14
+ pip install transformers
15
+ pip install sentencepiece
16
+ pip install vllm
17
+ ```
18
+ ### huggingface结合fastapi起服务,支持多轮对话
19
+
20
+ ```python
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer
22
+ from peft import PeftModel
23
+ import torch
24
+ import uvicorn
25
+ from fastapi import FastAPI
26
+ import jsonlines
27
+
28
+ device = 'cuda'
29
+ model_name = 'mxmax/baichuan-7b-sft-001'
30
+ max_new_tokens = 500
31
+ top_p = 0.9
32
+ temperature = 0.35
33
+ repetition_penalty = 1.0
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_name,
37
+ trust_remote_code=True,
38
+ low_cpu_mem_usage=True,
39
+ torch_dtype=torch.float16,
40
+ device_map={'': 0}#'auto'
41
+ ).cuda()
42
+ # model = PeftModel.from_pretrained(model, adapter_name)
43
+ model.eval()
44
+ model = model.to(device)
45
+ # 输入模型的最大长度
46
+ history_max_len = 1024
47
+ def model_infer(user_input):
48
+ history_token_ids = tokenizer('<s>', return_tensors="pt").input_ids
49
+ user_input_ids = tokenizer(user_input, return_tensors="pt").input_ids
50
+ history_token_ids = torch.concat((history_token_ids, user_input_ids[:, -history_max_len:]), dim=1)
51
+ model_input_ids = history_token_ids.to(device)
52
+ outputs = model.generate(
53
+ input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p,
54
+ temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id
55
+ )
56
+ model_input_ids_len = model_input_ids.size(1)
57
+ response_ids = outputs[:, model_input_ids_len:]
58
+ response = tokenizer.batch_decode(response_ids)
59
+ return response[0].strip().replace('</s>', "")
60
+
61
+ app = FastAPI()
62
+
63
+ @app.get('/')
64
+ async def root():
65
+ return {"msg": "Hello World"}
66
+
67
+ @app.post('/baichuan_sft_001')
68
+ async def baichuan_sft_001(message: dict):
69
+ prompt = ''
70
+ for l in message['context']:
71
+ prompt += 'human:'+l['human']+'\nassistant:'+l['assistant']+'</s>'
72
+ result = model_infer(prompt)
73
+ message['context'][-1]['assistant'] = result
74
+ return {'model_ouput':result}
75
+
76
+ if __name__ == '__main__':
77
+ uvicorn.run('model_serving:app',host="0.0.0.0", port=6006)
78
+ ```
79
+
80
+
81
+ ### vllm结合fastapi起服务,加速推理,支持多轮对话
82
+
83
+ ```python
84
+ from transformers import AutoModelForCausalLM, AutoTokenizer
85
+ from peft import PeftModel
86
+ import torch
87
+ import uvicorn
88
+ from fastapi import FastAPI
89
+ import jsonlines
90
+ from vllm import LLM, SamplingParams
91
+
92
+ device = 'cuda'
93
+ model_name = 'mxmax/baichuan-7b-sft-001'
94
+ max_new_tokens = 512
95
+ top_p = 0.9
96
+ temperature = 0.35
97
+ repetition_penalty = 0.1
98
+ history_max_len = 1024
99
+ sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_new_tokens, presence_penalty=repetition_penalty)
100
+
101
+ # Create an LLM.
102
+ llm = LLM(model=model_name,trust_remote_code=True,dtype='float16')
103
+ file = jsonlines.open('chat_record.json','a')
104
+ app = FastAPI()
105
+
106
+ @app.get('/')
107
+ async def root():
108
+ return {"msg": "Hello World"}
109
+
110
+ @app.post('/baichuan_sft_001')
111
+ async def baichuan_sft_001(message: dict):
112
+ prompt = ''
113
+ for l in message['context']:
114
+ prompt += 'human:'+l['human']+'\nassistant:'+l['assistant']+'</s>'
115
+ prompt = '<s>'+prompt[-history_max_len:]
116
+ outputs = llm.generate([prompt], sampling_params)
117
+ result = outputs[0].outputs[0].text
118
+ message['context'][-1]['assistant'] = result
119
+ return {'model_ouput':result}
120
+
121
+ if __name__ == '__main__':
122
+ uvicorn.run('vllm_serving:app',host="0.0.0.0", port=6006)
123
+ ```
124
+