|
from threading import Thread |
|
|
|
import torch |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
|
|
model_id = "fireballoon/baichuan-vicuna-chinese-7b" |
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print("Running on device:", torch_device) |
|
print("CPU threads:", torch.get_num_threads()) |
|
|
|
|
|
if torch_device == "cuda": |
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).cuda() |
|
else: |
|
model = AutoModelForCausalLM.from_pretrained(model_id) |
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) |
|
|
|
|
|
def run_generation(history, *args, **kwargs): |
|
|
|
|
|
instruction = "A chat between a curious user and an artificial intelligence assistant. " \ |
|
"The assistant gives helpful, detailed, and polite answers to the user's questions." |
|
context = ''.join([f" USER: {turn[0].strip()} ASSISTANT: {turn[1].strip()} </s>" for turn in history[:-1]]) |
|
prompt = instruction + context + f" USER: {history[-1][0].strip()} ASSISTANT:" |
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda() |
|
|
|
print() |
|
print(prompt) |
|
print('##', input_ids.size()) |
|
|
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
input_ids=input_ids, |
|
streamer=streamer, |
|
max_new_tokens=2048, |
|
do_sample=True, |
|
temperature=0.7, |
|
repetition_penalty=1.1, |
|
top_p=0.85 |
|
) |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
|
|
history[-1][1] = "" |
|
|
|
print("") |
|
for new_text in streamer: |
|
history[-1][1] += new_text |
|
print(new_text, end="", flush=True) |
|
yield history |
|
print('</s>') |
|
return history |
|
|
|
|
|
def reset_textbox(): |
|
return gr.update(value='') |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
"# Baichuan Vicuna Chinese\n" |
|
f"[{model_id}](https://huggingface.co/{model_id}):使用中英双语sharegpt数据全参数微调的对话模型,基于baichuan-7b" |
|
) |
|
chatbot = gr.Chatbot().style(height=600) |
|
msg = gr.Textbox() |
|
clear = gr.ClearButton([msg, chatbot]) |
|
|
|
def user(user_message, history): |
|
return gr.update(value="", interactive=False), history + [[user_message, None]] |
|
|
|
response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
run_generation, chatbot, chatbot |
|
) |
|
response.then(lambda: gr.update(interactive=True), None, [msg], queue=False) |
|
|
|
demo.queue() |
|
demo.launch(server_name='0.0.0.0') |
|
|