saffr0n's picture
Update to Swahili model
b437d23 verified
raw
history blame
No virus
5.05 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 512
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# Tamil Llama 2
This Space demonstrates the [Swahili (Jacaranda) model](https://huggingface.co/abhinand/tamil-llama-7b-instruct-v0.1) fine-tuned from Llama-2 7b, used as a daily life AI assistant.
"""
LICENSE = """
<p/>
---
As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
"""
SYSTEM_PROMPT = ""
PROMPT_TEMPLATE = """Below is a chat history between a user and an AI assistant (you). The user gives instructions. Write a response that appropriately completes the latest request.\n\n{% for message in messages %}{% if message['role'] == 'user' %}{{ '\n### Instruction:\n' + message['content'] + '\n'}}{% elif message['role'] == 'assistant' %}{{ '\n### Response:\n' + message['content'] + '\n'}}{% endif %}{% endfor %}\n\n### Response:\n"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_id = "Jacaranda/UlizaLlama"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.add_special_tokens({'pad_token': '[PAD]'}
tokenizer.chat_template = PROMPT_TEMPLATE
tokenizer.use_default_system_prompt = False
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str = SYSTEM_PROMPT,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Ninawezaje kupata usingizi haraka?"],
["Bosi wangu anadhibiti sana, nifanye nini?"],
["Je, ni vipindi gani muhimu katika historia vya kujua kuvihusu?"],
["Ni kazi gani nzuri ikiwa ninataka kupata pesa lakini pia kufurahiya?"],
["Nivae nini kwenye harusi?"],
],
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.queue(max_size=20).launch()