context-game / app.py
Allob's picture
Update app.py
b5f7332
import streamlit as st
import plotly.express as px
import pandas as pd
import random
import logging
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
@st.cache_resource
def load_model(name):
return SentenceTransformer(name)
@st.cache_data
def load_words_dataset():
dataset = load_dataset("marksverdhei/wordnet-definitions-en-2021", split="train")
return dataset["Word"]
@st.cache_data
def choose_secret_word():
all_words = load_words_dataset()
return random.choice(all_words)
all_words = load_words_dataset()
model_names = [
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
'BAAI/bge-small-en-v1.5'
]
models = {
name: load_model(name) for name in model_names
}
secret_word =choose_secret_word().lower().strip()
secret_embedding = [models[name].encode(secret_word) for name in model_names]
print("Secret word ", secret_word)
if 'words' not in st.session_state:
st.session_state['words'] = []
st.write('Try to guess a secret word by semantic similarity')
word = st.text_input("Input a word")
used_words = [w[0] for w in st.session_state['words']]
if st.button("Guess") or word:
if word not in used_words:
word_embedding = [models[name].encode(word.lower().strip()) for name in model_names]
similarities = [util.pytorch_cos_sim(secret_embedding[i], word_embedding[i]).cpu().numpy()[0][0] for i, name in enumerate(model_names)]
st.session_state['words'].append([str(word)] + similarities)
words_df = pd.DataFrame(
st.session_state['words'],
columns=["word"] + ["Similarity for " + name for name in model_names]
).sort_values(by=["Similarity for sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"], ascending=False)
st.dataframe(words_df, use_container_width=True)