saffr0n's picture
Fix system prompt input bug into generate(), and add translations of interface
aa5f65f verified
raw
history blame
No virus
6.19 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 512
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# Msaidizi wa AI ya Kiswahili
Hii inaonyesha kielelezo cha Kiswahili (Jacaranda) kilichoundwa kutoka Llama-2 7b, kinachotumiwa kama msaidizi wa AI kwa maisha ya kila siku.
(This Space demonstrates the [Swahili (Jacaranda) model](https://huggingface.co/abhinand/tamil-llama-7b-instruct-v0.1) fine-tuned from Llama-2 7b, used as a daily life AI assistant.)
"""
LICENSE = """
<p/>
---
As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
"""
SYSTEM_PROMPT = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
PROMPT_TEMPLATE = """{% if messages[0]['role'] == 'system' %}{{ messages[0]['content'] + '\n\n' }}{% endif %}### Instruction:\nWewe ni msaidizi wa AI unayepiga gumzo na mtumiaji.Hii ndiyo historia ya soga ya watu unaowasiliana nao kufikia sasa:\n\n{% for message in messages %}{% if message['role'] == 'user' %}{{ '\nUser: ' + message['content'] + '\n'}}{% elif message['role'] == 'assistant' %}{{ '\nAI: ' + message['content'] + '\n'}}{% endif %}{% endfor %}\n\nKama msaidizi wa AI, andika jibu lako linalofuata kwenye gumzo.\n\n### Response:\n"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_id = "Jacaranda/UlizaLlama"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.chat_template = PROMPT_TEMPLATE
tokenizer.use_default_system_prompt = False
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
print("chat history: ", chat_history)
conversation = [{"role": "system", "content": SYSTEM_PROMPT}]
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
print(tokenizer.apply_chat_template(conversation, tokenize=False))
print(conversation)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
examples = [
["Ninawezaje kupata usingizi haraka?"],
["Bosi wangu anadhibiti sana, nifanye nini?"],
["Je, ni vipindi gani muhimu katika historia vya kujua kuvihusu?"],
["Ni kazi gani nzuri ikiwa ninataka kupata pesa lakini pia kufurahiya?"],
["Nivae nini kwenye harusi?"],
]
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Ingiza ujumbe wako / Enter your message")
submit_btn = gr.Button("Wasilisha / Submit")
clear = gr.Button("Wazi / Clear")
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
user_message = history[-1][0]
chat_history = [(msg[0], msg[1]) for msg in history[:-1]]
bot_message = ""
for response in generate(user_message, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
bot_message = response
history[-1][1] = bot_message
yield history
gr.Examples(examples=examples, inputs=[msg], label="Mifano / Examples")
with gr.Accordion("Chaguzi za Juu / Advanced Options", open=False):
max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
submit_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot,
[chatbot, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
chatbot,
)
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot,
[chatbot, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
chatbot,
)
clear.click(lambda: None, None, chatbot, queue=False)
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.queue(max_size=20).launch()