petrojm's picture
add EKR files
a6c26b1
raw
history blame
12.7 kB
# Define the script's usage example
USAGE_EXAMPLE = """
Example usage:
To process input *.txt files at input_path and save the vector db output at output_db:
python create_vector_db.py input_path output_db --chunk_size 100 --chunk_overlap 10
Required arguments:
- input_path: Path to the input dir containing the .txt files
- output_path: Path to the output vector db.
Optional arguments:
- --chunk_size: Size of the chunks (default: None).
- --chunk_overlap: Overlap between chunks (default: None).
"""
import os
import sys
import argparse
import logging
from langchain_community.document_loaders import DirectoryLoader, UnstructuredURLLoader
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain_community.vectorstores import FAISS, Chroma, Qdrant
vectordb_dir = os.path.dirname(os.path.abspath(__file__))
utils_dir = os.path.abspath(os.path.join(vectordb_dir, ".."))
repo_dir = os.path.abspath(os.path.join(utils_dir, ".."))
sys.path.append(repo_dir)
sys.path.append(utils_dir)
from utils.model_wrappers.api_gateway import APIGateway
import uuid
import streamlit as st
EMBEDDING_MODEL = "intfloat/e5-large-v2"
NORMALIZE_EMBEDDINGS = True
VECTORDB_LOG_FILE_NAME = "vector_db.log"
# Configure the logger
logging.basicConfig(
level=logging.INFO, # Set the logging level (e.g., INFO, DEBUG)
format="%(asctime)s [%(levelname)s] - %(message)s", # Define the log message format
handlers=[
logging.StreamHandler(), # Output logs to the console
logging.FileHandler(VECTORDB_LOG_FILE_NAME),
],
)
# Create a logger object
logger = logging.getLogger(__name__)
class VectorDb():
"""
A class for creating, updating and loading FAISS or Chroma vector databases,
to use them with retrieval augmented generation tasks with langchain
Args:
None
Attributes:
None
Methods:
load_files: Load files from an input directory as langchain documents
get_text_chunks: Get text chunks from a list of documents
get_token_chunks: Get token chunks from a list of documents
create_vector_store: Create a vector store from chunks and an embedding model
load_vdb: load a previous stored vector database
update_vdb: Update an existing vector store with new chunks
create_vdb: Create a vector database from the raw files in a specific input directory
"""
def __init__(self) -> None:
self.collection_id = str(uuid.uuid4())
self.vector_collections = set()
def load_files(self, input_path, recursive=False, load_txt=True, load_pdf=False, urls = None) -> list:
"""Load files from input location
Args:
input_path : input location of files
recursive (bool, optional): flag to load files recursively. Defaults to False.
load_txt (bool, optional): flag to load txt files. Defaults to True.
load_pdf (bool, optional): flag to load pdf files. Defaults to False.
urls (list, optional): list of urls to load. Defaults to None.
Returns:
list: list of documents
"""
docs=[]
text_loader_kwargs={'autodetect_encoding': True}
if input_path is not None:
if load_txt:
loader = DirectoryLoader(input_path, glob="*.txt", recursive=recursive, show_progress=True, loader_kwargs=text_loader_kwargs)
docs.extend(loader.load())
if load_pdf:
loader = DirectoryLoader(input_path, glob="*.pdf", recursive=recursive, show_progress=True, loader_kwargs=text_loader_kwargs)
docs.extend(loader.load())
if urls:
loader = UnstructuredURLLoader(urls=urls)
docs.extend(loader.load())
logger.info(f"Total {len(docs)} files loaded")
return docs
def get_text_chunks(self, docs: list, chunk_size: int, chunk_overlap: int, meta_data: list = None) -> list:
"""Gets text chunks. If metadata is not None, it will create chunks with metadata elements.
Args:
docs (list): list of documents or texts. If no metadata is passed, this parameter is a list of documents.
If metadata is passed, this parameter is a list of texts.
chunk_size (int): chunk size in number of characters
chunk_overlap (int): chunk overlap in number of characters
metadata (list, optional): list of metadata in dictionary format. Defaults to None.
Returns:
list: list of documents
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len
)
if meta_data is None:
logger.info(f"Splitter: splitting documents")
chunks = text_splitter.split_documents(docs)
else:
logger.info(f"Splitter: creating documents with metadata")
chunks = text_splitter.create_documents(docs, meta_data)
logger.info(f"Total {len(chunks)} chunks created")
return chunks
def get_token_chunks(self, docs: list, chunk_size: int, chunk_overlap: int, tokenizer) -> list:
"""Gets token chunks. If metadata is not None, it will create chunks with metadata elements.
Args:
docs (list): list of documents or texts. If no metadata is passed, this parameter is a list of documents.
If metadata is passed, this parameter is a list of texts.
chunk_size (int): chunk size in number of tokens
chunk_overlap (int): chunk overlap in number of tokens
Returns:
list: list of documents
"""
text_splitter = CharacterTextSplitter.from_huggingface_tokenizer(
tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
logger.info(f"Splitter: splitting documents")
chunks = text_splitter.split_documents(docs)
logger.info(f"Total {len(chunks)} chunks created")
return chunks
def create_vector_store(self, chunks: list, embeddings: HuggingFaceInstructEmbeddings, db_type: str,
output_db: str = None, collection_name: str = None):
"""Creates a vector store
Args:
chunks (list): list of chunks
embeddings (HuggingFaceInstructEmbeddings): embedding model
db_type (str): vector db type
output_db (str, optional): output path to save the vector db. Defaults to None.
"""
if collection_name is None:
collection_name = f"collection_{self.collection_id}"
logger.info(f'This is the collection name: {collection_name}')
if db_type == "faiss":
vector_store = FAISS.from_documents(
documents=chunks,
embedding=embeddings
)
if output_db:
vector_store.save_local(output_db)
elif db_type == "chroma":
if output_db:
vector_store = Chroma()
vector_store.delete_collection()
vector_store = Chroma.from_documents(
documents=chunks,
embedding=embeddings,
persist_directory=output_db,
collection_name=collection_name
)
else:
vector_store = Chroma()
vector_store.delete_collection()
vector_store = Chroma.from_documents(
documents=chunks,
embedding=embeddings,
collection_name=collection_name
)
self.vector_collections.add(collection_name)
elif db_type == "qdrant":
if output_db:
vector_store = Qdrant.from_documents(
documents=chunks,
embedding=embeddings,
path=output_db,
collection_name="test_collection",
)
else:
vector_store = Qdrant.from_documents(
documents=chunks,
embedding=embeddings,
collection_name="test_collection",
)
logger.info(f"Vector store saved to {output_db}")
return vector_store
def load_vdb(self, persist_directory, embedding_model, db_type="chroma", collection_name=None):
if db_type == "faiss":
vector_store = FAISS.load_local(persist_directory, embedding_model, allow_dangerous_deserialization=True)
elif db_type == "chroma":
if collection_name:
vector_store = Chroma(
persist_directory=persist_directory,
embedding_function=embedding_model,
collection_name=collection_name
)
else:
vector_store = Chroma(
persist_directory=persist_directory,
embedding_function=embedding_model
)
elif db_type == "qdrant":
# TODO: Implement Qdrant loading
pass
else:
raise ValueError(f"Unsupported database type: {db_type}")
return vector_store
def update_vdb(self, chunks: list, embeddings, db_type: str, input_db: str = None,
output_db: str = None):
if db_type == "faiss":
vector_store = FAISS.load_local(input_db, embeddings, allow_dangerous_deserialization=True)
new_vector_store = self.create_vector_store(chunks, embeddings, db_type, None)
vector_store.merge_from(new_vector_store)
if output_db:
vector_store.save_local(output_db)
elif db_type == "chroma":
# TODO implement update method for chroma
pass
elif db_type == "qdrant":
# TODO implement update method for qdrant
pass
return vector_store
def create_vdb(
self,
input_path,
chunk_size,
chunk_overlap,
db_type,
output_db=None,
recursive=False,
tokenizer=None,
load_txt=True,
load_pdf=False,
urls=None,
embedding_type="cpu",
batch_size= None,
coe = None,
select_expert = None
):
docs = self.load_files(input_path, recursive=recursive, load_txt=load_txt, load_pdf=load_pdf, urls=urls)
if tokenizer is None:
chunks = self.get_text_chunks(docs, chunk_size, chunk_overlap)
else:
chunks = self.get_token_chunks(docs, chunk_size, chunk_overlap, tokenizer)
embeddings = APIGateway.load_embedding_model(
type=embedding_type,
batch_size=batch_size,
coe=coe,
select_expert=select_expert
)
vector_store = self.create_vector_store(chunks, embeddings, db_type, output_db)
return vector_store
def dir_path(path):
if os.path.isdir(path):
return path
else:
raise argparse.ArgumentTypeError(f"readable_dir:{path} is not a valid path")
# Parse the arguments
def parse_arguments():
parser = argparse.ArgumentParser(description="Process command line arguments.")
parser.add_argument("-input_path", type=dir_path, help="path to input directory")
parser.add_argument("--chunk_size", type=int, help="chunk size for splitting")
parser.add_argument("--chunk_overlap", type=int, help="chunk overlap for splitting")
parser.add_argument("-output_path", type=dir_path, help="path to input directory")
return parser.parse_args()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process data with optional chunking")
# Required arguments
parser.add_argument("--input_path", type=str, help="Path to the input directory")
parser.add_argument("--output_db", type=str, help="Path to the output vectordb")
# Optional arguments
parser.add_argument(
"--chunk_size", type=int, default=1000, help="Chunk size (default: 1000)"
)
parser.add_argument(
"--chunk_overlap", type=int, default=200, help="Chunk overlap (default: 200)"
)
parser.add_argument(
"--db_type",
type=str,
default="faiss",
help="Type of vector store (default: faiss)",
)
args = parser.parse_args()
vectordb = VectorDb()
vectordb.create_vdb(
args.input_path,
args.output_db,
args.chunk_size,
args.chunk_overlap,
args.db_type,
)