File size: 4,751 Bytes
a6c26b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Define the script's usage example
USAGE_EXAMPLE = """
Example usage:

To process input *.txt files at input_path and save the vector db output at output_db:
python create_vector_db.py input_path output_db --chunk_size 100 --chunk_overlap 10

Required arguments:
- input_path: Path to the input dir containing the .txt files
- output_path: Path to the output vector db.

Optional arguments:
- --chunk_size: Size of the chunks (default: None).
- --chunk_overlap: Overlap between chunks (default: None).
"""

import argparse
import logging
import os

from langchain.document_loaders import DirectoryLoader
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS, Chroma, Qdrant

# Configure the logger
logging.basicConfig(
    level=logging.INFO,  # Set the logging level (e.g., INFO, DEBUG)
    format="%(asctime)s [%(levelname)s] - %(message)s",  # Define the log message format
    handlers=[
        logging.StreamHandler(),  # Output logs to the console
        logging.FileHandler("create_vector_db.log"),
    ],
)

# Create a logger object
logger = logging.getLogger(__name__)


# Parse the arguments
def parse_arguments():
    parser = argparse.ArgumentParser(description="Process command line arguments.")
    parser.add_argument("-input_path", type=dir_path, help="path to input directory")
    parser.add_argument("--chunk_size", type=int, help="chunk size for splitting")
    parser.add_argument("--chunk_overlap", type=int, help="chunk overlap for splitting")
    parser.add_argument("-output_path", type=dir_path, help="path to input directory")

    return parser.parse_args()


# Check valid path
def dir_path(path):
    if os.path.isdir(path):
        return path
    else:
        raise argparse.ArgumentTypeError(f"readable_dir:{path} is not a valid path")


def main(input_path, output_db, chunk_size, chunk_overlap, db_type):
    # Load files from input_location
    loader = DirectoryLoader(input_path, glob="*.txt")
    docs = loader.load()
    logger.info(f"Total {len(docs)} files loaded")

    # get the text chunks
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len
    )
    chunks = text_splitter.split_documents(docs)
    logger.info(f"Total {len(chunks)} chunks created")

    # create vector store
    encode_kwargs = {"normalize_embeddings": True}
    embedding_model = "BAAI/bge-large-en"
    embeddings = HuggingFaceInstructEmbeddings(
        model_name=embedding_model,
        embed_instruction="",  # no instruction is needed for candidate passages
        query_instruction="Represent this sentence for searching relevant passages: ",
        encode_kwargs=encode_kwargs,
    )
    logger.info(
        f"Processing embeddings using {embedding_model}. This could take time depending on the number of chunks ..."
    )

    if db_type == "faiss":
        vectorstore = FAISS.from_documents(documents=chunks, embedding=embeddings)
        # save vectorstore
        vectorstore.save_local(output_db)
    elif db_type == "chromadb":
        vectorstore = Chroma.from_documents(
            documents=chunks, embedding=embeddings, persist_directory=output_db
        )
    elif db_type == "qdrant":
        vectorstore = Qdrant.from_documents(
            documents=chunks,
            embedding=embeddings,
            path=output_db,
            collection_name="test_collection",
        )
    elif db_type == "qdrant-server":
        url = "http://localhost:6333/"
        vectorstore = Qdrant.from_documents(
            documents=chunks,
            embedding=embeddings,
            url=url,
            prefer_grpc=True,
            collection_name="anaconda",
        )

    logger.info(f"Vector store saved to {output_db}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process data with optional chunking")

    # Required arguments
    parser.add_argument("input_path", type=str, help="Path to the input directory")
    parser.add_argument("output_db", type=str, help="Path to the output vectordb")

    # Optional arguments
    parser.add_argument(
        "--chunk_size", type=int, default=1000, help="Chunk size (default: 1000)"
    )
    parser.add_argument(
        "--chunk_overlap", type=int, default=200, help="Chunk overlap (default: 200)"
    )
    parser.add_argument(
        "--db_type",
        type=str,
        default="faiss",
        help="Type of vectorstore (default: faiss)",
    )

    args = parser.parse_args()
    main(
        args.input_path,
        args.output_db,
        args.chunk_size,
        args.chunk_overlap,
        args.db_type,
    )