marutenbo / app.py
isonuma's picture
Update app.py
7db7298 verified
raw
history blame contribute delete
No virus
1.49 kB
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
from st_keyup import st_keyup
from utils import is_hiragana_or_katakana, search_candidates
# model_name_or_path = "tokyotech-llm/Llama-3-Swallow-8B-v0.1"
model_name_or_path = "tokyotech-llm/Swallow-7b-hf"
# model_name_or_path = "llm-jp/llm-jp-1.3b-v1.0"
tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16)
# model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
# Show title and description.
st.title("丸点棒AI")
st.write(
""
)
query_candidates = {"": ([""], 0)}
query = st_keyup(
"お題",
placeholder="ひらがな/カタカナのみを入力",
)
if query != "" and is_hiragana_or_katakana(query):
if query in query_candidates:
top_candidates, top_losses = query_candidates[query]
else:
# top_candidates = [query]
# top_losses = [0.0]
top_candidates, top_losses = search_candidates(query, query_candidates, model=model, tokenizer=tokenizer, top_k=10)
answers = ["{}: {:.2f}".format(top_candidates[index], top_losses[index]) for index in range(min(len(top_candidates), 10))]
value = "\n".join(answers)
value += f"\n({len(top_candidates)}候補)"
st.info(value)
else:
st.info("ひらがな/カタカナのみを入力してください")