File size: 5,506 Bytes
fa2034d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import streamlit as st

from utils.config import document_store_configs, model_configs
from haystack import Pipeline
from haystack.schema import Answer
from haystack.document_stores import BaseDocumentStore
from haystack.document_stores import InMemoryDocumentStore, OpenSearchDocumentStore, WeaviateDocumentStore
from haystack.nodes import EmbeddingRetriever, FARMReader, PromptNode, PreProcessor
#from haystack.nodes import TextConverter, FileTypeClassifier, PDFToTextConverter
from milvus_haystack import MilvusDocumentStore
#Use this file to set up your Haystack pipeline and querying

@st.cache_resource(show_spinner=False)
def start_preprocessor_node():
    print('initializing preprocessor node')
    processor = PreProcessor(
        clean_empty_lines= True,
        clean_whitespace=True,
        clean_header_footer=True,
        #remove_substrings=None,
        split_by="word",
        split_length=100,
        split_respect_sentence_boundary=True,
        #split_overlap=0,
        #max_chars_check= 10_000
    )
    return processor
    #return docs

@st.cache_resource(show_spinner=False)
def start_document_store(type: str):
    #This function starts the documents store of your choice based on your command line preference
    print('initializing document store')
    if type == 'inmemory':
        document_store = InMemoryDocumentStore(use_bm25=True, embedding_dim=384)
        '''
        documents = [
            {
                'content': "Pi is a super dog",
                'meta': {'name': "pi.txt"}
            },
            {
                'content': "The revenue of siemens is 5 milion Euro",
                'meta': {'name': "siemens.txt"}
            },
        ]
        document_store.write_documents(documents)
        '''
    elif type == 'opensearch':
        document_store = OpenSearchDocumentStore(scheme = document_store_configs['OPENSEARCH_SCHEME'], 
                                                 username = document_store_configs['OPENSEARCH_USERNAME'], 
                                                 password = document_store_configs['OPENSEARCH_PASSWORD'],
                                                 host = document_store_configs['OPENSEARCH_HOST'],
                                                 port = document_store_configs['OPENSEARCH_PORT'],
                                                 index = document_store_configs['OPENSEARCH_INDEX'],
                                                 embedding_dim = document_store_configs['OPENSEARCH_EMBEDDING_DIM'])
    elif type == 'weaviate':
        document_store = WeaviateDocumentStore(host = document_store_configs['WEAVIATE_HOST'],
                                                port = document_store_configs['WEAVIATE_PORT'],
                                                index = document_store_configs['WEAVIATE_INDEX'],
                                                embedding_dim = document_store_configs['WEAVIATE_EMBEDDING_DIM'])
    elif type == 'milvus':
        document_store = MilvusDocumentStore(uri = document_store_configs['MILVUS_URI'],
                                            index = document_store_configs['MILVUS_INDEX'],
                                            embedding_dim = document_store_configs['MILVUS_EMBEDDING_DIM'],
                                            return_embedding=True)
    return document_store

# cached to make index and models load only at start
@st.cache_resource(show_spinner=False)
def start_retriever(_document_store: BaseDocumentStore):
    print('initializing retriever')
    retriever = EmbeddingRetriever(document_store=_document_store,
                                   embedding_model=model_configs['EMBEDDING_MODEL'],
                                   top_k=5)
    #

    #_document_store.update_embeddings(retriever)
    return retriever


@st.cache_resource(show_spinner=False)
def start_reader():
    print('initializing reader')
    reader = FARMReader(model_name_or_path=model_configs['EXTRACTIVE_MODEL'])
    return reader



# cached to make index and models load only at start
@st.cache_resource(show_spinner=False)
def start_haystack_extractive(_document_store: BaseDocumentStore, _retriever: EmbeddingRetriever, _reader: FARMReader):
    print('initializing pipeline')
    pipe = Pipeline()
    pipe.add_node(component=_retriever, name="Retriever", inputs=["Query"])
    pipe.add_node(component= _reader, name="Reader", inputs=["Retriever"])
    return pipe

@st.cache_resource(show_spinner=False)
def start_haystack_rag(_document_store: BaseDocumentStore, _retriever: EmbeddingRetriever, openai_key):
    prompt_node = PromptNode(default_prompt_template="deepset/question-answering", 
                             model_name_or_path=model_configs['GENERATIVE_MODEL'],
                             api_key=openai_key,
                             max_length=500)
    pipe = Pipeline()

    pipe.add_node(component=_retriever, name="Retriever", inputs=["Query"])
    pipe.add_node(component=prompt_node, name="PromptNode", inputs=["Retriever"])

    return pipe

#@st.cache_data(show_spinner=True)
def query(_pipeline, question):
    params = {}
    results = _pipeline.run(question, params=params)
    return results

def initialize_pipeline(task, document_store, retriever, reader, openai_key = ""):
    if task == 'extractive':
        return start_haystack_extractive(document_store, retriever, reader)
    elif task == 'rag':
        return start_haystack_rag(document_store, retriever, openai_key)