Spaces:
Runtime error
Runtime error
File size: 2,703 Bytes
8428b8e 7fcfcdc 30a40c5 8428b8e 30a40c5 8428b8e ea6e2e4 8428b8e 285e2b7 8428b8e 30a40c5 8428b8e 51f25dc 7fcfcdc 15936c4 8428b8e 7fcfcdc 30a40c5 8428b8e 285e2b7 8428b8e a91b3c4 8428b8e 285e2b7 51f25dc 22c71f8 51f25dc 285e2b7 ea6e2e4 9db1c97 ea6e2e4 30a40c5 ea6e2e4 285e2b7 8428b8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import streamlit as st
import torch
from sentence_transformers import SentenceTransformer, util
def app():
# 创建Streamlit应用程序
st.title("对比句子的相似度")
source_text = st.text_input("源句子", value="")
st.write("待比较的句子:")
if "inputs" not in st.session_state:
# 创建一个空列表来存储输入框列表
st.session_state.inputs = []
st.session_state.inputs_index = 0
with st.container():
# 在容器中渲染已经存在的输入框列表
for i in range(0, st.session_state.inputs_index):
st.session_state.inputs[i]= st.text_input(f"请输入第 {i+1} 个句子", "", key=i)
# 创建一个添加输入框的按钮
add_input_button = st.button("添加一个待比较句子")
# 当用户点击按钮时往容器中添加新的输入框
if add_input_button:
i = st.session_state.inputs_index
st.session_state.inputs.append(st.text_input(f"请输入第 {i+1} 个句子", "", key=i))
# 自增输入框的key
st.session_state.inputs_index += 1
button_generate = st.button("计算")
button_clear = st.button("清空")
def transformer(source_text, sentences):
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device="cpu")
source_emb = model.encode(source_text, convert_to_tensor=True)
sent_embs = model.encode(sentences, convert_to_tensor=True)
cos_sim = util.cos_sim(source_emb, sent_embs)
st.write(source_emb)
st.write(sent_embs)
st.write(cos_sim) #output tensor([[1.0000, 0.3624]])
cosin_dict = {}
for i, cos in enumerate(torch.flatten(cos_sim)):
cosin_dict[sentences[i]] = cos
sorted_dict = dict(sorted(cosin_dict.items(), key=lambda item: item[1],reverse = True))
return sorted_dict
if button_generate:
# embeddings
embeddings = transformer([source_text], st.session_state.inputs)
# 显示生成的文本
st.write(embeddings)
#output_text.success(generated_text)
with st.container():
for sent, cos in embeddings:
col1, col2, col3 = st.columns(3)
with col1:
st.text(sent)
with col2:
bar = st.progress(cos)
with col3:
st.text(cos)
if button_clear:
st.session_state.inputs.clear()
del st.session_state["inputs"]
st.session_state.inputs_index = 0
source_text = ''
st.experimental_rerun()
if __name__ == "__main__":
# 运行应用程序
app() |