import streamlit as st import torch from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer from st_keyup import st_keyup from utils import is_hiragana_or_katakana, search_candidates # model_name_or_path = "tokyotech-llm/Llama-3-Swallow-8B-v0.1" model_name_or_path = "tokyotech-llm/Swallow-7b-hf" # model_name_or_path = "llm-jp/llm-jp-1.3b-v1.0" tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16) # model = AutoModelForCausalLM.from_pretrained(model_name_or_path) # Show title and description. st.title("丸点棒AI") st.write( "" ) query_candidates = {"": ([""], 0)} query = st_keyup( "お題", placeholder="ひらがな/カタカナのみを入力", ) if query != "" and is_hiragana_or_katakana(query): if query in query_candidates: top_candidates, top_losses = query_candidates[query] else: # top_candidates = [query] # top_losses = [0.0] top_candidates, top_losses = search_candidates(query, query_candidates, model=model, tokenizer=tokenizer, top_k=10) answers = ["{}: {:.2f}".format(top_candidates[index], top_losses[index]) for index in range(min(len(top_candidates), 10))] value = "\n".join(answers) value += f"\n({len(top_candidates)}候補)" st.info(value) else: st.info("ひらがな/カタカナのみを入力してください")