VectorStoreFlowModule / ChromaDBFlow.py
nbaldwin's picture
renamed flows to aiflow
0b084fa
raw
history blame
5.92 kB
import os
from typing import Dict, List, Any
import uuid
from copy import deepcopy
from langchain.embeddings import OpenAIEmbeddings
from chromadb import Client as ChromaClient
from aiflows.base_flows import AtomicFlow
import hydra
class ChromaDBFlow(AtomicFlow):
""" A flow that uses the ChromaDB model to write and read memories stored in a database
*Configuration Parameters*:
- `name` (str): The name of the flow. Default: "chroma_db"
- `description` (str): A description of the flow. This description is used to generate the help message of the flow.
Default: "ChromaDB is a document store that uses vector embeddings to store and retrieve documents."
- `backend` (Dict[str, Any]): The configuration of the backend which is used to fetch api keys. Default: LiteLLMBackend with the
default parameters of LiteLLMBackend (see aiflows.backends.LiteLLMBackend). Except for the following parameter whose default value is overwritten:
- `api_infos` (List[Dict[str, Any]]): The list of api infos. Default: No default value, this parameter is required.
- `model_name` (str): The name of the model. Default: "". In the current implementation, this parameter is not used.
- `n_results` (int): The number of results to retrieve when reading from the database. Default: 5
- Other parameters are inherited from the default configuration of AtomicFlow (see AtomicFlow)
*Input Interface*:
- `operation` (str): The operation to perform. It can be "write" or "read".
- `content` (str or List[str]): The content to write or read. If operation is "write", it must be a string or a list of strings. If operation is "read", it must be a string.
*Output Interface*:
- `retrieved` (str or List[str]): The retrieved content. If operation is "write", it is an empty string. If operation is "read", it is a string or a list of strings.
:param backend: The backend of the flow (used to retrieve the API key)
:type backend: LiteLLMBackend
:param \**kwargs: Additional arguments to pass to the flow.
"""
def __init__(self, backend,**kwargs):
super().__init__(**kwargs)
self.client = ChromaClient()
self.collection = self.client.get_or_create_collection(name=self.flow_config["name"])
self.backend = backend
@classmethod
def _set_up_backend(cls, config):
""" This instantiates the backend of the flow from a configuration file.
:param config: The configuration of the backend.
:type config: Dict[str, Any]
:return: The backend of the flow.
:rtype: Dict[str, LiteLLMBackend]
"""
kwargs = {}
kwargs["backend"] = \
hydra.utils.instantiate(config['backend'], _convert_="partial")
return kwargs
@classmethod
def instantiate_from_config(cls, config):
""" This method instantiates the flow from a configuration file
:param config: The configuration of the flow.
:type config: Dict[str, Any]
:return: The instantiated flow.
:rtype: ChromaDBFlow
"""
flow_config = deepcopy(config)
kwargs = {"flow_config": flow_config}
# ~~~ Set up backend ~~~
kwargs.update(cls._set_up_backend(flow_config))
# ~~~ Instantiate flow ~~~
return cls(**kwargs)
def get_input_keys(self) -> List[str]:
""" This method returns the input keys of the flow.
:return: The input keys of the flow.
:rtype: List[str]
"""
return self.flow_config["input_keys"]
def get_output_keys(self) -> List[str]:
""" This method returns the output keys of the flow.
:return: The output keys of the flow.
:rtype: List[str]
"""
return self.flow_config["output_keys"]
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
""" This method runs the flow. It runs the ChromaDBFlow. It either writes or reads memories from the database.
:param input_data: The input data of the flow.
:type input_data: Dict[str, Any]
:return: The output data of the flow.
:rtype: Dict[str, Any]
"""
api_information = self.backend.get_key()
if api_information.backend_used == "openai":
embeddings = OpenAIEmbeddings(openai_api_key=api_information.api_key)
else:
# ToDo: Add support for Azure
embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"))
response = {}
operation = input_data["operation"]
if operation not in ["write", "read"]:
raise ValueError(f"Operation '{operation}' not supported")
content = input_data["content"]
if operation == "read":
if not isinstance(content, str):
raise ValueError(f"content(query) must be a string during read, got {type(content)}: {content}")
if content == "":
response["retrieved"] = [[""]]
return response
query = content
query_result = self.collection.query(
query_embeddings=embeddings.embed_query(query),
n_results=self.flow_config["n_results"]
)
response["retrieved"] = [doc for doc in query_result["documents"]]
elif operation == "write":
if content != "":
if not isinstance(content, list):
content = [content]
documents = content
self.collection.add(
ids=[str(uuid.uuid4()) for _ in range(len(documents))],
embeddings=embeddings.embed_documents(documents),
documents=documents
)
response["retrieved"] = ""
return response