import os import pickle import langchain import faiss from langchain.chains import ConversationalRetrievalChain from langchain.chat_models import ChatOpenAI from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredHTMLLoader from langchain.embeddings import OpenAIEmbeddings from langchain.memory import ConversationBufferWindowMemory from langchain.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores.faiss import FAISS from langchain.cache import InMemoryCache langchain.llm_cache = InMemoryCache() pickle_file = "open_ai.pkl" index_file = "open_ai.index" gpt_3_5 = ChatOpenAI(model_name='gpt-4',temperature=0.1) embeddings = OpenAIEmbeddings(model='text-embedding-ada-002') chat_history = [] memory = ConversationBufferWindowMemory(memory_key="chat_history") gpt_3_5_index = None system_template = """You are Coursera QA Bot. Have a conversation with a human, answering the following questions as best you can. You are a teaching assistant for a Coursera Course: The 3D Printing Evolution and can answer any question about that using vectorstore. Use the following pieces of context to answer the users question. ---------------- {context}""" messages = [ SystemMessagePromptTemplate.from_template(system_template), HumanMessagePromptTemplate.from_template("{question}"), ] CHAT_PROMPT = ChatPromptTemplate.from_messages(messages) def get_search_index(): global gpt_3_5_index if os.path.isfile(pickle_file) and os.path.isfile(index_file) and os.path.getsize(pickle_file) > 0: # Load index from pickle file with open(pickle_file, "rb") as f: search_index = pickle.load(f) else: search_index = create_index() gpt_3_5_index = search_index return search_index def create_index(): source_chunks = create_chunk_documents() search_index = search_index_from_docs(source_chunks) faiss.write_index(search_index.index, index_file) # Save index to pickle file with open(pickle_file, "wb") as f: pickle.dump(search_index, f) return search_index def search_index_from_docs(source_chunks): # print("source chunks: " + str(len(source_chunks))) # print("embeddings: " + str(embeddings)) search_index = FAISS.from_documents(source_chunks, embeddings) return search_index def get_html_files(): loader = DirectoryLoader('docs', glob="**/*.html", loader_cls=UnstructuredHTMLLoader, recursive=True) document_list = loader.load() return document_list def fetch_data_for_embeddings(): document_list = get_text_files() document_list.extend(get_html_files()) print("document list" + str(len(document_list))) return document_list def get_text_files(): loader = DirectoryLoader('docs', glob="**/*.txt", loader_cls=TextLoader, recursive=True) document_list = loader.load() return document_list def create_chunk_documents(): sources = fetch_data_for_embeddings() splitter = CharacterTextSplitter(separator=" ", chunk_size=800, chunk_overlap=0) source_chunks = splitter.split_documents(sources) print("sources" + str(len(source_chunks))) return source_chunks def get_qa_chain(gpt_3_5_index): global gpt_3_5 # embeddings_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76) # compression_retriever = ContextualCompressionRetriever(base_compressor=embeddings_filter, base_retriever=gpt_3_5_index.as_retriever()) chain = ConversationalRetrievalChain.from_llm(gpt_3_5, gpt_3_5_index.as_retriever(), return_source_documents=True, verbose=True, get_chat_history=get_chat_history, combine_docs_chain_kwargs={"prompt": CHAT_PROMPT}) return chain def get_chat_history(inputs) -> str: res = [] for human, ai in inputs: res.append(f"Human:{human}\nAI:{ai}") return "\n".join(res) def generate_answer(question) -> str: global chat_history, gpt_3_5_index gpt_3_5_chain = get_qa_chain(gpt_3_5_index) result = gpt_3_5_chain( {"question": question, "chat_history": chat_history, "vectordbkwargs": {"search_distance": 0.6}}) chat_history = [(question, result["answer"])] sources = [] print(result['answer']) for document in result['source_documents']: source = document.metadata['source'] sources.append(source.split('/')[-1].split('.')[0]) source = ',\n'.join(set(sources)) return result['answer'] + '\nSOURCES: ' + source