schroneko's picture
Update app.py
31697a6 verified
raw
history blame contribute delete
No virus
1.76 kB
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import gradio as gr
import spaces
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
if not huggingface_token:
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
model_id = "tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.1"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device,
torch_dtype=dtype,
token=huggingface_token
)
@spaces.GPU
def generate_text(prompt, system_message="あなたは誠実で優秀な日本人のアシスタントです。"):
messages = [
{"role": "user", "content": prompt},
]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(model.device)
outputs = model.generate(inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
generated_text = tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0]
return generated_text.strip()
iface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(lines=3, label="Input Prompt"),
gr.Textbox(lines=2, label="System Message", value="あなたは誠実で優秀な日本人のアシスタントです。"),
],
outputs=gr.Textbox(label="Generated Text"),
title="Llama-3.1-Swallow Text Generation",
description="Enter a prompt and optional system message to generate text using the Llama-3.1-Swallow model. This model is optimized for Japanese language input and output.",
)
if __name__ == "__main__":
iface.launch()