File size: 2,703 Bytes
8428b8e
7fcfcdc
30a40c5
8428b8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30a40c5
 
 
 
 
 
 
 
 
8428b8e
 
ea6e2e4
8428b8e
285e2b7
8428b8e
 
30a40c5
 
8428b8e
51f25dc
 
7fcfcdc
15936c4
8428b8e
7fcfcdc
30a40c5
8428b8e
 
 
 
 
 
 
285e2b7
8428b8e
 
a91b3c4
8428b8e
 
285e2b7
 
51f25dc
22c71f8
 
51f25dc
 
 
 
285e2b7
ea6e2e4
 
9db1c97
ea6e2e4
30a40c5
 
ea6e2e4
285e2b7
 
8428b8e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
import torch
from sentence_transformers import SentenceTransformer, util

def app():
    # 创建Streamlit应用程序

    st.title("对比句子的相似度")

    source_text = st.text_input("源句子", value="")

    st.write("待比较的句子:")

    if "inputs" not in st.session_state:
	    # 创建一个空列表来存储输入框列表
        st.session_state.inputs = []
        st.session_state.inputs_index = 0


    with st.container():
        # 在容器中渲染已经存在的输入框列表
        for i in range(0, st.session_state.inputs_index):
            st.session_state.inputs[i]= st.text_input(f"请输入第 {i+1} 个句子", "", key=i)

    # 创建一个添加输入框的按钮
    add_input_button = st.button("添加一个待比较句子")

	# 当用户点击按钮时往容器中添加新的输入框
    if add_input_button:
        i = st.session_state.inputs_index
        st.session_state.inputs.append(st.text_input(f"请输入第 {i+1} 个句子", "", key=i))
        # 自增输入框的key
        st.session_state.inputs_index += 1
    
    button_generate = st.button("计算")
    button_clear = st.button("清空")

    def transformer(source_text, sentences):
        model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device="cpu")
        source_emb = model.encode(source_text, convert_to_tensor=True)
        sent_embs = model.encode(sentences, convert_to_tensor=True)
        cos_sim = util.cos_sim(source_emb, sent_embs)

        st.write(source_emb)
        st.write(sent_embs)
        st.write(cos_sim) #output tensor([[1.0000, 0.3624]])
        cosin_dict = {}

        for i, cos in enumerate(torch.flatten(cos_sim)):
            cosin_dict[sentences[i]] = cos
        
        sorted_dict = dict(sorted(cosin_dict.items(), key=lambda item: item[1],reverse = True))
        return sorted_dict


    if button_generate:
        # embeddings
        embeddings = transformer([source_text], st.session_state.inputs)

        # 显示生成的文本
        st.write(embeddings)
        #output_text.success(generated_text)

        with st.container():
            for sent, cos in embeddings:
                col1, col2, col3 = st.columns(3)
                with col1:
                    st.text(sent)
                with col2:
                    bar = st.progress(cos)
                with col3:
                    st.text(cos)

    if button_clear:
        st.session_state.inputs.clear()
        del st.session_state["inputs"]
        st.session_state.inputs_index = 0
        source_text = ''
        st.experimental_rerun()

   

if __name__ == "__main__":
    # 运行应用程序
    app()