Spaces:
Runtime error
Runtime error
File size: 8,727 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import os
os.environ["OPENAI_API_KEY"] = "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"
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.vectorstores import Chroma
from pypinyin import lazy_pinyin
import gradio as gr
import openai
import random
# import logging
# logging.basicConfig(
# filename='log/log.log',
# level=logging.INFO,
# format='%(asctime)s - %(levelname)s - %(message)s',
# datefmt='%m/%d/%Y %H:%M:%S'
# )
embedding = OpenAIEmbeddings()
target_files = set()
topics = ["农业", "宗教与文化", "建筑业与制造业", "医疗卫生保健", "国家治理", "法律法规", "财政税收", "教育", "金融", "贸易", "宏观经济", "社会发展", "科学技术", "能源环保", "国际关系", "国防安全","不限主题"]
def get_path(target_string):
folder_path = "./vector_data"
all_vectors = os.listdir(folder_path)
matching_files = [file for file in all_vectors if file.startswith(target_string)]
for file in matching_files:
file_path = os.path.join(folder_path, file)
return file_path
return ""
def extract_partial_message(res_message, response):
for chunk in response:
if len(chunk["choices"][0]["delta"]) != 0:
res_message = res_message + chunk["choices"][0]["delta"]["content"]
yield res_message
def format_messages(sys_prompt, history, message):
history_openai_format = [{"role": "system", "content": sys_prompt}]
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({"role": "assistant", "content": assistant})
history_openai_format.append({"role": "user", "content": message})
return history_openai_format
def get_domain(history, message):
sys_prompt = """
帮我根据用户的问题划分到以下几个类别,输出最匹配的一个类别:[宗教与文化, 农业, 建筑业与制造业, 医疗卫生保健, 国家治理, 法律法规, 财政税收, 教育, 金融, 贸易, 宏观经济, 社会发展, 科学技术, 能源环保, 国际关系, 国防安全]
"""
history_openai_format = format_messages(sys_prompt, history, message)
print("history_openai_format:", history_openai_format)
# logging.info(f"history_openai_format: {history_openai_format}")
response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=False)
domain = response['choices'][0]['message']['content']
print("匹配领域:", domain)
# logging.info(f"匹配领域: {domain}")
return domain
def echo(message, history, flag1, flag2):
global target_files, topics
print("flag1:", flag1)
print("flag2:", flag2)
print("history:", history)
print("message:", message)
# logging.info(f"flag1: {flag1}")
# logging.info(f"flag2: {flag2}")
# logging.info(f"history: {history}")
# logging.info(f"message: {message}")
if len(flag1) == 0: # 不进行研报问答&研报生成
target_files.clear()
history.clear()
if flag2 not in [None, 16]:
domain = topics[flag2]
message = f"{domain}领域相关内容"
elif flag2 in [None, 16]:
message = f"{topics[random.randint(0, 16)]}领域相关内容" if message == "" else message
domain = get_domain(history, message)
persist_vector_path = get_path("".join(lazy_pinyin(domain)))
db = Chroma(persist_directory=persist_vector_path, embedding_function=embedding)
docs = db.similarity_search_with_score(query=message, k=5)
contents = [doc[0] for doc in docs]
relevance = " ".join(doc.page_content for doc in contents)
source = [doc.metadata for doc in contents]
for item in source:
target_files.add(item['source'] )
print("研报搜索结果:", target_files)
# logging.info(f"研报搜索结果: {target_files}")
sys_prompt = """
你是一个研报助手,根据这篇文章:{}
来回复用户的问题生成总结,你需要严格按照这种格式回复:以上文章总结了*,主要观点是*。你只能回复中文。
如果这篇文章无法回答用户的问题,你必须根据你的知识面直接编造内容进行回答,不能回复不知道,不能说这段文字没有提供等话语。
"""
sys_prompt = sys_prompt.format(relevance)
history_openai_format = format_messages(sys_prompt, history, message)
print("history_openai_format:", history_openai_format)
# logging.info(f"history_openai_format: {history_openai_format}")
response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True)
partial_message = "搜索结果来自以下研报:\n" + '\n'.join(i for i in target_files) + '\n\n'
for result_message in extract_partial_message(partial_message, response):
yield result_message
elif flag1 == ['研报问答']:
print("target_files:", target_files)
# logging.info(f"target_files: {target_files}")
QA_pages = []
if not target_files:
yield "请取消选中研报问答,先进行研报检索,再进行问答。"
else:
for item in target_files:
loader = PyPDFLoader(item)
QA_pages.extend(loader.load_and_split())
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
documents = text_splitter.split_documents(QA_pages)
db = Chroma.from_documents(documents, OpenAIEmbeddings())
docs = db.similarity_search_with_score(query=message, k=3)
contents = [doc[0] for doc in docs]
relevance = " ".join(doc.page_content for doc in contents)
sys_prompt = """
你是一个研报助手,根据这篇文章:{}
来回复用户的问题,如果这篇文章无法回答用户的问题,你必须根据你的知识面来编造进行专业的回答,
不能回复不知道,不能回复这篇文章不能回答的这种话语,你只能回复中文。
"""
sys_prompt = sys_prompt.format(relevance)
history_openai_format = format_messages(sys_prompt, history, message)
print("history_openai_format:", history_openai_format)
# logging.info(f"history_openai_format: {history_openai_format}")
response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True)
for result_message in extract_partial_message("", response):
yield result_message
elif flag1 == ['研报生成']:
target_files.clear()
sys_prompt = """
你是一个研报助手,请根据用户的要求回复问题。
"""
history_openai_format = format_messages(sys_prompt, history, message)
print("history_openai_format:", history_openai_format)
# logging.info(f"history_openai_format: {history_openai_format}")
response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True)
for result_message in extract_partial_message("", response):
yield result_message
elif len(flag1) == 2:
yield "请选中一个选项,进行相关问答。"
demo = gr.ChatInterface(
echo,
chatbot=gr.Chatbot(height=430, label="ChatReport"),
textbox=gr.Textbox(placeholder="请输入问题", container=False, scale=7),
title="研报助手",
description="清芬院研报助手",
theme="soft",
additional_inputs=[
# gr.Radio(["研报问答", "研报生成"], type="index", label = "function"),
# gr.Checkbox(label = "研报问答"),
# gr.Checkbox(label = "研报生成"),
gr.CheckboxGroup(["研报问答", "研报生成"], label="Function"),
gr.Dropdown(topics, type="index"),
# gr.Button(value="Run").click(echo, inputs=["", "", [], None], outputs=[""])
# btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3])
# gr.Blocks()
],
# retry_btn="retry",
undo_btn="清空输入框",
clear_btn="清空聊天记录"
).queue()
if __name__ == "__main__":
demo.launch(share=True)
'''
flag1: ['研报问答']
flag2: None
history: []
message: gg
target_files: set()
''' |