# Define the script's usage example USAGE_EXAMPLE = """ Example usage: To process input *.txt files at input_path and save the vector db output at output_db: python create_vector_db.py input_path output_db --chunk_size 100 --chunk_overlap 10 Required arguments: - input_path: Path to the input dir containing the .txt files - output_path: Path to the output vector db. Optional arguments: - --chunk_size: Size of the chunks (default: None). - --chunk_overlap: Overlap between chunks (default: None). """ import os import sys import argparse import logging from langchain_community.document_loaders import DirectoryLoader, UnstructuredURLLoader from langchain_community.embeddings import HuggingFaceInstructEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter from langchain_community.vectorstores import FAISS, Chroma, Qdrant vectordb_dir = os.path.dirname(os.path.abspath(__file__)) utils_dir = os.path.abspath(os.path.join(vectordb_dir, "..")) repo_dir = os.path.abspath(os.path.join(utils_dir, "..")) sys.path.append(repo_dir) sys.path.append(utils_dir) from utils.model_wrappers.api_gateway import APIGateway import uuid import streamlit as st EMBEDDING_MODEL = "intfloat/e5-large-v2" NORMALIZE_EMBEDDINGS = True VECTORDB_LOG_FILE_NAME = "vector_db.log" # Configure the logger logging.basicConfig( level=logging.INFO, # Set the logging level (e.g., INFO, DEBUG) format="%(asctime)s [%(levelname)s] - %(message)s", # Define the log message format handlers=[ logging.StreamHandler(), # Output logs to the console logging.FileHandler(VECTORDB_LOG_FILE_NAME), ], ) # Create a logger object logger = logging.getLogger(__name__) class VectorDb(): """ A class for creating, updating and loading FAISS or Chroma vector databases, to use them with retrieval augmented generation tasks with langchain Args: None Attributes: None Methods: load_files: Load files from an input directory as langchain documents get_text_chunks: Get text chunks from a list of documents get_token_chunks: Get token chunks from a list of documents create_vector_store: Create a vector store from chunks and an embedding model load_vdb: load a previous stored vector database update_vdb: Update an existing vector store with new chunks create_vdb: Create a vector database from the raw files in a specific input directory """ def __init__(self) -> None: self.collection_id = str(uuid.uuid4()) self.vector_collections = set() def load_files(self, input_path, recursive=False, load_txt=True, load_pdf=False, urls = None) -> list: """Load files from input location Args: input_path : input location of files recursive (bool, optional): flag to load files recursively. Defaults to False. load_txt (bool, optional): flag to load txt files. Defaults to True. load_pdf (bool, optional): flag to load pdf files. Defaults to False. urls (list, optional): list of urls to load. Defaults to None. Returns: list: list of documents """ docs=[] text_loader_kwargs={'autodetect_encoding': True} if input_path is not None: if load_txt: loader = DirectoryLoader(input_path, glob="*.txt", recursive=recursive, show_progress=True, loader_kwargs=text_loader_kwargs) docs.extend(loader.load()) if load_pdf: loader = DirectoryLoader(input_path, glob="*.pdf", recursive=recursive, show_progress=True, loader_kwargs=text_loader_kwargs) docs.extend(loader.load()) if urls: loader = UnstructuredURLLoader(urls=urls) docs.extend(loader.load()) logger.info(f"Total {len(docs)} files loaded") return docs def get_text_chunks(self, docs: list, chunk_size: int, chunk_overlap: int, meta_data: list = None) -> list: """Gets text chunks. If metadata is not None, it will create chunks with metadata elements. Args: docs (list): list of documents or texts. If no metadata is passed, this parameter is a list of documents. If metadata is passed, this parameter is a list of texts. chunk_size (int): chunk size in number of characters chunk_overlap (int): chunk overlap in number of characters metadata (list, optional): list of metadata in dictionary format. Defaults to None. Returns: list: list of documents """ text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len ) if meta_data is None: logger.info(f"Splitter: splitting documents") chunks = text_splitter.split_documents(docs) else: logger.info(f"Splitter: creating documents with metadata") chunks = text_splitter.create_documents(docs, meta_data) logger.info(f"Total {len(chunks)} chunks created") return chunks def get_token_chunks(self, docs: list, chunk_size: int, chunk_overlap: int, tokenizer) -> list: """Gets token chunks. If metadata is not None, it will create chunks with metadata elements. Args: docs (list): list of documents or texts. If no metadata is passed, this parameter is a list of documents. If metadata is passed, this parameter is a list of texts. chunk_size (int): chunk size in number of tokens chunk_overlap (int): chunk overlap in number of tokens Returns: list: list of documents """ text_splitter = CharacterTextSplitter.from_huggingface_tokenizer( tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap ) logger.info(f"Splitter: splitting documents") chunks = text_splitter.split_documents(docs) logger.info(f"Total {len(chunks)} chunks created") return chunks def create_vector_store(self, chunks: list, embeddings: HuggingFaceInstructEmbeddings, db_type: str, output_db: str = None, collection_name: str = None): """Creates a vector store Args: chunks (list): list of chunks embeddings (HuggingFaceInstructEmbeddings): embedding model db_type (str): vector db type output_db (str, optional): output path to save the vector db. Defaults to None. """ if collection_name is None: collection_name = f"collection_{self.collection_id}" logger.info(f'This is the collection name: {collection_name}') if db_type == "faiss": vector_store = FAISS.from_documents( documents=chunks, embedding=embeddings ) if output_db: vector_store.save_local(output_db) elif db_type == "chroma": if output_db: vector_store = Chroma() vector_store.delete_collection() vector_store = Chroma.from_documents( documents=chunks, embedding=embeddings, persist_directory=output_db, collection_name=collection_name ) else: vector_store = Chroma() vector_store.delete_collection() vector_store = Chroma.from_documents( documents=chunks, embedding=embeddings, collection_name=collection_name ) self.vector_collections.add(collection_name) elif db_type == "qdrant": if output_db: vector_store = Qdrant.from_documents( documents=chunks, embedding=embeddings, path=output_db, collection_name="test_collection", ) else: vector_store = Qdrant.from_documents( documents=chunks, embedding=embeddings, collection_name="test_collection", ) logger.info(f"Vector store saved to {output_db}") return vector_store def load_vdb(self, persist_directory, embedding_model, db_type="chroma", collection_name=None): if db_type == "faiss": vector_store = FAISS.load_local(persist_directory, embedding_model, allow_dangerous_deserialization=True) elif db_type == "chroma": if collection_name: vector_store = Chroma( persist_directory=persist_directory, embedding_function=embedding_model, collection_name=collection_name ) else: vector_store = Chroma( persist_directory=persist_directory, embedding_function=embedding_model ) elif db_type == "qdrant": # TODO: Implement Qdrant loading pass else: raise ValueError(f"Unsupported database type: {db_type}") return vector_store def update_vdb(self, chunks: list, embeddings, db_type: str, input_db: str = None, output_db: str = None): if db_type == "faiss": vector_store = FAISS.load_local(input_db, embeddings, allow_dangerous_deserialization=True) new_vector_store = self.create_vector_store(chunks, embeddings, db_type, None) vector_store.merge_from(new_vector_store) if output_db: vector_store.save_local(output_db) elif db_type == "chroma": # TODO implement update method for chroma pass elif db_type == "qdrant": # TODO implement update method for qdrant pass return vector_store def create_vdb( self, input_path, chunk_size, chunk_overlap, db_type, output_db=None, recursive=False, tokenizer=None, load_txt=True, load_pdf=False, urls=None, embedding_type="cpu", batch_size= None, coe = None, select_expert = None ): docs = self.load_files(input_path, recursive=recursive, load_txt=load_txt, load_pdf=load_pdf, urls=urls) if tokenizer is None: chunks = self.get_text_chunks(docs, chunk_size, chunk_overlap) else: chunks = self.get_token_chunks(docs, chunk_size, chunk_overlap, tokenizer) embeddings = APIGateway.load_embedding_model( type=embedding_type, batch_size=batch_size, coe=coe, select_expert=select_expert ) vector_store = self.create_vector_store(chunks, embeddings, db_type, output_db) return vector_store def dir_path(path): if os.path.isdir(path): return path else: raise argparse.ArgumentTypeError(f"readable_dir:{path} is not a valid path") # Parse the arguments def parse_arguments(): parser = argparse.ArgumentParser(description="Process command line arguments.") parser.add_argument("-input_path", type=dir_path, help="path to input directory") parser.add_argument("--chunk_size", type=int, help="chunk size for splitting") parser.add_argument("--chunk_overlap", type=int, help="chunk overlap for splitting") parser.add_argument("-output_path", type=dir_path, help="path to input directory") return parser.parse_args() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Process data with optional chunking") # Required arguments parser.add_argument("--input_path", type=str, help="Path to the input directory") parser.add_argument("--output_db", type=str, help="Path to the output vectordb") # Optional arguments parser.add_argument( "--chunk_size", type=int, default=1000, help="Chunk size (default: 1000)" ) parser.add_argument( "--chunk_overlap", type=int, default=200, help="Chunk overlap (default: 200)" ) parser.add_argument( "--db_type", type=str, default="faiss", help="Type of vector store (default: faiss)", ) args = parser.parse_args() vectordb = VectorDb() vectordb.create_vdb( args.input_path, args.output_db, args.chunk_size, args.chunk_overlap, args.db_type, )