petrojm's picture
add EKR files
a6c26b1
raw
history blame
4.75 kB
# Define the script's usage example
USAGE_EXAMPLE = """
Example usage:
To process input *.txt files at input_path and save the vector db output at output_db:
python create_vector_db.py input_path output_db --chunk_size 100 --chunk_overlap 10
Required arguments:
- input_path: Path to the input dir containing the .txt files
- output_path: Path to the output vector db.
Optional arguments:
- --chunk_size: Size of the chunks (default: None).
- --chunk_overlap: Overlap between chunks (default: None).
"""
import argparse
import logging
import os
from langchain.document_loaders import DirectoryLoader
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS, Chroma, Qdrant
# Configure the logger
logging.basicConfig(
level=logging.INFO, # Set the logging level (e.g., INFO, DEBUG)
format="%(asctime)s [%(levelname)s] - %(message)s", # Define the log message format
handlers=[
logging.StreamHandler(), # Output logs to the console
logging.FileHandler("create_vector_db.log"),
],
)
# Create a logger object
logger = logging.getLogger(__name__)
# Parse the arguments
def parse_arguments():
parser = argparse.ArgumentParser(description="Process command line arguments.")
parser.add_argument("-input_path", type=dir_path, help="path to input directory")
parser.add_argument("--chunk_size", type=int, help="chunk size for splitting")
parser.add_argument("--chunk_overlap", type=int, help="chunk overlap for splitting")
parser.add_argument("-output_path", type=dir_path, help="path to input directory")
return parser.parse_args()
# Check valid path
def dir_path(path):
if os.path.isdir(path):
return path
else:
raise argparse.ArgumentTypeError(f"readable_dir:{path} is not a valid path")
def main(input_path, output_db, chunk_size, chunk_overlap, db_type):
# Load files from input_location
loader = DirectoryLoader(input_path, glob="*.txt")
docs = loader.load()
logger.info(f"Total {len(docs)} files loaded")
# get the text chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len
)
chunks = text_splitter.split_documents(docs)
logger.info(f"Total {len(chunks)} chunks created")
# create vector store
encode_kwargs = {"normalize_embeddings": True}
embedding_model = "BAAI/bge-large-en"
embeddings = HuggingFaceInstructEmbeddings(
model_name=embedding_model,
embed_instruction="", # no instruction is needed for candidate passages
query_instruction="Represent this sentence for searching relevant passages: ",
encode_kwargs=encode_kwargs,
)
logger.info(
f"Processing embeddings using {embedding_model}. This could take time depending on the number of chunks ..."
)
if db_type == "faiss":
vectorstore = FAISS.from_documents(documents=chunks, embedding=embeddings)
# save vectorstore
vectorstore.save_local(output_db)
elif db_type == "chromadb":
vectorstore = Chroma.from_documents(
documents=chunks, embedding=embeddings, persist_directory=output_db
)
elif db_type == "qdrant":
vectorstore = Qdrant.from_documents(
documents=chunks,
embedding=embeddings,
path=output_db,
collection_name="test_collection",
)
elif db_type == "qdrant-server":
url = "http://localhost:6333/"
vectorstore = Qdrant.from_documents(
documents=chunks,
embedding=embeddings,
url=url,
prefer_grpc=True,
collection_name="anaconda",
)
logger.info(f"Vector store saved to {output_db}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process data with optional chunking")
# Required arguments
parser.add_argument("input_path", type=str, help="Path to the input directory")
parser.add_argument("output_db", type=str, help="Path to the output vectordb")
# Optional arguments
parser.add_argument(
"--chunk_size", type=int, default=1000, help="Chunk size (default: 1000)"
)
parser.add_argument(
"--chunk_overlap", type=int, default=200, help="Chunk overlap (default: 200)"
)
parser.add_argument(
"--db_type",
type=str,
default="faiss",
help="Type of vectorstore (default: faiss)",
)
args = parser.parse_args()
main(
args.input_path,
args.output_db,
args.chunk_size,
args.chunk_overlap,
args.db_type,
)