File size: 2,068 Bytes
99d65c0
 
 
 
 
 
 
 
 
4359eb6
 
 
99d65c0
30bf870
99d65c0
 
 
4359eb6
99d65c0
 
 
 
 
 
 
 
 
 
 
4359eb6
 
 
 
99d65c0
4359eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a1c3b1
4359eb6
 
 
 
 
 
 
 
 
 
9a1c3b1
4359eb6
 
 
99d65c0
 
 
4359eb6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Main entrypoint for the app."""
import json
import os
from timeit import default_timer as timer
from typing import List, Optional

from lcserve import serving
from pydantic import BaseModel

from app_modules.init import app_init
from app_modules.llm_chat_chain import ChatChain
from app_modules.utils import print_llm_response

llm_loader, qa_chain = app_init(True)

chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"

uuid_to_chat_chain_mapping = dict()


class ChatResponse(BaseModel):
    """Chat response schema."""

    token: Optional[str] = None
    error: Optional[str] = None
    sourceDocs: Optional[List] = None


@serving(websocket=True)
def chat(
    question: str, history: Optional[List] = [], uuid: Optional[str] = None, **kwargs
) -> str:
    print(f"uuid: {uuid}")
    # Get the `streaming_handler` from `kwargs`. This is used to stream data to the client.
    streaming_handler = kwargs.get("streaming_handler")
    if uuid is None:
        chat_history = []
        if chat_history_enabled:
            for element in history:
                item = (element[0] or "", element[1] or "")
                chat_history.append(item)

        start = timer()
        result = qa_chain.call_chain(
            {"question": question, "chat_history": chat_history}, streaming_handler
        )
        end = timer()
        print(f"Completed in {end - start:.3f}s")

        print(f"qa_chain result: {result}")
        resp = ChatResponse(sourceDocs=result["source_documents"])

        return json.dumps(resp.dict())
    else:
        if uuid in uuid_to_chat_chain_mapping:
            chat = uuid_to_chat_chain_mapping[uuid]
        else:
            chat = ChatChain(llm_loader)
            uuid_to_chat_chain_mapping[uuid] = chat
        result = chat.call_chain({"question": question}, streaming_handler)
        print(f"chat result: {result}")

        resp = ChatResponse(sourceDocs=[])
        return json.dumps(resp.dict())


if __name__ == "__main__":
    print_llm_response(json.loads(chat("What's deep learning?", [])))