Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
from sentence_transformers import SentenceTransformer, util | |
#使页面布局更宽 | |
st.set_page_config(layout="wide") | |
def app(): | |
# 创建Streamlit应用程序 | |
st.title("对比句子的相似度") | |
source_text = st.text_input("源句子", value="") | |
st.write("待比较的句子:") | |
if "inputs" not in st.session_state: | |
# 创建一个空列表来存储输入框列表 | |
st.session_state.inputs = [] | |
st.session_state.inputs_index = 0 | |
with st.container(): | |
# 在容器中渲染已经存在的输入框列表 | |
for i in range(0, st.session_state.inputs_index): | |
st.session_state.inputs[i]= st.text_input(f"请输入第 {i+1} 个句子", "", key=i) | |
# 创建一个添加输入框的按钮 | |
add_input_button = st.button("添加一个待比较句子") | |
# 当用户点击按钮时往容器中添加新的输入框 | |
if add_input_button: | |
i = st.session_state.inputs_index | |
st.session_state.inputs.append(st.text_input(f"请输入第 {i+1} 个句子", "", key=i)) | |
# 自增输入框的key | |
st.session_state.inputs_index += 1 | |
button_generate = st.button("计算") | |
button_clear = st.button("清空") | |
def transformer(source_text, sentences): | |
#使用模型分别计算源字符串和对比字符串的embedding | |
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', device="cpu") | |
source_emb = model.encode(source_text, convert_to_tensor=True) | |
sent_embs = model.encode(sentences, convert_to_tensor=True) | |
#计算源字符串和对比字符串embedding的cos值 | |
cos_sim = util.cos_sim(source_emb, sent_embs) | |
cosin_dict = {} | |
for i, cos in enumerate(torch.flatten(cos_sim)): | |
cosin_dict[sentences[i]] = cos | |
#根据cos值降序排列 | |
sorted_dict = dict(sorted(cosin_dict.items(), key=lambda item: item[1],reverse = True)) | |
return sorted_dict | |
if button_generate: | |
# embeddings | |
embeddings = transformer([source_text], st.session_state.inputs) | |
# 显示对比的字符串、进度条、cos值 | |
with st.container(): | |
for sent, cos in embeddings.items(): | |
col1, col2, col3 = st.columns(3) | |
cos_value = round(float(cos.item()),4) | |
with col1: | |
st.text(sent) | |
with col2: | |
bar = st.progress(cos_value) | |
with col3: | |
st.text(cos_value) | |
if button_clear: | |
#清空对比字符串 | |
st.session_state.inputs.clear() | |
del st.session_state["inputs"] | |
st.session_state.inputs_index = 0 | |
source_text = '' | |
st.experimental_rerun() | |
if __name__ == "__main__": | |
# 运行应用程序 | |
app() |