HODACHI's picture
Update app.py
2b88a75 verified
raw
history blame contribute delete
No virus
2.14 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, pipeline
import torch
from threading import Thread
MODEL_ID = "HODACHI/Llama-3.1-8B-EZO-1.1-it"
DTYPE = torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
device_map="auto",
low_cpu_mem_usage=True,
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto",
)
def generate_text(prompt, max_new_tokens, temperature, top_p):
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
streamer=streamer,
)
thread = Thread(target=pipe, kwargs=dict(text_inputs=prompt, **generation_kwargs))
thread.start()
return streamer
def respond(message, history, max_tokens, temperature, top_p):
chat = []
chat.append({"role": "system", "content": "あなたは誠実で優秀な日本人のアシスタントです。特に指示が無い場合は、原則日本語で回答してください。"})
for user, assistant in history:
chat.append({"role": "user", "content": user})
chat.append({"role": "assistant", "content": assistant})
chat.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
streamer = generate_text(prompt, max_tokens, temperature, top_p)
response = ""
for new_text in streamer:
response += new_text
yield response
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
demo.launch()