File size: 8,727 Bytes
129cd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import os
os.environ["OPENAI_API_KEY"] = "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"
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.vectorstores import Chroma
from pypinyin import lazy_pinyin
import gradio as gr
import openai
import random
# import logging
# logging.basicConfig(
#     filename='log/log.log',
#     level=logging.INFO,
#     format='%(asctime)s - %(levelname)s - %(message)s',
#     datefmt='%m/%d/%Y %H:%M:%S'
# )
embedding = OpenAIEmbeddings()
target_files = set()
topics = ["农业", "宗教与文化", "建筑业与制造业", "医疗卫生保健", "国家治理", "法律法规", "财政税收", "教育", "金融", "贸易", "宏观经济", "社会发展", "科学技术", "能源环保", "国际关系", "国防安全","不限主题"]

def get_path(target_string):
    folder_path = "./vector_data"
    all_vectors = os.listdir(folder_path)
    matching_files = [file for file in all_vectors if file.startswith(target_string)]
    for file in matching_files:
        file_path = os.path.join(folder_path, file)
        return file_path
    return ""


def extract_partial_message(res_message, response):
    for chunk in response:
        if len(chunk["choices"][0]["delta"]) != 0:
            res_message = res_message + chunk["choices"][0]["delta"]["content"]
            yield res_message


def format_messages(sys_prompt, history, message):
    history_openai_format = [{"role": "system", "content": sys_prompt}]
    for human, assistant in history:
        history_openai_format.append({"role": "user", "content": human})
        history_openai_format.append({"role": "assistant", "content": assistant})
    history_openai_format.append({"role": "user", "content": message})
    return history_openai_format

def get_domain(history, message):
    sys_prompt = """
    帮我根据用户的问题划分到以下几个类别,输出最匹配的一个类别:[宗教与文化, 农业, 建筑业与制造业, 医疗卫生保健, 国家治理, 法律法规, 财政税收, 教育, 金融, 贸易, 宏观经济, 社会发展, 科学技术, 能源环保, 国际关系, 国防安全]
    """
    history_openai_format = format_messages(sys_prompt, history, message)
    print("history_openai_format:", history_openai_format)
    # logging.info(f"history_openai_format: {history_openai_format}")
    response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=False)
    domain = response['choices'][0]['message']['content']
    print("匹配领域:", domain)
    # logging.info(f"匹配领域: {domain}")
    return domain

def echo(message, history, flag1, flag2):
    global target_files, topics
    print("flag1:", flag1)
    print("flag2:", flag2) 
    print("history:", history)
    print("message:", message)
    # logging.info(f"flag1: {flag1}")
    # logging.info(f"flag2: {flag2}")
    # logging.info(f"history: {history}")
    # logging.info(f"message: {message}")
    if len(flag1) == 0:  # 不进行研报问答&研报生成
        target_files.clear()
        history.clear()
        if flag2 not in [None, 16]:
            domain = topics[flag2]
            message = f"{domain}领域相关内容"
        elif flag2 in [None, 16]:
            message = f"{topics[random.randint(0, 16)]}领域相关内容" if message == "" else message
            domain = get_domain(history, message)
        persist_vector_path = get_path("".join(lazy_pinyin(domain)))
        
        db = Chroma(persist_directory=persist_vector_path, embedding_function=embedding)
        docs = db.similarity_search_with_score(query=message, k=5)
        contents = [doc[0] for doc in docs]
        relevance = "  ".join(doc.page_content for doc in contents)
        source = [doc.metadata for doc in contents]
        for item in source:
            target_files.add(item['source'] )
        print("研报搜索结果:", target_files)
        # logging.info(f"研报搜索结果: {target_files}")

        sys_prompt = """
        你是一个研报助手,根据这篇文章:{}
        来回复用户的问题生成总结,你需要严格按照这种格式回复:以上文章总结了*,主要观点是*。你只能回复中文。
        如果这篇文章无法回答用户的问题,你必须根据你的知识面直接编造内容进行回答,不能回复不知道,不能说这段文字没有提供等话语。
        """
        sys_prompt = sys_prompt.format(relevance)
        history_openai_format = format_messages(sys_prompt, history, message)
        print("history_openai_format:", history_openai_format)
        # logging.info(f"history_openai_format: {history_openai_format}")
        response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True)
        partial_message = "搜索结果来自以下研报:\n" + '\n'.join(i for i in target_files) + '\n\n'
        for result_message in extract_partial_message(partial_message, response):
            yield result_message
        

    elif flag1 == ['研报问答']:
        print("target_files:", target_files)
        # logging.info(f"target_files: {target_files}")
        QA_pages = []
        if not target_files:
            yield "请取消选中研报问答,先进行研报检索,再进行问答。"
        else:
            for item in target_files:
                loader = PyPDFLoader(item)
                QA_pages.extend(loader.load_and_split())
            text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
            documents = text_splitter.split_documents(QA_pages)
            db = Chroma.from_documents(documents, OpenAIEmbeddings())

            docs = db.similarity_search_with_score(query=message, k=3)
            contents = [doc[0] for doc in docs]
            relevance = "  ".join(doc.page_content for doc in contents)

            sys_prompt = """
            你是一个研报助手,根据这篇文章:{}
            来回复用户的问题,如果这篇文章无法回答用户的问题,你必须根据你的知识面来编造进行专业的回答,
            不能回复不知道,不能回复这篇文章不能回答的这种话语,你只能回复中文。
            """
            sys_prompt = sys_prompt.format(relevance)
            history_openai_format = format_messages(sys_prompt, history, message)
            print("history_openai_format:", history_openai_format)
            # logging.info(f"history_openai_format: {history_openai_format}")
            response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True)
            for result_message in extract_partial_message("", response):
                yield result_message

    elif flag1 == ['研报生成']:
        target_files.clear()
        sys_prompt = """
        你是一个研报助手,请根据用户的要求回复问题。
        """
        history_openai_format = format_messages(sys_prompt, history, message)
        print("history_openai_format:", history_openai_format)
        # logging.info(f"history_openai_format: {history_openai_format}")
        response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True)
        for result_message in extract_partial_message("", response):
            yield result_message

    elif len(flag1) == 2:
        yield "请选中一个选项,进行相关问答。"


demo = gr.ChatInterface(
        echo, 
        chatbot=gr.Chatbot(height=430, label="ChatReport"),
        textbox=gr.Textbox(placeholder="请输入问题", container=False, scale=7),
        title="研报助手",
        description="清芬院研报助手",
        theme="soft",
        additional_inputs=[
            # gr.Radio(["研报问答", "研报生成"], type="index", label = "function"),
            # gr.Checkbox(label = "研报问答"),
            # gr.Checkbox(label = "研报生成"),
            gr.CheckboxGroup(["研报问答", "研报生成"], label="Function"),
            gr.Dropdown(topics, type="index"),
            # gr.Button(value="Run").click(echo, inputs=["", "", [], None], outputs=[""])
            # btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3])
            # gr.Blocks()
        ],
        # retry_btn="retry",
        undo_btn="清空输入框",
        clear_btn="清空聊天记录"
        ).queue()


if __name__ == "__main__":
    demo.launch(share=True)
    





'''
flag1: ['研报问答']
flag2: None
history: []
message: gg
target_files: set()
'''