Spaces:
Runtime error
Runtime error
"""This is an example of how to use async langchain with fastapi and return a streaming response.""" | |
import uvicorn, threading, logging, time, json, re, os | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from langchain import LLMChain | |
from starlette.types import Send | |
from langchain.chat_models import ChatOpenAI | |
from fastapi.responses import StreamingResponse | |
from logging.handlers import RotatingFileHandler | |
from langchain.memory import ConversationBufferMemory | |
from typing import Any, Optional, Awaitable, Callable, Iterator, Union, List | |
from langchain.callbacks.base import AsyncCallbackManager, AsyncCallbackHandler | |
import openai | |
os.environ["OPENAI_API_KEY"] = "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" | |
openai.proxy = {"http": "http://127.0.0.1:7890", "https": "http://127.0.0.1:7890"} | |
from utils import ( | |
prompt_memory, prompt_chat_term, prompt_basic, prompt_reco, | |
prompt_memory_character, prompt_chat_character, prompt_basic_character, prompt_reco_character, | |
memory_chat, memory_basic, memory_basic_character, memory_reco, memory_reco_character, | |
product, content_term, get_detailInfo, process_Info, | |
) | |
from config import log_path, log_file, pre_key_words | |
key_words = {} | |
# [完毕, 招呼, 配置, 点赞, 广告, 讲解, 多人脸, 未成功] | |
keywords_to_extract = ['wan_bi', 'zhao_hu', 'pei_zhi', 'dian_zan', 'guang_gao', 'jiang_jie', 'multi_face', 'un_success'] | |
for keyword in keywords_to_extract: | |
key_words[keyword] = pre_key_words[keyword] | |
logger = logging.getLogger('my_logger') | |
logger.setLevel(logging.DEBUG) | |
log_path = os.path.join(os.path.dirname(__file__), log_path) | |
if not os.path.exists(log_path): | |
os.makedirs(log_path) | |
log_file = "{}/{}".format(log_path, log_file) # 创建大小滚动的日志处理器,最大文件大小为20MB,保留10个历史日志文件 | |
file_handler = RotatingFileHandler(log_file, maxBytes=20 * 1024 * 1024, backupCount=10) | |
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
file_handler.setFormatter(formatter) | |
logger.addHandler(file_handler) | |
app = FastAPI() | |
Sender = Callable[[Union[str, bytes]], Awaitable[None]] | |
class EmptyIterator(Iterator[Union[str, bytes]]): | |
def __iter__(self): | |
return self | |
def __next__(self): | |
raise StopIteration | |
class AsyncStreamCallbackHandler(AsyncCallbackHandler): | |
"""Callback handler for streaming, inheritance from AsyncCallbackHandler.""" | |
def __init__(self, send: Sender, time: int, start_time: float): | |
super().__init__() | |
self.send = send | |
self.tokens = [] # 用于存储生成的token | |
self.time = time | |
self.start_time = start_time # 用户发送数据开始时间 | |
self.previous_time = start_time # 上一条句子返回的时间 | |
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
async def send_response(sentence): | |
response = generate_response(self.time, self.start_time, self.previous_time, content=sentence, broadcast=True, end=False) | |
self.previous_time = time.time() | |
time.sleep(0.2) | |
logger.info(f"--return body--:{json.dumps(response, ensure_ascii=False)}") | |
await self.send(f"{json.dumps(response, ensure_ascii=False)}\n") | |
self.tokens.append(token.strip()) | |
if any(punctuation in token for punctuation in [",", ",", "。", "!", "?", ":", ":", ";", ";"]): | |
if len(token) == 1: | |
sentence = "".join(self.tokens) | |
self.tokens = [] | |
await send_response(sentence) | |
else: # openai 生成的token有时候标点符号连着另一个token:",请" / ",把", 语音播报会生硬,所以单独处理 | |
sentence = "".join(self.tokens[:-1]) | |
self.tokens = [self.tokens[-1]] | |
await send_response(sentence) | |
def extract_product(text): | |
"""根据模型生成内容正则提取出推荐的险种名称,在预先定义好的阳光6类保险(product)中去根据名称拿出对应的保险信息""" | |
match = re.search(r"推荐:(.+?)[,。]", text) | |
insurance_name = match.group(1) if match else "未匹配到保险名称" | |
product_name = product.get(insurance_name, insurance_name) | |
return {"product": product_name} | |
def final_info(res): | |
"""根据UI把家庭信息放入数组,可优化,当前版本是必须按照[用户,父亲,母亲,妻子,儿子]的顺序,缺失的地方留出slot""" | |
family = { | |
"父亲": ["父亲", "爸爸"], | |
"母亲": ["母亲", "妈妈"], | |
"配偶": ["老公", "丈夫", "太太", "配偶", "妻子", "夫人", "老婆"], | |
"孩子": ["孩子", "儿子", "女儿", "小孩儿", "小孩"] | |
} | |
result = [res[0]] | |
for _, keywords in family.items(): | |
family_member = next((item for item in res[1:] if item['name'] in keywords), {'name': '', 'age': '', 'career': '', 'health': '', 'live': ''}) | |
result.append(family_member) | |
return result | |
def generate_response( | |
stamp, | |
start_time, | |
previous_time, | |
content: str = "", | |
broadcast: bool = True, | |
exitInsureProductInfo: bool = False, | |
existFamilyInfo: bool = False, | |
insureProductInfos: dict = [], | |
familyInfos: dict = [], | |
end: bool = False | |
) -> dict: | |
"""返回给客户端的接口字段定义""" | |
response = { | |
"time": stamp, | |
"totalElapsedTime": round(time.time() - start_time, 4), | |
"elapsedTimeSinceLast": round(time.time() - previous_time, 4), | |
"content": content, | |
"broadcast": broadcast, | |
"exitInsureProductInfo": exitInsureProductInfo, | |
"existFamilyInfo": existFamilyInfo, | |
"insureProductInfos": insureProductInfos, | |
"familyInfos": familyInfos, | |
"end": end | |
} | |
return response | |
class ChatOpenAIStreamingResponse(StreamingResponse): | |
"""Streaming response for openai chat model, inheritance from StreamingResponse.""" | |
def __init__( | |
self, | |
generate: Callable[[Sender], Awaitable[None]], | |
message: str, | |
time: int, | |
action: int, | |
start_time: float, | |
status_code: int = 200, | |
media_type: Optional[str] = None, | |
) -> None: | |
super().__init__(content=EmptyIterator(), status_code=status_code, media_type=media_type) | |
self.generate = generate | |
self.message = message | |
self.time = time | |
self.action = action | |
self.start_time = start_time | |
self.response_data = '' # 新增的属性,用于存储生成的数据 | |
async def stream_response(self, send: Send) -> None: | |
"""Rewrite stream_response to send response to client.""" | |
await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) | |
async def send_chunk(chunk: Union[str, bytes]): | |
if not isinstance(chunk, bytes): | |
chunk = chunk.encode(self.charset) | |
await send({"type": "http.response.body", "body": chunk, "more_body": True}) | |
dict_data = json.loads(chunk.decode(self.charset).strip()) | |
self.response_data += dict_data['content'] | |
async def send_word_response(self, response): | |
logger.info(f"--return body--:{json.dumps(response, ensure_ascii=False)}") | |
await send({"type": "http.response.body", "body": json.dumps(response, ensure_ascii=False).encode("utf-8"), "more_body": True}) | |
async def send_end_response(self, previous_time): | |
response = generate_response(self.time, self.start_time, previous_time, content = "", broadcast = False, end=True) | |
logger.info(f"--return body--:{json.dumps(response, ensure_ascii=False)}\n{'*' * 300}") | |
await send({"type": "http.response.body", "body": json.dumps(response, ensure_ascii=False).encode("utf-8"), "more_body": False}) | |
async def process_response(self, prefix): | |
"""处理固定回复,切割成小句返回""" | |
delimiter_pattern = r"[,。?!\.\?!]" | |
sentences = re.split(delimiter_pattern, prefix) | |
surfix = [i.strip() for i in sentences if len(i)] | |
previous_time = time.time() | |
for item in surfix: | |
response = generate_response(self.time, self.start_time, previous_time, content = item, broadcast = True, end=False) | |
await send_word_response(self, response) | |
previous_time = time.time() | |
time.sleep(0.5) | |
await send_end_response(self, previous_time) | |
print("self.message:", self.message) | |
global flag1, flag2, all_Info | |
previous_time = time.time() | |
if self.action == 2 and len(self.message) == 0: | |
await send_end_response(self, previous_time) | |
if self.action == 4: | |
await self.generate(send_chunk) if self.message else None | |
await process_response(self, prefix = key_words.get('multi_face', '')) | |
elif "赞" in self.message: | |
await process_response(self, prefix = key_words.get('dian_zan', '')) | |
elif "招呼" in self.message: | |
await process_response(self, prefix = key_words.get('zhao_hu', '')) | |
elif "讲解" in self.message: | |
await self.generate(send_chunk) | |
await process_response(self, prefix = key_words.get('jiang_jie', '')) | |
elif "广告" in self.message: | |
await process_response(self, prefix = key_words.get('guang_gao', '')) | |
elif "配置" in self.message: | |
all_Info = [] | |
await self.generate(send_chunk) | |
await process_response(self, prefix = key_words.get('pei_zhi', '')) | |
elif "首要" in self.message: | |
if len(all_Info) == 0: | |
await process_response(self, prefix = "抱歉,未能正确记录您的家庭信息,让我们重新开始搜集您的家庭信息之后,再为您推荐保险吧。") | |
else: | |
await self.generate(send_chunk) | |
previous_time = time.time() | |
print("response_data:", self.response_data) | |
result = extract_product(self.response_data) | |
all_Info[0]["name"] = "客户" | |
doll = {'name': '', 'age': '', 'career': '', 'health': '', 'live': ''} | |
user_info = [all_Info[0], doll, doll, doll, doll] | |
response = generate_response(self.time, self.start_time, previous_time, content = "请查看以下表格:", broadcast = True, exitInsureProductInfo=True, existFamilyInfo=True, insureProductInfos=result['product'], familyInfos=user_info, end=False) | |
await send_word_response(self, response) | |
time.sleep(0.5) | |
previous_time = time.time() | |
await process_response(self, prefix = "您可以向我咨询具体的保险条款,或直接向我获取代理人联系方式,他们能提供更加详细专业的回答。〞如已阅读完毕,请按方向键上键,我将继续回答您关于保险产品的问题。") | |
else: | |
await self.generate(send_chunk) | |
previous_time = time.time() | |
print("response_data:", self.response_data) | |
if "完毕" in self.response_data: | |
thread1 = threading.Thread(target = thread_function) | |
thread1.start() | |
thread1.join() # 这一行会让主线程等待thread1线程执行完毕,调用join()方法会阻塞主线程,直到thread1线程执行完成后才会继续执行主线程的后续代码 | |
if all_Info: | |
all_Info[0]["name"] = "客户" | |
all_Info = final_info(all_Info) | |
response = generate_response(self.time, self.start_time, previous_time, content = "以下表格是根据您提供的家庭结构和成员信息,", broadcast = True, existFamilyInfo=True, familyInfos=all_Info, end=False) | |
await send_word_response(self, response) | |
time.sleep(1) | |
previous_time = time.time() | |
await process_response(self, prefix = key_words['wan_bi']) | |
else: # 用户没有提供或者没有搜集到家庭信息:all_Info = [];跳转到闲聊模式 | |
restart_function() | |
await process_response(self, prefix = key_words['un_success']) | |
else: | |
await send_end_response(self, previous_time) | |
if flag1 == 1 and flag2 == 0: | |
thread1 = threading.Thread(target = thread_function) | |
thread1.start() | |
def thread_function(): | |
global family_info, all_Info | |
if family_info: | |
detail_Info = get_detailInfo(family_info) | |
print("detail_Info_1:\n", detail_Info) | |
if "人物" in detail_Info: | |
detail_Info = [i.strip() for i in detail_Info.strip().split('\n') if i.strip()] | |
count = sum(1 for item in detail_Info if item.count('未知') >= 4) | |
print("detail_Info_2:\n", len(detail_Info), detail_Info) | |
print("超过4个信息是未知的家人数(不计入最终统计)count:", count) | |
if len(detail_Info) != count and "*" in detail_Info[0]: | |
final = process_Info(detail_Info) | |
if final: | |
all_Info.extend(final) | |
print("all_Info: ", all_Info) | |
def restart_function(): | |
global flag1, flag2,family_info, all_Info, memory_chat, memory_reco, memory_reco_character, memory_basic, memory_basic_character | |
flag1, flag2, all_Info, family_info = 0, 0, [], "" | |
memory_chat = ConversationBufferMemory(memory_key="chat_history", ai_prefix="") | |
memory_reco = ConversationBufferMemory(memory_key="chat_history", input_key = "human_input") | |
memory_reco_character = ConversationBufferMemory(memory_key="chat_history", input_key = "human_input") | |
memory_basic = ConversationBufferMemory(memory_key="chat_history", ai_prefix="") | |
memory_basic_character = ConversationBufferMemory(memory_key="chat_history", ai_prefix="") | |
return | |
async def openai_function(send, time, start_time, prompt, message, model_name, memory=None, context: str = ""): | |
# model = AzureChatOpenAI(request_timeout = 8*60, deployment_name="gpt-35-turbo",openai_api_version="2023-03-15-preview", streaming=True, callback_manager=AsyncCallbackManager([AsyncStreamCallbackHandler(send, time, start_time)]), verbose=True, temperature=0.7) | |
model = ChatOpenAI(request_timeout = 8*60, model_name=model_name, streaming=True, callback_manager=AsyncCallbackManager([AsyncStreamCallbackHandler(send, time, start_time)]), verbose=True, temperature=0.7) | |
chain_args = { | |
'llm': model, | |
'prompt': prompt, | |
'verbose': True, | |
'memory': memory if memory else None | |
} | |
chain = LLMChain(**chain_args) | |
if context: | |
await chain.apredict(human_input=message, context=context) | |
else: | |
await chain.apredict(human_input=message) | |
switch, flag1, flag2, all_Info, family_info = 0, 0, 0, [], "" | |
def send_message(message: str, time: int, action: int, dialogue_memory:str, start_time: float, switch: int) -> Callable[[Sender], Awaitable[None]]: | |
async def generate_memory(send: Sender): | |
await openai_function(model_name = "gpt-4", send = send, time = time, start_time = start_time, prompt = prompt_memory, message = message, context = dialogue_memory) | |
async def generate_memory_character(send: Sender): | |
await openai_function(model = "gpt-4", send = send, start_time = start_time, prompt = prompt_memory_character, message = message, context = dialogue_memory) | |
async def generate_hello(send: Sender): | |
await openai_function(model_name = "gpt-3.5-turbo-16k", send = send, time = time, start_time = start_time, prompt = prompt_chat_term, message = message, context = content_term) | |
async def generate_hello_character(send: Sender): | |
await openai_function(model_name = "gpt-4", send = send, time = time, start_time = start_time, memory = memory_chat, prompt = prompt_chat_character, message = message) | |
async def generate_basic(send: Sender): | |
global all_Info, family_info | |
family_info = message | |
await openai_function(model_name = "gpt-4", send = send, time = time, start_time = start_time, memory = memory_basic, prompt = prompt_basic, message = message) | |
async def generate_basic_character(send: Sender): | |
global all_Info, family_info | |
family_info = message | |
await openai_function(model_name = "gpt-4", send = send, time = time, start_time = start_time, memory = memory_basic_character, prompt = prompt_basic_character, message = message) | |
async def generate_recommend(send: Sender): | |
global all_Info | |
all_Info[0]["name"] = "客户" if all_Info else logger.info("--error--:len(all_Info)==0, 家庭信息未正确搜集!") | |
model = ChatOpenAI(request_timeout = 8*60, model_name="gpt-4", streaming=True, callback_manager=AsyncCallbackManager([AsyncStreamCallbackHandler(send, time, start_time)]), verbose=True, temperature=0) | |
llm_chain = LLMChain(llm = model, prompt = prompt_reco, verbose = True, memory = memory_reco) | |
await llm_chain.apredict(human_input = message, context = all_Info[0] if all_Info else "", product=product) | |
async def generate_recommend_character(send: Sender): | |
global all_Info | |
all_Info[0]["name"] = "客户" if all_Info else logger.info("--error--:len(all_Info)==0, 家庭信息未正确搜集!") | |
model = ChatOpenAI(request_timeout = 8*60, model_name="gpt-4", streaming=True, callback_manager=AsyncCallbackManager([AsyncStreamCallbackHandler(send, time, start_time)]), verbose=True, temperature=0) | |
llm_chain = LLMChain(llm=model, prompt = prompt_reco_character, verbose=True, memory = memory_reco_character) | |
await llm_chain.apredict(human_input = message, context = all_Info[0] if all_Info else "", product = product) | |
global flag1, flag2 | |
if action == 5: | |
return generate_memory_character if switch else generate_memory | |
if action == 2: | |
restart_function() | |
if "配置" in message: | |
flag1 = 1 | |
return generate_basic_character if switch else generate_basic | |
if "首要" in message: | |
flag2 = 1 | |
return generate_recommend_character if switch else generate_recommend | |
if flag1 == 0 and flag2 == 0: | |
return generate_hello_character if switch else generate_hello | |
if flag1 == 1 and flag2 == 0: | |
return generate_basic_character if switch else generate_basic | |
if flag1 == 1 and flag2 == 1: | |
return generate_recommend_character if switch else generate_recommend | |
if flag1 == 0 and flag2 == 1: | |
return generate_recommend_character if switch else generate_recommend | |
return generate_recommend_character if switch else generate_recommend | |
class StreamRequest(BaseModel): | |
"""Request body for streaming.""" | |
message: str | |
time: int | |
action: int | |
historyMessage: list[str] | |
def stream(body: StreamRequest): | |
logger.info(f"--request.body--:{body}") | |
start_time = time.time() | |
if body.action == 1: | |
body.message = "根据我之前提供的家庭信息,首要推荐我考虑的是什么险种?" | |
elif body.action == 3: | |
response = {"time":body.time, "totalElapsedTime": round(time.time() - start_time, 4), "elapsedTimeSinceLast": round(time.time() - start_time, 4), "content": "您对以上推荐的保险方案是否有疑惑需要我进行解答呢?", "broadcast": True, "exitInsureProductInfo": False, "existFamilyInfo": False, "insureProductInfos": [], "familyInfos":[], "end": True} | |
logger.info(f"--return body--:{response}\n{'*' * 300}") | |
return response | |
if "讲解" in body.message: | |
body.message = body.message + "回答不超过3句话" | |
logger.info(f"--request.body--:{body}") | |
if any(keyword in body.message for keyword in ["研发", "开发"]): | |
response = {"time":body.time, "totalElapsedTime": round(time.time() - start_time, 4), "elapsedTimeSinceLast": round(time.time() - start_time, 4), "content": "我是由杭州华鲤智能科技有限公司研发的AI私人保险助理,具体细节请咨询我们团队技术人员。", "broadcast": True, "exitInsureProductInfo": False, "existFamilyInfo": False, "insureProductInfos": [], "familyInfos":[], "end": True} | |
logger.info(f"--return body--:{response}\n{'*' * 300}") | |
return response | |
global switch, flag1, flag2, all_Info, family_info | |
if "切换" in body.message: | |
if switch == 0: | |
switch = 1 | |
restart_function() | |
response = {"time":body.time, "totalElapsedTime": round(time.time() - start_time, 4), "elapsedTimeSinceLast": round(time.time() - start_time, 4), "content": "已为您切换到精神小伙人设,老铁,请问您有什么问题要咨询的吗?", "broadcast": True, "exitInsureProductInfo": False, "existFamilyInfo": False, "insureProductInfos": [], "familyInfos":[], "end": True} | |
logger.info(f"--return body--:{response}\n{'*' * 300}") | |
return response | |
else: | |
switch = 0 | |
restart_function() | |
response = {"time":body.time, "totalElapsedTime": round(time.time() - start_time, 4), "elapsedTimeSinceLast": round(time.time() - start_time, 4), "content": "已为您切换到默认人设,请问您有什么问题要咨询的吗?", "broadcast": True, "exitInsureProductInfo": False, "existFamilyInfo": False, "insureProductInfos": [], "familyInfos":[], "end": True} | |
logger.info(f"--return body--:{response}\n{'*' * 300}") | |
return response | |
else: | |
dialogue_memory = '\n'.join(body.historyMessage) if body.historyMessage else "" | |
return ChatOpenAIStreamingResponse(send_message(body.message, body.time, body.action, dialogue_memory, start_time, switch=switch), media_type="text/event-stream", message = body.message, time = body.time, action = body.action, start_time = start_time) | |
if __name__ == "__main__": | |
uvicorn.run(host="0.0.0.0", port=8086, app=app) | |
''' | |
action = 1 : 自动切换保险推荐,算法发送“根据我之前提供的家庭信息,首要推荐我考虑的是什么险种?” | |
action = 2 : 会话全部重置 | |
action = 3 : 默认回复“您对以上推荐的保险方案是否有疑惑需要我进行解答呢?” | |
action = 4 : 出现多张脸的情况,在一个嘈杂环境中 | |
action = 5 : 用户切换 | |
action = 0 or -1 : 默认 | |
''' | |