File size: 2,926 Bytes
328b268
 
 
 
 
 
 
 
 
 
e182c41
 
328b268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ca5bd8
328b268
3ca5bd8
328b268
 
3ca5bd8
 
 
 
 
 
 
 
 
 
328b268
 
3ca5bd8
 
2826548
3ca5bd8
 
 
 
 
 
 
328b268
3ca5bd8
 
 
328b268
95d2e5f
4cae0a4
95d2e5f
 
3ca5bd8
328b268
3ca5bd8
 
 
 
 
4cae0a4
3ca5bd8
 
 
4cae0a4
3ca5bd8
 
 
 
4cae0a4
3ca5bd8
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import abc
import os
import time
import urllib
from queue import Queue
from threading import Thread

from langchain.callbacks.tracers import LangChainTracer
from langchain.chains.base import Chain

from app_modules.llm_loader import LLMLoader, TextIteratorStreamer
from app_modules.utils import remove_extra_spaces


class LLMInference(metaclass=abc.ABCMeta):
    llm_loader: LLMLoader
    chain: Chain

    def __init__(self, llm_loader):
        self.llm_loader = llm_loader
        self.chain = None

    @abc.abstractmethod
    def create_chain(self) -> Chain:
        pass

    def get_chain(self, tracing: bool = False) -> Chain:
        if self.chain is None:
            if tracing:
                tracer = LangChainTracer()
                tracer.load_default_session()

            self.chain = self.create_chain()

        return self.chain

    def call_chain(
        self, inputs, streaming_handler, q: Queue = None, tracing: bool = False
    ):
        print(inputs)
        self.llm_loader.lock.acquire()

        try:
            self.llm_loader.streamer.reset(q)

            chain = self.get_chain(tracing)
            result = (
                self._run_chain(
                    chain,
                    inputs,
                    streaming_handler,
                )
                if streaming_handler is not None
                and self.llm_loader.streamer.for_huggingface
                else chain(inputs)
            )

            if "answer" in result:
                result["answer"] = remove_extra_spaces(result["answer"])

                base_url = os.environ.get("PDF_FILE_BASE_URL")
                if base_url is not None and len(base_url) > 0:
                    documents = result["source_documents"]
                    for doc in documents:
                        source = doc.metadata["source"]
                        title = source.split("/")[-1]
                        doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}"

            return result
        finally:
            self.llm_loader.lock.release()

    def _execute_chain(self, chain, inputs, q, sh):
        q.put(chain(inputs, callbacks=[sh]))

    def _run_chain(self, chain, inputs, streaming_handler):
        que = Queue()

        t = Thread(
            target=self._execute_chain,
            args=(chain, inputs, que, streaming_handler),
        )
        t.start()

        count = (
            2 if "chat_history" in inputs and len(inputs.get("chat_history")) > 0 else 1
        )

        while count > 0:
            try:
                for token in self.llm_loader.streamer:
                    streaming_handler.on_llm_new_token(token)

                self.llm_loader.streamer.reset()
                count -= 1
            except Exception:
                print("nothing generated yet - retry in 0.5s")
                time.sleep(0.5)

        t.join()
        return que.get()