from utils.check_pydantic_version import use_pydantic_v1 use_pydantic_v1() #This function has to be run before importing haystack. as haystack requires pydantic v1 to run from operator import index import streamlit as st import logging import os from annotated_text import annotation from json import JSONDecodeError from markdown import markdown from utils.config import parser from utils.haystack import start_document_store, query, initialize_pipeline, start_preprocessor_node, start_retriever, start_reader from utils.ui import reset_results, set_initial_state import pandas as pd import haystack from datetime import datetime import streamlit.components.v1 as components import streamlit_authenticator as stauth import pickle from streamlit_modal import Modal import numpy as np names = ['mlreply'] usernames = ['docwhiz'] with open('hashed_password.pkl','rb') as f: hashed_passwords = pickle.load(f) # Whether the file upload should be enabled or not DISABLE_FILE_UPLOAD = bool(os.getenv("DISABLE_FILE_UPLOAD")) def show_documents_list(retrieved_documents): data = [] for i, document in enumerate(retrieved_documents): data.append([document.meta['name']]) df = pd.DataFrame(data, columns=['Uploaded Document Name']) df.drop_duplicates(subset=['Uploaded Document Name'], inplace=True) df.index = np.arange(1, len(df) + 1) return df # Define a function to handle file uploads def upload_files(): uploaded_files = upload_container.file_uploader( "upload", type=["pdf", "txt", "docx"], accept_multiple_files=True, label_visibility="hidden", key=1 ) return uploaded_files # Define a function to process a single file def process_file(data_file, preprocesor, document_store): # read file and add content file_contents = data_file.read().decode("utf-8") docs = [{ 'content': str(file_contents), 'meta': {'name': str(data_file.name)} }] try: names = [item.meta.get('name') for item in document_store.get_all_documents()] #if args.store == 'inmemory': # doc = converter.convert(file_path=files, meta=None) if data_file.name in names: print(f"{data_file.name} already processed") else: print(f'preprocessing uploaded doc {data_file.name}.......') #print(data_file.read().decode("utf-8")) preprocessed_docs = preprocesor.process(docs) print('writing to document store.......') document_store.write_documents(preprocessed_docs) print('updating emebdding.......') document_store.update_embeddings(retriever) except Exception as e: print(e) # Define a function to upload the documents to haystack document store def upload_document(): if data_files is not None: for data_file in data_files: # Upload file if data_file: try: #raw_json = upload_doc(data_file) # Call the process_file function for each uploaded file if args.store == 'inmemory': processed_data = process_file(data_file, preprocesor, document_store) #upload_container.write(str(data_file.name) + "    ✅ ") except Exception as e: upload_container.write(str(data_file.name) + "    ❌ ") upload_container.write("_This file could not be parsed, see the logs for more information._") # Define a function to reset the documents in haystack document store def reset_documents(): print('\nReseting documents list at ' + str(datetime.now()) + '\n') st.session_state.data_files = None document_store.delete_documents() try: args = parser.parse_args() preprocesor = start_preprocessor_node() document_store = start_document_store(type=args.store) document_store.get_all_documents() retriever = start_retriever(document_store) reader = start_reader() st.set_page_config( page_title="MLReplySearch", layout="centered", page_icon=":shark:", menu_items={ 'Get Help': 'https://www.extremelycoolapp.com/help', 'Report a bug': "https://www.extremelycoolapp.com/bug", 'About': "# This is a header. This is an *extremely* cool app!" } ) st.sidebar.image("ml_logo.png", use_column_width=True) authenticator = stauth.Authenticate(names, usernames, hashed_passwords, "document_search", "random_text", cookie_expiry_days=1) name, authentication_status, username = authenticator.login("Login", "main") if authentication_status == False: st.error("Username/Password is incorrect") if authentication_status == None: st.warning("Please enter your username and password") if authentication_status: # Sidebar for Task Selection st.sidebar.header('Options:') # OpenAI Key Input openai_key = st.sidebar.text_input("Enter LLM-authorization Key:", type="password") if openai_key: task_options = ['Extractive', 'Generative'] else: task_options = ['Extractive'] task_selection = st.sidebar.radio('Select the task:', task_options) # Check the task and initialize pipeline accordingly if task_selection == 'Extractive': pipeline_extractive = initialize_pipeline("extractive", document_store, retriever, reader) elif task_selection == 'Generative' and openai_key: # Check for openai_key to ensure user has entered it pipeline_rag = initialize_pipeline("rag", document_store, retriever, reader, openai_key=openai_key) set_initial_state() modal = Modal("Manage Files", key="demo-modal") open_modal = st.sidebar.button("Manage Files", use_container_width=True) if open_modal: modal.open() st.write('# ' + args.name) if modal.is_open(): with modal.container(): if not DISABLE_FILE_UPLOAD: upload_container = st.container() data_files = upload_files() upload_document() st.session_state.sidebar_state = 'collapsed' st.table(show_documents_list(document_store.get_all_documents())) # File upload block # if not DISABLE_FILE_UPLOAD: # upload_container = st.sidebar.container() # upload_container.write("## File Upload:") # data_files = upload_files() # Button to update files in the documentStore # upload_container.button('Upload Files', on_click=upload_document, args=()) # Button to reset the documents in DocumentStore st.sidebar.button("Reset documents", on_click=reset_documents, args=(), use_container_width=True) if "question" not in st.session_state: st.session_state.question = "" # Search bar question = st.text_input("Question", value=st.session_state.question, max_chars=100, on_change=reset_results, label_visibility="hidden") run_pressed = st.button("Run") run_query = ( run_pressed or question != st.session_state.question #or task_selection != st.session_state.task ) # Get results for query if run_query and question: if task_selection == 'Extractive': reset_results() st.session_state.question = question with st.spinner("🔎    Running your pipeline"): try: st.session_state.results_extractive = query(pipeline_extractive, question) st.session_state.task = task_selection except JSONDecodeError as je: st.error( "👓    An error occurred reading the results. Is the document store working?" ) except Exception as e: logging.exception(e) st.error("🐞    An error occurred during the request.") elif task_selection == 'Generative': reset_results() st.session_state.question = question with st.spinner("🔎    Running your pipeline"): try: st.session_state.results_generative = query(pipeline_rag, question) st.session_state.task = task_selection except JSONDecodeError as je: st.error( "👓    An error occurred reading the results. Is the document store working?" ) except Exception as e: if "API key is invalid" in str(e): logging.exception(e) st.error("🐞    incorrect API key provided. You can find your API key at https://platform.openai.com/account/api-keys.") else: logging.exception(e) st.error("🐞    An error occurred during the request.") # Display results if (st.session_state.results_extractive or st.session_state.results_generative) and run_query: # Handle Extractive Answers if task_selection == 'Extractive': results = st.session_state.results_extractive st.subheader("Extracted Answers:") if 'answers' in results: answers = results['answers'] treshold = 0.2 higher_then_treshold = any(ans.score > treshold for ans in answers) if not higher_then_treshold: st.markdown(f"Please note none of the answers achieved a score higher then {int(treshold) * 100}%. Which probably means that the desired answer is not in the searched documents.", unsafe_allow_html=True) for count, answer in enumerate(answers): if answer.answer: text, context = answer.answer, answer.context start_idx = context.find(text) end_idx = start_idx + len(text) score = round(answer.score, 3) st.markdown(f"**Answer {count + 1}:**") st.markdown( context[:start_idx] + str(annotation(body=text, label=f'SCORE {score}', background='#964448', color='#ffffff')) + context[end_idx:], unsafe_allow_html=True, ) else: st.info( "🤔    Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!" ) # Handle Generative Answers elif task_selection == 'Generative': results = st.session_state.results_generative st.subheader("Generated Answer:") if 'results' in results: st.markdown("**Answer:**") st.write(results['results'][0]) # Handle Retrieved Documents if 'documents' in results: retrieved_documents = results['documents'] st.subheader("Retriever Results:") data = [] for i, document in enumerate(retrieved_documents): # Truncate the content truncated_content = (document.content[:150] + '...') if len(document.content) > 150 else document.content data.append([i + 1, document.meta['name'], truncated_content]) # Convert data to DataFrame and display using Streamlit df = pd.DataFrame(data, columns=['Ranked Context', 'Document Name', 'Content']) st.table(df) except SystemExit as e: os._exit(e.code)