File size: 4,145 Bytes
04c357f 16ee7d6 794528a 3a754ce 794528a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
---
license: apache-2.0
---
## 一、基于baichuan 7b模型进行sft,对其人类意图
## 二、sft数据是在开源MOSS数据中通过各个类别均衡采样15w数据进行sft
## 模型推理
Install package:
```
pip install transformers
pip install sentencepiece
pip install vllm
```
### huggingface结合fastapi起服务,支持多轮对话
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import uvicorn
from fastapi import FastAPI
import jsonlines
device = 'cuda'
model_name = 'mxmax/baichuan-7b-sft-001'
max_new_tokens = 500
top_p = 0.9
temperature = 0.35
repetition_penalty = 1.0
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map={'': 0}#'auto'
).cuda()
# model = PeftModel.from_pretrained(model, adapter_name)
model.eval()
model = model.to(device)
# 输入模型的最大长度
history_max_len = 1024
def model_infer(user_input):
history_token_ids = tokenizer('<s>', return_tensors="pt").input_ids
user_input_ids = tokenizer(user_input, return_tensors="pt").input_ids
history_token_ids = torch.concat((history_token_ids, user_input_ids[:, -history_max_len:]), dim=1)
model_input_ids = history_token_ids.to(device)
outputs = model.generate(
input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p,
temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id
)
model_input_ids_len = model_input_ids.size(1)
response_ids = outputs[:, model_input_ids_len:]
response = tokenizer.batch_decode(response_ids)
return response[0].strip().replace('</s>', "")
app = FastAPI()
@app.get('/')
async def root():
return {"msg": "Hello World"}
@app.post('/baichuan_sft_001')
async def baichuan_sft_001(message: dict):
prompt = ''
for l in message['context']:
prompt += 'human:'+l['human']+'\nassistant:'+l['assistant']+'</s>'
result = model_infer(prompt)
message['context'][-1]['assistant'] = result
return {'model_ouput':result}
if __name__ == '__main__':
uvicorn.run('model_serving:app',host="0.0.0.0", port=6006)
```
### vllm结合fastapi起服务,加速推理,支持多轮对话
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import uvicorn
from fastapi import FastAPI
import jsonlines
from vllm import LLM, SamplingParams
device = 'cuda'
model_name = 'mxmax/baichuan-7b-sft-001'
max_new_tokens = 512
top_p = 0.9
temperature = 0.35
repetition_penalty = 0.1
history_max_len = 1024
sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_new_tokens, presence_penalty=repetition_penalty)
# Create an LLM.
llm = LLM(model=model_name,trust_remote_code=True,dtype='float16')
file = jsonlines.open('chat_record.json','a')
app = FastAPI()
@app.get('/')
async def root():
return {"msg": "Hello World"}
@app.post('/baichuan_sft_001')
async def baichuan_sft_001(message: dict):
prompt = ''
for l in message['context']:
prompt += 'human:'+l['human']+'\nassistant:'+l['assistant']+'</s>'
prompt = '<s>'+prompt[-history_max_len:]
outputs = llm.generate([prompt], sampling_params)
result = outputs[0].outputs[0].text
message['context'][-1]['assistant'] = result
return {'model_ouput':result}
if __name__ == '__main__':
uvicorn.run('vllm_serving:app',host="0.0.0.0", port=6006)
```
## 模型效果展示
![arch](./images/1.jpg)
![arch](./images/2.jpg)
![arch](./images/3.jpg)
![arch](./images/4.jpg)
![arch](./images/5.jpg)
![arch](./images/6.jpg)
## 联系方式
![arch](./images/微信好友二维码.jpg)
加好友请备注:来自于huggingface网站交流技术+名字
qq群:621725172
## 引用
```bash
@misc{mxmax,
title={baichuan_sft: baichuan-7b-sft-001},
author={Ma Xin},
year={2023},
howpublished={\url{https://huggingface.co/mxmax/baichuan-7b-sft-001}},
}
``` |