webui / final.py
zhangyi617's picture
Upload folder using huggingface_hub
129cd69
raw
history blame contribute delete
No virus
8.73 kB
import os
os.environ["OPENAI_API_KEY"] = "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"
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.vectorstores import Chroma
from pypinyin import lazy_pinyin
import gradio as gr
import openai
import random
# import logging
# logging.basicConfig(
# filename='log/log.log',
# level=logging.INFO,
# format='%(asctime)s - %(levelname)s - %(message)s',
# datefmt='%m/%d/%Y %H:%M:%S'
# )
embedding = OpenAIEmbeddings()
target_files = set()
topics = ["农业", "宗教与文化", "建筑业与制造业", "医疗卫生保健", "国家治理", "法律法规", "财政税收", "教育", "金融", "贸易", "宏观经济", "社会发展", "科学技术", "能源环保", "国际关系", "国防安全","不限主题"]
def get_path(target_string):
folder_path = "./vector_data"
all_vectors = os.listdir(folder_path)
matching_files = [file for file in all_vectors if file.startswith(target_string)]
for file in matching_files:
file_path = os.path.join(folder_path, file)
return file_path
return ""
def extract_partial_message(res_message, response):
for chunk in response:
if len(chunk["choices"][0]["delta"]) != 0:
res_message = res_message + chunk["choices"][0]["delta"]["content"]
yield res_message
def format_messages(sys_prompt, history, message):
history_openai_format = [{"role": "system", "content": sys_prompt}]
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({"role": "assistant", "content": assistant})
history_openai_format.append({"role": "user", "content": message})
return history_openai_format
def get_domain(history, message):
sys_prompt = """
帮我根据用户的问题划分到以下几个类别,输出最匹配的一个类别:[宗教与文化, 农业, 建筑业与制造业, 医疗卫生保健, 国家治理, 法律法规, 财政税收, 教育, 金融, 贸易, 宏观经济, 社会发展, 科学技术, 能源环保, 国际关系, 国防安全]
"""
history_openai_format = format_messages(sys_prompt, history, message)
print("history_openai_format:", history_openai_format)
# logging.info(f"history_openai_format: {history_openai_format}")
response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=False)
domain = response['choices'][0]['message']['content']
print("匹配领域:", domain)
# logging.info(f"匹配领域: {domain}")
return domain
def echo(message, history, flag1, flag2):
global target_files, topics
print("flag1:", flag1)
print("flag2:", flag2)
print("history:", history)
print("message:", message)
# logging.info(f"flag1: {flag1}")
# logging.info(f"flag2: {flag2}")
# logging.info(f"history: {history}")
# logging.info(f"message: {message}")
if len(flag1) == 0: # 不进行研报问答&研报生成
target_files.clear()
history.clear()
if flag2 not in [None, 16]:
domain = topics[flag2]
message = f"{domain}领域相关内容"
elif flag2 in [None, 16]:
message = f"{topics[random.randint(0, 16)]}领域相关内容" if message == "" else message
domain = get_domain(history, message)
persist_vector_path = get_path("".join(lazy_pinyin(domain)))
db = Chroma(persist_directory=persist_vector_path, embedding_function=embedding)
docs = db.similarity_search_with_score(query=message, k=5)
contents = [doc[0] for doc in docs]
relevance = " ".join(doc.page_content for doc in contents)
source = [doc.metadata for doc in contents]
for item in source:
target_files.add(item['source'] )
print("研报搜索结果:", target_files)
# logging.info(f"研报搜索结果: {target_files}")
sys_prompt = """
你是一个研报助手,根据这篇文章:{}
来回复用户的问题生成总结,你需要严格按照这种格式回复:以上文章总结了*,主要观点是*。你只能回复中文。
如果这篇文章无法回答用户的问题,你必须根据你的知识面直接编造内容进行回答,不能回复不知道,不能说这段文字没有提供等话语。
"""
sys_prompt = sys_prompt.format(relevance)
history_openai_format = format_messages(sys_prompt, history, message)
print("history_openai_format:", history_openai_format)
# logging.info(f"history_openai_format: {history_openai_format}")
response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True)
partial_message = "搜索结果来自以下研报:\n" + '\n'.join(i for i in target_files) + '\n\n'
for result_message in extract_partial_message(partial_message, response):
yield result_message
elif flag1 == ['研报问答']:
print("target_files:", target_files)
# logging.info(f"target_files: {target_files}")
QA_pages = []
if not target_files:
yield "请取消选中研报问答,先进行研报检索,再进行问答。"
else:
for item in target_files:
loader = PyPDFLoader(item)
QA_pages.extend(loader.load_and_split())
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
documents = text_splitter.split_documents(QA_pages)
db = Chroma.from_documents(documents, OpenAIEmbeddings())
docs = db.similarity_search_with_score(query=message, k=3)
contents = [doc[0] for doc in docs]
relevance = " ".join(doc.page_content for doc in contents)
sys_prompt = """
你是一个研报助手,根据这篇文章:{}
来回复用户的问题,如果这篇文章无法回答用户的问题,你必须根据你的知识面来编造进行专业的回答,
不能回复不知道,不能回复这篇文章不能回答的这种话语,你只能回复中文。
"""
sys_prompt = sys_prompt.format(relevance)
history_openai_format = format_messages(sys_prompt, history, message)
print("history_openai_format:", history_openai_format)
# logging.info(f"history_openai_format: {history_openai_format}")
response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True)
for result_message in extract_partial_message("", response):
yield result_message
elif flag1 == ['研报生成']:
target_files.clear()
sys_prompt = """
你是一个研报助手,请根据用户的要求回复问题。
"""
history_openai_format = format_messages(sys_prompt, history, message)
print("history_openai_format:", history_openai_format)
# logging.info(f"history_openai_format: {history_openai_format}")
response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True)
for result_message in extract_partial_message("", response):
yield result_message
elif len(flag1) == 2:
yield "请选中一个选项,进行相关问答。"
demo = gr.ChatInterface(
echo,
chatbot=gr.Chatbot(height=430, label="ChatReport"),
textbox=gr.Textbox(placeholder="请输入问题", container=False, scale=7),
title="研报助手",
description="清芬院研报助手",
theme="soft",
additional_inputs=[
# gr.Radio(["研报问答", "研报生成"], type="index", label = "function"),
# gr.Checkbox(label = "研报问答"),
# gr.Checkbox(label = "研报生成"),
gr.CheckboxGroup(["研报问答", "研报生成"], label="Function"),
gr.Dropdown(topics, type="index"),
# gr.Button(value="Run").click(echo, inputs=["", "", [], None], outputs=[""])
# btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3])
# gr.Blocks()
],
# retry_btn="retry",
undo_btn="清空输入框",
clear_btn="清空聊天记录"
).queue()
if __name__ == "__main__":
demo.launch(share=True)
'''
flag1: ['研报问答']
flag2: None
history: []
message: gg
target_files: set()
'''