|
"""Wrapper around ChromaDB (version 3) embeddings platform.""" |
|
from __future__ import annotations |
|
|
|
import logging |
|
import uuid |
|
from typing import ( |
|
TYPE_CHECKING, |
|
Any, |
|
Callable, |
|
Dict, |
|
Iterable, |
|
List, |
|
Optional, |
|
Tuple, |
|
Type, |
|
) |
|
|
|
import numpy as np |
|
|
|
from langchain.docstore.document import Document |
|
from langchain.embeddings.base import Embeddings |
|
from langchain.utils import xor_args |
|
from langchain.vectorstores.base import VectorStore |
|
from langchain.vectorstores.utils import maximal_marginal_relevance |
|
|
|
if TYPE_CHECKING: |
|
import chromamigdb |
|
import chromamigdb.config |
|
from chromamigdb.api.types import ID, OneOrMany, Where, WhereDocument |
|
|
|
logger = logging.getLogger() |
|
DEFAULT_K = 4 |
|
|
|
|
|
def _results_to_docs(results: Any) -> List[Document]: |
|
return [doc for doc, _ in _results_to_docs_and_scores(results)] |
|
|
|
|
|
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]: |
|
return [ |
|
|
|
|
|
(Document(page_content=result[0], metadata=result[1] or {}), result[2]) |
|
for result in zip( |
|
results["documents"][0], |
|
results["metadatas"][0], |
|
results["distances"][0], |
|
) |
|
] |
|
|
|
|
|
class ChromaMig(VectorStore): |
|
"""Wrapper around ChromaDB embeddings platform. |
|
|
|
To use, you should have the ``chromamigdb`` python package installed. |
|
|
|
Example: |
|
.. code-block:: python |
|
|
|
from langchain.vectorstores import Chroma |
|
from langchain.embeddings.openai import OpenAIEmbeddings |
|
|
|
embeddings = OpenAIEmbeddings() |
|
vectorstore = Chroma("langchain_store", embeddings) |
|
""" |
|
|
|
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" |
|
|
|
def __init__( |
|
self, |
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, |
|
embedding_function: Optional[Embeddings] = None, |
|
persist_directory: Optional[str] = None, |
|
client_settings: Optional[chromamigdb.config.Settings] = None, |
|
collection_metadata: Optional[Dict] = None, |
|
client: Optional[chromamigdb.Client] = None, |
|
relevance_score_fn: Optional[Callable[[float], float]] = None, |
|
) -> None: |
|
"""Initialize with Chroma client.""" |
|
try: |
|
import chromamigdb |
|
import chromamigdb.config |
|
except ImportError: |
|
raise ValueError( |
|
"Could not import chromamigdb python package. " |
|
"Please install it with `pip install chromamigdb`." |
|
) |
|
|
|
if client is not None: |
|
self._client_settings = client_settings |
|
self._client = client |
|
self._persist_directory = persist_directory |
|
else: |
|
if client_settings: |
|
|
|
|
|
client_settings.persist_directory = ( |
|
persist_directory or client_settings.persist_directory |
|
) |
|
if client_settings.persist_directory is not None: |
|
|
|
major, minor, _ = chromamigdb.__version__.split(".") |
|
if int(major) == 0 and int(minor) < 4: |
|
client_settings.chroma_db_impl = "duckdb+parquet" |
|
|
|
_client_settings = client_settings |
|
elif persist_directory: |
|
|
|
major, minor, _ = chromamigdb.__version__.split(".") |
|
if int(major) == 0 and int(minor) < 4: |
|
_client_settings = chromamigdb.config.Settings( |
|
chroma_db_impl="duckdb+parquet", |
|
) |
|
else: |
|
_client_settings = chromamigdb.config.Settings(is_persistent=True) |
|
_client_settings.persist_directory = persist_directory |
|
else: |
|
_client_settings = chromamigdb.config.Settings() |
|
self._client_settings = _client_settings |
|
self._client = chromamigdb.Client(_client_settings) |
|
self._persist_directory = ( |
|
_client_settings.persist_directory or persist_directory |
|
) |
|
|
|
self._embedding_function = embedding_function |
|
self._collection = self._client.get_or_create_collection( |
|
name=collection_name, |
|
embedding_function=self._embedding_function.embed_documents |
|
if self._embedding_function is not None |
|
else None, |
|
metadata=collection_metadata, |
|
) |
|
self.override_relevance_score_fn = relevance_score_fn |
|
|
|
@property |
|
def embeddings(self) -> Optional[Embeddings]: |
|
return self._embedding_function |
|
|
|
@xor_args(("query_texts", "query_embeddings")) |
|
def __query_collection( |
|
self, |
|
query_texts: Optional[List[str]] = None, |
|
query_embeddings: Optional[List[List[float]]] = None, |
|
n_results: int = 4, |
|
where: Optional[Dict[str, str]] = None, |
|
**kwargs: Any, |
|
) -> List[Document]: |
|
"""Query the chroma collection.""" |
|
try: |
|
import chromamigdb |
|
except ImportError: |
|
raise ValueError( |
|
"Could not import chromamigdb python package. " |
|
"Please install it with `pip install chromamigdb`." |
|
) |
|
return self._collection.query( |
|
query_texts=query_texts, |
|
query_embeddings=query_embeddings, |
|
n_results=n_results, |
|
where=where, |
|
**kwargs, |
|
) |
|
|
|
def add_texts( |
|
self, |
|
texts: Iterable[str], |
|
metadatas: Optional[List[dict]] = None, |
|
ids: Optional[List[str]] = None, |
|
**kwargs: Any, |
|
) -> List[str]: |
|
"""Run more texts through the embeddings and add to the vectorstore. |
|
|
|
Args: |
|
texts (Iterable[str]): Texts to add to the vectorstore. |
|
metadatas (Optional[List[dict]], optional): Optional list of metadatas. |
|
ids (Optional[List[str]], optional): Optional list of IDs. |
|
|
|
Returns: |
|
List[str]: List of IDs of the added texts. |
|
""" |
|
|
|
if ids is None: |
|
ids = [str(uuid.uuid1()) for _ in texts] |
|
embeddings = None |
|
texts = list(texts) |
|
if self._embedding_function is not None: |
|
embeddings = self._embedding_function.embed_documents(texts) |
|
if metadatas: |
|
|
|
|
|
length_diff = len(texts) - len(metadatas) |
|
if length_diff: |
|
metadatas = metadatas + [{}] * length_diff |
|
empty_ids = [] |
|
non_empty_ids = [] |
|
for idx, m in enumerate(metadatas): |
|
if m: |
|
non_empty_ids.append(idx) |
|
else: |
|
empty_ids.append(idx) |
|
if non_empty_ids: |
|
metadatas = [metadatas[idx] for idx in non_empty_ids] |
|
texts_with_metadatas = [texts[idx] for idx in non_empty_ids] |
|
embeddings_with_metadatas = ( |
|
[embeddings[idx] for idx in non_empty_ids] if embeddings else None |
|
) |
|
ids_with_metadata = [ids[idx] for idx in non_empty_ids] |
|
self._collection.upsert( |
|
metadatas=metadatas, |
|
embeddings=embeddings_with_metadatas, |
|
documents=texts_with_metadatas, |
|
ids=ids_with_metadata, |
|
) |
|
if empty_ids: |
|
texts_without_metadatas = [texts[j] for j in empty_ids] |
|
embeddings_without_metadatas = ( |
|
[embeddings[j] for j in empty_ids] if embeddings else None |
|
) |
|
ids_without_metadatas = [ids[j] for j in empty_ids] |
|
self._collection.upsert( |
|
embeddings=embeddings_without_metadatas, |
|
documents=texts_without_metadatas, |
|
ids=ids_without_metadatas, |
|
) |
|
else: |
|
self._collection.upsert( |
|
embeddings=embeddings, |
|
documents=texts, |
|
ids=ids, |
|
) |
|
return ids |
|
|
|
def similarity_search( |
|
self, |
|
query: str, |
|
k: int = DEFAULT_K, |
|
filter: Optional[Dict[str, str]] = None, |
|
**kwargs: Any, |
|
) -> List[Document]: |
|
"""Run similarity search with Chroma. |
|
|
|
Args: |
|
query (str): Query text to search for. |
|
k (int): Number of results to return. Defaults to 4. |
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. |
|
|
|
Returns: |
|
List[Document]: List of documents most similar to the query text. |
|
""" |
|
docs_and_scores = self.similarity_search_with_score(query, k, filter=filter) |
|
return [doc for doc, _ in docs_and_scores] |
|
|
|
def similarity_search_by_vector( |
|
self, |
|
embedding: List[float], |
|
k: int = DEFAULT_K, |
|
filter: Optional[Dict[str, str]] = None, |
|
**kwargs: Any, |
|
) -> List[Document]: |
|
"""Return docs most similar to embedding vector. |
|
Args: |
|
embedding (List[float]): Embedding to look up documents similar to. |
|
k (int): Number of Documents to return. Defaults to 4. |
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. |
|
Returns: |
|
List of Documents most similar to the query vector. |
|
""" |
|
results = self.__query_collection( |
|
query_embeddings=embedding, n_results=k, where=filter |
|
) |
|
return _results_to_docs(results) |
|
|
|
def similarity_search_by_vector_with_relevance_scores( |
|
self, |
|
embedding: List[float], |
|
k: int = DEFAULT_K, |
|
filter: Optional[Dict[str, str]] = None, |
|
**kwargs: Any, |
|
) -> List[Tuple[Document, float]]: |
|
""" |
|
Return docs most similar to embedding vector and similarity score. |
|
|
|
Args: |
|
embedding (List[float]): Embedding to look up documents similar to. |
|
k (int): Number of Documents to return. Defaults to 4. |
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. |
|
|
|
Returns: |
|
List[Tuple[Document, float]]: List of documents most similar to |
|
the query text and cosine distance in float for each. |
|
Lower score represents more similarity. |
|
""" |
|
results = self.__query_collection( |
|
query_embeddings=embedding, n_results=k, where=filter |
|
) |
|
return _results_to_docs_and_scores(results) |
|
|
|
def similarity_search_with_score( |
|
self, |
|
query: str, |
|
k: int = DEFAULT_K, |
|
filter: Optional[Dict[str, str]] = None, |
|
**kwargs: Any, |
|
) -> List[Tuple[Document, float]]: |
|
"""Run similarity search with Chroma with distance. |
|
|
|
Args: |
|
query (str): Query text to search for. |
|
k (int): Number of results to return. Defaults to 4. |
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. |
|
|
|
Returns: |
|
List[Tuple[Document, float]]: List of documents most similar to |
|
the query text and cosine distance in float for each. |
|
Lower score represents more similarity. |
|
""" |
|
if self._embedding_function is None: |
|
results = self.__query_collection( |
|
query_texts=[query], n_results=k, where=filter |
|
) |
|
else: |
|
query_embedding = self._embedding_function.embed_query(query) |
|
results = self.__query_collection( |
|
query_embeddings=[query_embedding], n_results=k, where=filter |
|
) |
|
|
|
return _results_to_docs_and_scores(results) |
|
|
|
def _select_relevance_score_fn(self) -> Callable[[float], float]: |
|
""" |
|
The 'correct' relevance function |
|
may differ depending on a few things, including: |
|
- the distance / similarity metric used by the VectorStore |
|
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!) |
|
- embedding dimensionality |
|
- etc. |
|
""" |
|
if self.override_relevance_score_fn: |
|
return self.override_relevance_score_fn |
|
|
|
distance = "l2" |
|
distance_key = "hnsw:space" |
|
metadata = self._collection.metadata |
|
|
|
if metadata and distance_key in metadata: |
|
distance = metadata[distance_key] |
|
|
|
if distance == "cosine": |
|
return self._cosine_relevance_score_fn |
|
elif distance == "l2": |
|
return self._euclidean_relevance_score_fn |
|
elif distance == "ip": |
|
return self._max_inner_product_relevance_score_fn |
|
else: |
|
raise ValueError( |
|
"No supported normalization function" |
|
f" for distance metric of type: {distance}." |
|
"Consider providing relevance_score_fn to Chroma constructor." |
|
) |
|
|
|
def max_marginal_relevance_search_by_vector( |
|
self, |
|
embedding: List[float], |
|
k: int = DEFAULT_K, |
|
fetch_k: int = 20, |
|
lambda_mult: float = 0.5, |
|
filter: Optional[Dict[str, str]] = None, |
|
**kwargs: Any, |
|
) -> List[Document]: |
|
"""Return docs selected using the maximal marginal relevance. |
|
Maximal marginal relevance optimizes for similarity to query AND diversity |
|
among selected documents. |
|
|
|
Args: |
|
embedding: Embedding to look up documents similar to. |
|
k: Number of Documents to return. Defaults to 4. |
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm. |
|
lambda_mult: Number between 0 and 1 that determines the degree |
|
of diversity among the results with 0 corresponding |
|
to maximum diversity and 1 to minimum diversity. |
|
Defaults to 0.5. |
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. |
|
|
|
Returns: |
|
List of Documents selected by maximal marginal relevance. |
|
""" |
|
|
|
results = self.__query_collection( |
|
query_embeddings=embedding, |
|
n_results=fetch_k, |
|
where=filter, |
|
include=["metadatas", "documents", "distances", "embeddings"], |
|
) |
|
mmr_selected = maximal_marginal_relevance( |
|
np.array(embedding, dtype=np.float32), |
|
results["embeddings"][0], |
|
k=k, |
|
lambda_mult=lambda_mult, |
|
) |
|
|
|
candidates = _results_to_docs(results) |
|
|
|
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected] |
|
return selected_results |
|
|
|
def max_marginal_relevance_search( |
|
self, |
|
query: str, |
|
k: int = DEFAULT_K, |
|
fetch_k: int = 20, |
|
lambda_mult: float = 0.5, |
|
filter: Optional[Dict[str, str]] = None, |
|
**kwargs: Any, |
|
) -> List[Document]: |
|
"""Return docs selected using the maximal marginal relevance. |
|
Maximal marginal relevance optimizes for similarity to query AND diversity |
|
among selected documents. |
|
|
|
Args: |
|
query: Text to look up documents similar to. |
|
k: Number of Documents to return. Defaults to 4. |
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm. |
|
lambda_mult: Number between 0 and 1 that determines the degree |
|
of diversity among the results with 0 corresponding |
|
to maximum diversity and 1 to minimum diversity. |
|
Defaults to 0.5. |
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. |
|
|
|
Returns: |
|
List of Documents selected by maximal marginal relevance. |
|
""" |
|
if self._embedding_function is None: |
|
raise ValueError( |
|
"For MMR search, you must specify an embedding function on" "creation." |
|
) |
|
|
|
embedding = self._embedding_function.embed_query(query) |
|
docs = self.max_marginal_relevance_search_by_vector( |
|
embedding, k, fetch_k, lambda_mult=lambda_mult, filter=filter |
|
) |
|
return docs |
|
|
|
def delete_collection(self) -> None: |
|
"""Delete the collection.""" |
|
self._client.delete_collection(self._collection.name) |
|
|
|
def get( |
|
self, |
|
ids: Optional[OneOrMany[ID]] = None, |
|
where: Optional[Where] = None, |
|
limit: Optional[int] = None, |
|
offset: Optional[int] = None, |
|
where_document: Optional[WhereDocument] = None, |
|
include: Optional[List[str]] = None, |
|
) -> Dict[str, Any]: |
|
"""Gets the collection. |
|
|
|
Args: |
|
ids: The ids of the embeddings to get. Optional. |
|
where: A Where type dict used to filter results by. |
|
E.g. `{"color" : "red", "price": 4.20}`. Optional. |
|
limit: The number of documents to return. Optional. |
|
offset: The offset to start returning results from. |
|
Useful for paging results with limit. Optional. |
|
where_document: A WhereDocument type dict used to filter by the documents. |
|
E.g. `{$contains: {"text": "hello"}}`. Optional. |
|
include: A list of what to include in the results. |
|
Can contain `"embeddings"`, `"metadatas"`, `"documents"`. |
|
Ids are always included. |
|
Defaults to `["metadatas", "documents"]`. Optional. |
|
""" |
|
kwargs = { |
|
"ids": ids, |
|
"where": where, |
|
"limit": limit, |
|
"offset": offset, |
|
"where_document": where_document, |
|
} |
|
|
|
if include is not None: |
|
kwargs["include"] = include |
|
|
|
return self._collection.get(**kwargs) |
|
|
|
def persist(self) -> None: |
|
"""Persist the collection. |
|
|
|
This can be used to explicitly persist the data to disk. |
|
It will also be called automatically when the object is destroyed. |
|
""" |
|
if self._persist_directory is None: |
|
raise ValueError( |
|
"You must specify a persist_directory on" |
|
"creation to persist the collection." |
|
) |
|
import chromamigdb |
|
|
|
|
|
major, minor, _ = chromamigdb.__version__.split(".") |
|
if int(major) == 0 and int(minor) < 4: |
|
self._client.persist() |
|
|
|
def update_document(self, document_id: str, document: Document) -> None: |
|
"""Update a document in the collection. |
|
|
|
Args: |
|
document_id (str): ID of the document to update. |
|
document (Document): Document to update. |
|
""" |
|
text = document.page_content |
|
metadata = document.metadata |
|
if self._embedding_function is None: |
|
raise ValueError( |
|
"For update, you must specify an embedding function on creation." |
|
) |
|
embeddings = self._embedding_function.embed_documents([text]) |
|
|
|
self._collection.update( |
|
ids=[document_id], |
|
embeddings=embeddings, |
|
documents=[text], |
|
metadatas=[metadata], |
|
) |
|
|
|
@classmethod |
|
def from_texts( |
|
cls: Type[ChromaMig], |
|
texts: List[str], |
|
embedding: Optional[Embeddings] = None, |
|
metadatas: Optional[List[dict]] = None, |
|
ids: Optional[List[str]] = None, |
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, |
|
persist_directory: Optional[str] = None, |
|
client_settings: Optional[chromamigdb.config.Settings] = None, |
|
client: Optional[chromamigdb.Client] = None, |
|
collection_metadata: Optional[Dict] = None, |
|
**kwargs: Any, |
|
) -> ChromaMig: |
|
"""Create a Chroma vectorstore from a raw documents. |
|
|
|
If a persist_directory is specified, the collection will be persisted there. |
|
Otherwise, the data will be ephemeral in-memory. |
|
|
|
Args: |
|
texts (List[str]): List of texts to add to the collection. |
|
collection_name (str): Name of the collection to create. |
|
persist_directory (Optional[str]): Directory to persist the collection. |
|
embedding (Optional[Embeddings]): Embedding function. Defaults to None. |
|
metadatas (Optional[List[dict]]): List of metadatas. Defaults to None. |
|
ids (Optional[List[str]]): List of document IDs. Defaults to None. |
|
client_settings (Optional[chromamigdb.config.Settings]): Chroma client settings |
|
collection_metadata (Optional[Dict]): Collection configurations. |
|
Defaults to None. |
|
|
|
Returns: |
|
Chroma: Chroma vectorstore. |
|
""" |
|
chroma_collection = cls( |
|
collection_name=collection_name, |
|
embedding_function=embedding, |
|
persist_directory=persist_directory, |
|
client_settings=client_settings, |
|
client=client, |
|
collection_metadata=collection_metadata, |
|
**kwargs, |
|
) |
|
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids) |
|
return chroma_collection |
|
|
|
@classmethod |
|
def from_documents( |
|
cls: Type[ChromaMig], |
|
documents: List[Document], |
|
embedding: Optional[Embeddings] = None, |
|
ids: Optional[List[str]] = None, |
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, |
|
persist_directory: Optional[str] = None, |
|
client_settings: Optional[chromamigdb.config.Settings] = None, |
|
client: Optional[chromamigdb.Client] = None, |
|
collection_metadata: Optional[Dict] = None, |
|
**kwargs: Any, |
|
) -> ChromaMig: |
|
"""Create a Chroma vectorstore from a list of documents. |
|
|
|
If a persist_directory is specified, the collection will be persisted there. |
|
Otherwise, the data will be ephemeral in-memory. |
|
|
|
Args: |
|
collection_name (str): Name of the collection to create. |
|
persist_directory (Optional[str]): Directory to persist the collection. |
|
ids (Optional[List[str]]): List of document IDs. Defaults to None. |
|
documents (List[Document]): List of documents to add to the vectorstore. |
|
embedding (Optional[Embeddings]): Embedding function. Defaults to None. |
|
client_settings (Optional[chromamigdb.config.Settings]): Chroma client settings |
|
collection_metadata (Optional[Dict]): Collection configurations. |
|
Defaults to None. |
|
|
|
Returns: |
|
Chroma: Chroma vectorstore. |
|
""" |
|
texts = [doc.page_content for doc in documents] |
|
metadatas = [doc.metadata for doc in documents] |
|
return cls.from_texts( |
|
texts=texts, |
|
embedding=embedding, |
|
metadatas=metadatas, |
|
ids=ids, |
|
collection_name=collection_name, |
|
persist_directory=persist_directory, |
|
client_settings=client_settings, |
|
client=client, |
|
collection_metadata=collection_metadata, |
|
**kwargs, |
|
) |
|
|
|
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None: |
|
"""Delete by vector IDs. |
|
|
|
Args: |
|
ids: List of ids to delete. |
|
""" |
|
self._collection.delete(ids=ids) |
|
|