Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
from sentence_transformers import SentenceTransformer, util | |
def app(): | |
# 创建Streamlit应用程序 | |
st.title("对比句子的相似度") | |
source_text = st.text_input("源句子", value="") | |
st.write("待比较的句子:") | |
if "inputs" not in st.session_state: | |
# 创建一个空列表来存储输入框列表 | |
st.session_state.inputs = [] | |
st.session_state.inputs_index = 0 | |
with st.container(): | |
# 在容器中渲染已经存在的输入框列表 | |
for i in range(0, st.session_state.inputs_index): | |
st.session_state.inputs[i]= st.text_input(f"请输入第 {i+1} 个句子", "", key=i) | |
# 创建一个添加输入框的按钮 | |
add_input_button = st.button("添加一个待比较句子") | |
# 当用户点击按钮时往容器中添加新的输入框 | |
if add_input_button: | |
i = st.session_state.inputs_index | |
st.session_state.inputs.append(st.text_input(f"请输入第 {i+1} 个句子", "", key=i)) | |
# 自增输入框的key | |
st.session_state.inputs_index += 1 | |
button_generate = st.button("计算") | |
button_clear = st.button("清空") | |
def transformer(source_text, sentences): | |
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device="cpu") | |
source_emb = model.encode(source_text, convert_to_tensor=True) | |
sent_embs = model.encode(sentences, convert_to_tensor=True) | |
cos_sim = util.cos_sim(source_emb, sent_embs) | |
st.write(source_emb) | |
st.write(sent_embs) | |
st.write(cos_sim) #output tensor([[1.0000, 0.3624]]) | |
cosin_dict = {} | |
for i, cos in enumerate(torch.flatten(cos_sim)): | |
cosin_dict[sentences[i]] = cos | |
sorted_dict = dict(sorted(cosin_dict.items(), key=lambda item: item[1],reverse = True)) | |
return sorted_dict | |
if button_generate: | |
# embeddings | |
embeddings = transformer([source_text], st.session_state.inputs) | |
# 显示生成的文本 | |
st.write(embeddings) | |
#output_text.success(generated_text) | |
with st.container(): | |
for sent, cos in embeddings: | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.text(sent) | |
with col2: | |
bar = st.progress(cos) | |
with col3: | |
st.text(cos) | |
if button_clear: | |
st.session_state.inputs.clear() | |
del st.session_state["inputs"] | |
st.session_state.inputs_index = 0 | |
source_text = '' | |
st.experimental_rerun() | |
if __name__ == "__main__": | |
# 运行应用程序 | |
app() |