learn-ai / server.py
dh-mc's picture
added debug logs
9a1c3b1
raw
history blame
No virus
2.07 kB
"""Main entrypoint for the app."""
import json
import os
from timeit import default_timer as timer
from typing import List, Optional
from lcserve import serving
from pydantic import BaseModel
from app_modules.init import app_init
from app_modules.llm_chat_chain import ChatChain
from app_modules.utils import print_llm_response
llm_loader, qa_chain = app_init(True)
chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
uuid_to_chat_chain_mapping = dict()
class ChatResponse(BaseModel):
"""Chat response schema."""
token: Optional[str] = None
error: Optional[str] = None
sourceDocs: Optional[List] = None
@serving(websocket=True)
def chat(
question: str, history: Optional[List] = [], uuid: Optional[str] = None, **kwargs
) -> str:
print(f"uuid: {uuid}")
# Get the `streaming_handler` from `kwargs`. This is used to stream data to the client.
streaming_handler = kwargs.get("streaming_handler")
if uuid is None:
chat_history = []
if chat_history_enabled:
for element in history:
item = (element[0] or "", element[1] or "")
chat_history.append(item)
start = timer()
result = qa_chain.call_chain(
{"question": question, "chat_history": chat_history}, streaming_handler
)
end = timer()
print(f"Completed in {end - start:.3f}s")
print(f"qa_chain result: {result}")
resp = ChatResponse(sourceDocs=result["source_documents"])
return json.dumps(resp.dict())
else:
if uuid in uuid_to_chat_chain_mapping:
chat = uuid_to_chat_chain_mapping[uuid]
else:
chat = ChatChain(llm_loader)
uuid_to_chat_chain_mapping[uuid] = chat
result = chat.call_chain({"question": question}, streaming_handler)
print(f"chat result: {result}")
resp = ChatResponse(sourceDocs=[])
return json.dumps(resp.dict())
if __name__ == "__main__":
print_llm_response(json.loads(chat("What's deep learning?", [])))