schroneko commited on
Commit
e8b1714
1 Parent(s): 00a4be6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +48 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
+ import gradio as gr
5
+ import spaces
6
+
7
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
8
+ if not huggingface_token:
9
+ raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
10
+
11
+ model_id = "tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.1"
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ dtype = torch.bfloat16
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token)
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ model_id,
18
+ device_map=device,
19
+ torch_dtype=dtype,
20
+ token=huggingface_token
21
+ )
22
+
23
+ @spaces.GPU
24
+ def generate_text(prompt, system_message="あなたは誠実で優秀な日本人のアシスタントです。"):
25
+ messages = [
26
+ {"role": "user", "content": prompt},
27
+ ]
28
+
29
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(model.device)
30
+
31
+ outputs = model.generate(inputs, max_new_tokens=256, do_sample=True, temperature=0.7)
32
+ generated_text = tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0]
33
+
34
+ return generated_text.strip()
35
+
36
+ iface = gr.Interface(
37
+ fn=generate_text,
38
+ inputs=[
39
+ gr.Textbox(lines=3, label="Input Prompt"),
40
+ gr.Textbox(lines=2, label="System Message", value="あなたは誠実で優秀な日本人のアシスタントです。"),
41
+ ],
42
+ outputs=gr.Textbox(label="Generated Text"),
43
+ title="Llama-3.1-Swallow Text Generation",
44
+ description="Enter a prompt and optional system message to generate text using the Llama-3.1-Swallow model. This model is optimized for Japanese language input and output.",
45
+ )
46
+
47
+ if __name__ == "__main__":
48
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ transformers
4
+ accelerate