|
--- |
|
license: apache-2.0 |
|
--- |
|
|
|
## 一、基于baichuan 7b模型进行sft,对其人类意图 |
|
|
|
|
|
## 二、sft数据是在开源MOSS数据中通过各个类别均衡采样15w数据进行sft |
|
|
|
## 模型推理 |
|
|
|
Install package: |
|
``` |
|
pip install transformers |
|
pip install sentencepiece |
|
pip install vllm |
|
``` |
|
### huggingface结合fastapi起服务,支持多轮对话 |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import PeftModel |
|
import torch |
|
import uvicorn |
|
from fastapi import FastAPI |
|
import jsonlines |
|
|
|
device = 'cuda' |
|
model_name = 'mxmax/baichuan-7b-sft-001' |
|
max_new_tokens = 500 |
|
top_p = 0.9 |
|
temperature = 0.35 |
|
repetition_penalty = 1.0 |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
low_cpu_mem_usage=True, |
|
torch_dtype=torch.float16, |
|
device_map={'': 0}#'auto' |
|
).cuda() |
|
# model = PeftModel.from_pretrained(model, adapter_name) |
|
model.eval() |
|
model = model.to(device) |
|
# 输入模型的最大长度 |
|
history_max_len = 1024 |
|
def model_infer(user_input): |
|
history_token_ids = tokenizer('<s>', return_tensors="pt").input_ids |
|
user_input_ids = tokenizer(user_input, return_tensors="pt").input_ids |
|
history_token_ids = torch.concat((history_token_ids, user_input_ids[:, -history_max_len:]), dim=1) |
|
model_input_ids = history_token_ids.to(device) |
|
outputs = model.generate( |
|
input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, |
|
temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id |
|
) |
|
model_input_ids_len = model_input_ids.size(1) |
|
response_ids = outputs[:, model_input_ids_len:] |
|
response = tokenizer.batch_decode(response_ids) |
|
return response[0].strip().replace('</s>', "") |
|
|
|
app = FastAPI() |
|
|
|
@app.get('/') |
|
async def root(): |
|
return {"msg": "Hello World"} |
|
|
|
@app.post('/baichuan_sft_001') |
|
async def baichuan_sft_001(message: dict): |
|
prompt = '' |
|
for l in message['context']: |
|
prompt += 'human:'+l['human']+'\nassistant:'+l['assistant']+'</s>' |
|
result = model_infer(prompt) |
|
message['context'][-1]['assistant'] = result |
|
return {'model_ouput':result} |
|
|
|
if __name__ == '__main__': |
|
uvicorn.run('model_serving:app',host="0.0.0.0", port=6006) |
|
``` |
|
|
|
|
|
### vllm结合fastapi起服务,加速推理,支持多轮对话 |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import PeftModel |
|
import torch |
|
import uvicorn |
|
from fastapi import FastAPI |
|
import jsonlines |
|
from vllm import LLM, SamplingParams |
|
|
|
device = 'cuda' |
|
model_name = 'mxmax/baichuan-7b-sft-001' |
|
max_new_tokens = 512 |
|
top_p = 0.9 |
|
temperature = 0.35 |
|
repetition_penalty = 0.1 |
|
history_max_len = 1024 |
|
sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_new_tokens, presence_penalty=repetition_penalty) |
|
|
|
# Create an LLM. |
|
llm = LLM(model=model_name,trust_remote_code=True,dtype='float16') |
|
file = jsonlines.open('chat_record.json','a') |
|
app = FastAPI() |
|
|
|
@app.get('/') |
|
async def root(): |
|
return {"msg": "Hello World"} |
|
|
|
@app.post('/baichuan_sft_001') |
|
async def baichuan_sft_001(message: dict): |
|
prompt = '' |
|
for l in message['context']: |
|
prompt += 'human:'+l['human']+'\nassistant:'+l['assistant']+'</s>' |
|
prompt = '<s>'+prompt[-history_max_len:] |
|
outputs = llm.generate([prompt], sampling_params) |
|
result = outputs[0].outputs[0].text |
|
message['context'][-1]['assistant'] = result |
|
return {'model_ouput':result} |
|
|
|
if __name__ == '__main__': |
|
uvicorn.run('vllm_serving:app',host="0.0.0.0", port=6006) |
|
``` |
|
|
|
## 模型效果展示 |
|
![arch](./images/1.jpg) |
|
![arch](./images/2.jpg) |
|
![arch](./images/3.jpg) |
|
![arch](./images/4.jpg) |
|
![arch](./images/5.jpg) |
|
![arch](./images/6.jpg) |
|
|
|
## 联系方式 |
|
![arch](./images/微信好友二维码.jpg) |
|
加好友请备注:来自于huggingface网站交流技术+名字 |
|
|
|
qq群:621725172 |
|
|
|
## 引用 |
|
```bash |
|
@misc{mxmax, |
|
title={baichuan_sft: baichuan-7b-sft-001}, |
|
author={Ma Xin}, |
|
year={2023}, |
|
howpublished={\url{https://huggingface.co/mxmax/baichuan-7b-sft-001}}, |
|
} |
|
``` |