|
|
|
USAGE_EXAMPLE = """ |
|
Example usage: |
|
|
|
To process input *.txt files at input_path and save the vector db output at output_db: |
|
python create_vector_db.py input_path output_db --chunk_size 100 --chunk_overlap 10 |
|
|
|
Required arguments: |
|
- input_path: Path to the input dir containing the .txt files |
|
- output_path: Path to the output vector db. |
|
|
|
Optional arguments: |
|
- --chunk_size: Size of the chunks (default: None). |
|
- --chunk_overlap: Overlap between chunks (default: None). |
|
""" |
|
|
|
import argparse |
|
import logging |
|
import os |
|
|
|
from langchain.document_loaders import DirectoryLoader |
|
from langchain.embeddings import HuggingFaceInstructEmbeddings |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import FAISS, Chroma, Qdrant |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s [%(levelname)s] - %(message)s", |
|
handlers=[ |
|
logging.StreamHandler(), |
|
logging.FileHandler("create_vector_db.log"), |
|
], |
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
def parse_arguments(): |
|
parser = argparse.ArgumentParser(description="Process command line arguments.") |
|
parser.add_argument("-input_path", type=dir_path, help="path to input directory") |
|
parser.add_argument("--chunk_size", type=int, help="chunk size for splitting") |
|
parser.add_argument("--chunk_overlap", type=int, help="chunk overlap for splitting") |
|
parser.add_argument("-output_path", type=dir_path, help="path to input directory") |
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
def dir_path(path): |
|
if os.path.isdir(path): |
|
return path |
|
else: |
|
raise argparse.ArgumentTypeError(f"readable_dir:{path} is not a valid path") |
|
|
|
|
|
def main(input_path, output_db, chunk_size, chunk_overlap, db_type): |
|
|
|
loader = DirectoryLoader(input_path, glob="*.txt") |
|
docs = loader.load() |
|
logger.info(f"Total {len(docs)} files loaded") |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len |
|
) |
|
chunks = text_splitter.split_documents(docs) |
|
logger.info(f"Total {len(chunks)} chunks created") |
|
|
|
|
|
encode_kwargs = {"normalize_embeddings": True} |
|
embedding_model = "BAAI/bge-large-en" |
|
embeddings = HuggingFaceInstructEmbeddings( |
|
model_name=embedding_model, |
|
embed_instruction="", |
|
query_instruction="Represent this sentence for searching relevant passages: ", |
|
encode_kwargs=encode_kwargs, |
|
) |
|
logger.info( |
|
f"Processing embeddings using {embedding_model}. This could take time depending on the number of chunks ..." |
|
) |
|
|
|
if db_type == "faiss": |
|
vectorstore = FAISS.from_documents(documents=chunks, embedding=embeddings) |
|
|
|
vectorstore.save_local(output_db) |
|
elif db_type == "chromadb": |
|
vectorstore = Chroma.from_documents( |
|
documents=chunks, embedding=embeddings, persist_directory=output_db |
|
) |
|
elif db_type == "qdrant": |
|
vectorstore = Qdrant.from_documents( |
|
documents=chunks, |
|
embedding=embeddings, |
|
path=output_db, |
|
collection_name="test_collection", |
|
) |
|
elif db_type == "qdrant-server": |
|
url = "http://localhost:6333/" |
|
vectorstore = Qdrant.from_documents( |
|
documents=chunks, |
|
embedding=embeddings, |
|
url=url, |
|
prefer_grpc=True, |
|
collection_name="anaconda", |
|
) |
|
|
|
logger.info(f"Vector store saved to {output_db}") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Process data with optional chunking") |
|
|
|
|
|
parser.add_argument("input_path", type=str, help="Path to the input directory") |
|
parser.add_argument("output_db", type=str, help="Path to the output vectordb") |
|
|
|
|
|
parser.add_argument( |
|
"--chunk_size", type=int, default=1000, help="Chunk size (default: 1000)" |
|
) |
|
parser.add_argument( |
|
"--chunk_overlap", type=int, default=200, help="Chunk overlap (default: 200)" |
|
) |
|
parser.add_argument( |
|
"--db_type", |
|
type=str, |
|
default="faiss", |
|
help="Type of vectorstore (default: faiss)", |
|
) |
|
|
|
args = parser.parse_args() |
|
main( |
|
args.input_path, |
|
args.output_db, |
|
args.chunk_size, |
|
args.chunk_overlap, |
|
args.db_type, |
|
) |
|
|