Spaces:
Runtime error
Runtime error
""" | |
.. warning:: | |
Beta Feature! | |
**Cache** provides an optional caching layer for LLMs. | |
Cache is useful for two reasons: | |
- It can save you money by reducing the number of API calls you make to the LLM | |
provider if you're often requesting the same completion multiple times. | |
- It can speed up your application by reducing the number of API calls you make | |
to the LLM provider. | |
Cache directly competes with Memory. See documentation for Pros and Cons. | |
**Class hierarchy:** | |
.. code-block:: | |
BaseCache --> <name>Cache # Examples: InMemoryCache, RedisCache, GPTCache | |
""" | |
from __future__ import annotations | |
import hashlib | |
import inspect | |
import json | |
import logging | |
import uuid | |
import warnings | |
from datetime import timedelta | |
from functools import lru_cache | |
from typing import ( | |
TYPE_CHECKING, | |
Any, | |
Callable, | |
Dict, | |
List, | |
Optional, | |
Tuple, | |
Type, | |
Union, | |
cast, | |
) | |
from sqlalchemy import Column, Integer, Row, String, create_engine, select | |
from sqlalchemy.engine.base import Engine | |
from sqlalchemy.orm import Session | |
try: | |
from sqlalchemy.orm import declarative_base | |
except ImportError: | |
from sqlalchemy.ext.declarative import declarative_base | |
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache | |
from langchain_core.embeddings import Embeddings | |
from langchain_core.load.dump import dumps | |
from langchain_core.load.load import loads | |
from langchain_core.outputs import ChatGeneration, Generation | |
from langchain.llms.base import LLM, get_prompts | |
from langchain.utils import get_from_env | |
from langchain.vectorstores.redis import Redis as RedisVectorstore | |
logger = logging.getLogger(__file__) | |
if TYPE_CHECKING: | |
import momento | |
from cassandra.cluster import Session as CassandraSession | |
def _hash(_input: str) -> str: | |
"""Use a deterministic hashing approach.""" | |
return hashlib.md5(_input.encode()).hexdigest() | |
def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str: | |
"""Dump generations to json. | |
Args: | |
generations (RETURN_VAL_TYPE): A list of language model generations. | |
Returns: | |
str: Json representing a list of generations. | |
Warning: would not work well with arbitrary subclasses of `Generation` | |
""" | |
return json.dumps([generation.dict() for generation in generations]) | |
def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE: | |
"""Load generations from json. | |
Args: | |
generations_json (str): A string of json representing a list of generations. | |
Raises: | |
ValueError: Could not decode json string to list of generations. | |
Returns: | |
RETURN_VAL_TYPE: A list of generations. | |
Warning: would not work well with arbitrary subclasses of `Generation` | |
""" | |
try: | |
results = json.loads(generations_json) | |
return [Generation(**generation_dict) for generation_dict in results] | |
except json.JSONDecodeError: | |
raise ValueError( | |
f"Could not decode json to list of generations: {generations_json}" | |
) | |
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str: | |
""" | |
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation` | |
Args: | |
generations (RETURN_VAL_TYPE): A list of language model generations. | |
Returns: | |
str: a single string representing a list of generations. | |
This function (+ its counterpart `_loads_generations`) rely on | |
the dumps/loads pair with Reviver, so are able to deal | |
with all subclasses of Generation. | |
Each item in the list can be `dumps`ed to a string, | |
then we make the whole list of strings into a json-dumped. | |
""" | |
return json.dumps([dumps(_item) for _item in generations]) | |
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]: | |
""" | |
Deserialization of a string into a generic RETURN_VAL_TYPE | |
(i.e. a sequence of `Generation`). | |
See `_dumps_generations`, the inverse of this function. | |
Args: | |
generations_str (str): A string representing a list of generations. | |
Compatible with the legacy cache-blob format | |
Does not raise exceptions for malformed entries, just logs a warning | |
and returns none: the caller should be prepared for such a cache miss. | |
Returns: | |
RETURN_VAL_TYPE: A list of generations. | |
""" | |
try: | |
generations = [loads(_item_str) for _item_str in json.loads(generations_str)] | |
return generations | |
except (json.JSONDecodeError, TypeError): | |
# deferring the (soft) handling to after the legacy-format attempt | |
pass | |
try: | |
gen_dicts = json.loads(generations_str) | |
# not relying on `_load_generations_from_json` (which could disappear): | |
generations = [Generation(**generation_dict) for generation_dict in gen_dicts] | |
logger.warning( | |
f"Legacy 'Generation' cached blob encountered: '{generations_str}'" | |
) | |
return generations | |
except (json.JSONDecodeError, TypeError): | |
logger.warning( | |
f"Malformed/unparsable cached blob encountered: '{generations_str}'" | |
) | |
return None | |
class InMemoryCache(BaseCache): | |
"""Cache that stores things in memory.""" | |
def __init__(self) -> None: | |
"""Initialize with empty cache.""" | |
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {} | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Look up based on prompt and llm_string.""" | |
return self._cache.get((prompt, llm_string), None) | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Update cache based on prompt and llm_string.""" | |
self._cache[(prompt, llm_string)] = return_val | |
def clear(self, **kwargs: Any) -> None: | |
"""Clear cache.""" | |
self._cache = {} | |
Base = declarative_base() | |
class FullLLMCache(Base): # type: ignore | |
"""SQLite table for full LLM Cache (all generations).""" | |
__tablename__ = "full_llm_cache" | |
prompt = Column(String, primary_key=True) | |
llm = Column(String, primary_key=True) | |
idx = Column(Integer, primary_key=True) | |
response = Column(String) | |
class SQLAlchemyCache(BaseCache): | |
"""Cache that uses SQAlchemy as a backend.""" | |
def __init__(self, engine: Engine, cache_schema: Type[FullLLMCache] = FullLLMCache): | |
"""Initialize by creating all tables.""" | |
self.engine = engine | |
self.cache_schema = cache_schema | |
self.cache_schema.metadata.create_all(self.engine) | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Look up based on prompt and llm_string.""" | |
stmt = ( | |
select(self.cache_schema.response) | |
.where(self.cache_schema.prompt == prompt) # type: ignore | |
.where(self.cache_schema.llm == llm_string) | |
.order_by(self.cache_schema.idx) | |
) | |
with Session(self.engine) as session: | |
rows = session.execute(stmt).fetchall() | |
if rows: | |
try: | |
return [loads(row[0]) for row in rows] | |
except Exception: | |
logger.warning( | |
"Retrieving a cache value that could not be deserialized " | |
"properly. This is likely due to the cache being in an " | |
"older format. Please recreate your cache to avoid this " | |
"error." | |
) | |
# In a previous life we stored the raw text directly | |
# in the table, so assume it's in that format. | |
return [Generation(text=row[0]) for row in rows] | |
return None | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Update based on prompt and llm_string.""" | |
items = [ | |
self.cache_schema(prompt=prompt, llm=llm_string, response=dumps(gen), idx=i) | |
for i, gen in enumerate(return_val) | |
] | |
with Session(self.engine) as session, session.begin(): | |
for item in items: | |
session.merge(item) | |
def clear(self, **kwargs: Any) -> None: | |
"""Clear cache.""" | |
with Session(self.engine) as session: | |
session.query(self.cache_schema).delete() | |
session.commit() | |
class SQLiteCache(SQLAlchemyCache): | |
"""Cache that uses SQLite as a backend.""" | |
def __init__(self, database_path: str = ".langchain.db"): | |
"""Initialize by creating the engine and all tables.""" | |
engine = create_engine(f"sqlite:///{database_path}") | |
super().__init__(engine) | |
class UpstashRedisCache(BaseCache): | |
"""Cache that uses Upstash Redis as a backend.""" | |
def __init__(self, redis_: Any, *, ttl: Optional[int] = None): | |
""" | |
Initialize an instance of UpstashRedisCache. | |
This method initializes an object with Upstash Redis caching capabilities. | |
It takes a `redis_` parameter, which should be an instance of an Upstash Redis | |
client class, allowing the object to interact with Upstash Redis | |
server for caching purposes. | |
Parameters: | |
redis_: An instance of Upstash Redis client class | |
(e.g., Redis) used for caching. | |
This allows the object to communicate with | |
Redis server for caching operations on. | |
ttl (int, optional): Time-to-live (TTL) for cached items in seconds. | |
If provided, it sets the time duration for how long cached | |
items will remain valid. If not provided, cached items will not | |
have an automatic expiration. | |
""" | |
try: | |
from upstash_redis import Redis | |
except ImportError: | |
raise ValueError( | |
"Could not import upstash_redis python package. " | |
"Please install it with `pip install upstash_redis`." | |
) | |
if not isinstance(redis_, Redis): | |
raise ValueError("Please pass in Upstash Redis object.") | |
self.redis = redis_ | |
self.ttl = ttl | |
def _key(self, prompt: str, llm_string: str) -> str: | |
"""Compute key from prompt and llm_string""" | |
return _hash(prompt + llm_string) | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Look up based on prompt and llm_string.""" | |
generations = [] | |
# Read from a HASH | |
results = self.redis.hgetall(self._key(prompt, llm_string)) | |
if results: | |
for _, text in results.items(): | |
generations.append(Generation(text=text)) | |
return generations if generations else None | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Update cache based on prompt and llm_string.""" | |
for gen in return_val: | |
if not isinstance(gen, Generation): | |
raise ValueError( | |
"UpstashRedisCache supports caching of normal LLM generations, " | |
f"got {type(gen)}" | |
) | |
if isinstance(gen, ChatGeneration): | |
warnings.warn( | |
"NOTE: Generation has not been cached. UpstashRedisCache does not" | |
" support caching ChatModel outputs." | |
) | |
return | |
# Write to a HASH | |
key = self._key(prompt, llm_string) | |
mapping = { | |
str(idx): generation.text for idx, generation in enumerate(return_val) | |
} | |
self.redis.hset(key=key, values=mapping) | |
if self.ttl is not None: | |
self.redis.expire(key, self.ttl) | |
def clear(self, **kwargs: Any) -> None: | |
""" | |
Clear cache. If `asynchronous` is True, flush asynchronously. | |
This flushes the *whole* db. | |
""" | |
asynchronous = kwargs.get("asynchronous", False) | |
if asynchronous: | |
asynchronous = "ASYNC" | |
else: | |
asynchronous = "SYNC" | |
self.redis.flushdb(flush_type=asynchronous) | |
class RedisCache(BaseCache): | |
"""Cache that uses Redis as a backend.""" | |
def __init__(self, redis_: Any, *, ttl: Optional[int] = None): | |
""" | |
Initialize an instance of RedisCache. | |
This method initializes an object with Redis caching capabilities. | |
It takes a `redis_` parameter, which should be an instance of a Redis | |
client class, allowing the object to interact with a Redis | |
server for caching purposes. | |
Parameters: | |
redis_ (Any): An instance of a Redis client class | |
(e.g., redis.Redis) used for caching. | |
This allows the object to communicate with a | |
Redis server for caching operations. | |
ttl (int, optional): Time-to-live (TTL) for cached items in seconds. | |
If provided, it sets the time duration for how long cached | |
items will remain valid. If not provided, cached items will not | |
have an automatic expiration. | |
""" | |
try: | |
from redis import Redis | |
except ImportError: | |
raise ValueError( | |
"Could not import redis python package. " | |
"Please install it with `pip install redis`." | |
) | |
if not isinstance(redis_, Redis): | |
raise ValueError("Please pass in Redis object.") | |
self.redis = redis_ | |
self.ttl = ttl | |
def _key(self, prompt: str, llm_string: str) -> str: | |
"""Compute key from prompt and llm_string""" | |
return _hash(prompt + llm_string) | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Look up based on prompt and llm_string.""" | |
generations = [] | |
# Read from a Redis HASH | |
results = self.redis.hgetall(self._key(prompt, llm_string)) | |
if results: | |
for _, text in results.items(): | |
try: | |
generations.append(loads(text)) | |
except Exception: | |
logger.warning( | |
"Retrieving a cache value that could not be deserialized " | |
"properly. This is likely due to the cache being in an " | |
"older format. Please recreate your cache to avoid this " | |
"error." | |
) | |
# In a previous life we stored the raw text directly | |
# in the table, so assume it's in that format. | |
generations.append(Generation(text=text)) | |
return generations if generations else None | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Update cache based on prompt and llm_string.""" | |
for gen in return_val: | |
if not isinstance(gen, Generation): | |
raise ValueError( | |
"RedisCache only supports caching of normal LLM generations, " | |
f"got {type(gen)}" | |
) | |
# Write to a Redis HASH | |
key = self._key(prompt, llm_string) | |
with self.redis.pipeline() as pipe: | |
pipe.hset( | |
key, | |
mapping={ | |
str(idx): dumps(generation) | |
for idx, generation in enumerate(return_val) | |
}, | |
) | |
if self.ttl is not None: | |
pipe.expire(key, self.ttl) | |
pipe.execute() | |
def clear(self, **kwargs: Any) -> None: | |
"""Clear cache. If `asynchronous` is True, flush asynchronously.""" | |
asynchronous = kwargs.get("asynchronous", False) | |
self.redis.flushdb(asynchronous=asynchronous, **kwargs) | |
class RedisSemanticCache(BaseCache): | |
"""Cache that uses Redis as a vector-store backend.""" | |
# TODO - implement a TTL policy in Redis | |
DEFAULT_SCHEMA = { | |
"content_key": "prompt", | |
"text": [ | |
{"name": "prompt"}, | |
], | |
"extra": [{"name": "return_val"}, {"name": "llm_string"}], | |
} | |
def __init__( | |
self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2 | |
): | |
"""Initialize by passing in the `init` GPTCache func | |
Args: | |
redis_url (str): URL to connect to Redis. | |
embedding (Embedding): Embedding provider for semantic encoding and search. | |
score_threshold (float, 0.2): | |
Example: | |
.. code-block:: python | |
from langchain.globals import set_llm_cache | |
from langchain.cache import RedisSemanticCache | |
from langchain.embeddings import OpenAIEmbeddings | |
set_llm_cache(RedisSemanticCache( | |
redis_url="redis://localhost:6379", | |
embedding=OpenAIEmbeddings() | |
)) | |
""" | |
self._cache_dict: Dict[str, RedisVectorstore] = {} | |
self.redis_url = redis_url | |
self.embedding = embedding | |
self.score_threshold = score_threshold | |
def _index_name(self, llm_string: str) -> str: | |
hashed_index = _hash(llm_string) | |
return f"cache:{hashed_index}" | |
def _get_llm_cache(self, llm_string: str) -> RedisVectorstore: | |
index_name = self._index_name(llm_string) | |
# return vectorstore client for the specific llm string | |
if index_name in self._cache_dict: | |
return self._cache_dict[index_name] | |
# create new vectorstore client for the specific llm string | |
try: | |
self._cache_dict[index_name] = RedisVectorstore.from_existing_index( | |
embedding=self.embedding, | |
index_name=index_name, | |
redis_url=self.redis_url, | |
schema=cast(Dict, self.DEFAULT_SCHEMA), | |
) | |
except ValueError: | |
redis = RedisVectorstore( | |
embedding=self.embedding, | |
index_name=index_name, | |
redis_url=self.redis_url, | |
index_schema=cast(Dict, self.DEFAULT_SCHEMA), | |
) | |
_embedding = self.embedding.embed_query(text="test") | |
redis._create_index_if_not_exist(dim=len(_embedding)) | |
self._cache_dict[index_name] = redis | |
return self._cache_dict[index_name] | |
def clear(self, **kwargs: Any) -> None: | |
"""Clear semantic cache for a given llm_string.""" | |
index_name = self._index_name(kwargs["llm_string"]) | |
if index_name in self._cache_dict: | |
self._cache_dict[index_name].drop_index( | |
index_name=index_name, delete_documents=True, redis_url=self.redis_url | |
) | |
del self._cache_dict[index_name] | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Look up based on prompt and llm_string.""" | |
llm_cache = self._get_llm_cache(llm_string) | |
generations: List = [] | |
# Read from a Hash | |
results = llm_cache.similarity_search( | |
query=prompt, | |
k=1, | |
distance_threshold=self.score_threshold, | |
) | |
if results: | |
for document in results: | |
try: | |
generations.extend(loads(document.metadata["return_val"])) | |
except Exception: | |
logger.warning( | |
"Retrieving a cache value that could not be deserialized " | |
"properly. This is likely due to the cache being in an " | |
"older format. Please recreate your cache to avoid this " | |
"error." | |
) | |
# In a previous life we stored the raw text directly | |
# in the table, so assume it's in that format. | |
generations.extend( | |
_load_generations_from_json(document.metadata["return_val"]) | |
) | |
return generations if generations else None | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Update cache based on prompt and llm_string.""" | |
for gen in return_val: | |
if not isinstance(gen, Generation): | |
raise ValueError( | |
"RedisSemanticCache only supports caching of " | |
f"normal LLM generations, got {type(gen)}" | |
) | |
llm_cache = self._get_llm_cache(llm_string) | |
metadata = { | |
"llm_string": llm_string, | |
"prompt": prompt, | |
"return_val": dumps([g for g in return_val]), | |
} | |
llm_cache.add_texts(texts=[prompt], metadatas=[metadata]) | |
class GPTCache(BaseCache): | |
"""Cache that uses GPTCache as a backend.""" | |
def __init__( | |
self, | |
init_func: Union[ | |
Callable[[Any, str], None], Callable[[Any], None], None | |
] = None, | |
): | |
"""Initialize by passing in init function (default: `None`). | |
Args: | |
init_func (Optional[Callable[[Any], None]]): init `GPTCache` function | |
(default: `None`) | |
Example: | |
.. code-block:: python | |
# Initialize GPTCache with a custom init function | |
import gptcache | |
from gptcache.processor.pre import get_prompt | |
from gptcache.manager.factory import get_data_manager | |
from langchain.globals import set_llm_cache | |
# Avoid multiple caches using the same file, | |
causing different llm model caches to affect each other | |
def init_gptcache(cache_obj: gptcache.Cache, llm str): | |
cache_obj.init( | |
pre_embedding_func=get_prompt, | |
data_manager=manager_factory( | |
manager="map", | |
data_dir=f"map_cache_{llm}" | |
), | |
) | |
set_llm_cache(GPTCache(init_gptcache)) | |
""" | |
try: | |
import gptcache # noqa: F401 | |
except ImportError: | |
raise ImportError( | |
"Could not import gptcache python package. " | |
"Please install it with `pip install gptcache`." | |
) | |
self.init_gptcache_func: Union[ | |
Callable[[Any, str], None], Callable[[Any], None], None | |
] = init_func | |
self.gptcache_dict: Dict[str, Any] = {} | |
def _new_gptcache(self, llm_string: str) -> Any: | |
"""New gptcache object""" | |
from gptcache import Cache | |
from gptcache.manager.factory import get_data_manager | |
from gptcache.processor.pre import get_prompt | |
_gptcache = Cache() | |
if self.init_gptcache_func is not None: | |
sig = inspect.signature(self.init_gptcache_func) | |
if len(sig.parameters) == 2: | |
self.init_gptcache_func(_gptcache, llm_string) # type: ignore[call-arg] | |
else: | |
self.init_gptcache_func(_gptcache) # type: ignore[call-arg] | |
else: | |
_gptcache.init( | |
pre_embedding_func=get_prompt, | |
data_manager=get_data_manager(data_path=llm_string), | |
) | |
self.gptcache_dict[llm_string] = _gptcache | |
return _gptcache | |
def _get_gptcache(self, llm_string: str) -> Any: | |
"""Get a cache object. | |
When the corresponding llm model cache does not exist, it will be created.""" | |
_gptcache = self.gptcache_dict.get(llm_string, None) | |
if not _gptcache: | |
_gptcache = self._new_gptcache(llm_string) | |
return _gptcache | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Look up the cache data. | |
First, retrieve the corresponding cache object using the `llm_string` parameter, | |
and then retrieve the data from the cache based on the `prompt`. | |
""" | |
from gptcache.adapter.api import get | |
_gptcache = self._get_gptcache(llm_string) | |
res = get(prompt, cache_obj=_gptcache) | |
if res: | |
return [ | |
Generation(**generation_dict) for generation_dict in json.loads(res) | |
] | |
return None | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Update cache. | |
First, retrieve the corresponding cache object using the `llm_string` parameter, | |
and then store the `prompt` and `return_val` in the cache object. | |
""" | |
for gen in return_val: | |
if not isinstance(gen, Generation): | |
raise ValueError( | |
"GPTCache only supports caching of normal LLM generations, " | |
f"got {type(gen)}" | |
) | |
from gptcache.adapter.api import put | |
_gptcache = self._get_gptcache(llm_string) | |
handled_data = json.dumps([generation.dict() for generation in return_val]) | |
put(prompt, handled_data, cache_obj=_gptcache) | |
return None | |
def clear(self, **kwargs: Any) -> None: | |
"""Clear cache.""" | |
from gptcache import Cache | |
for gptcache_instance in self.gptcache_dict.values(): | |
gptcache_instance = cast(Cache, gptcache_instance) | |
gptcache_instance.flush() | |
self.gptcache_dict.clear() | |
def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None: | |
"""Create cache if it doesn't exist. | |
Raises: | |
SdkException: Momento service or network error | |
Exception: Unexpected response | |
""" | |
from momento.responses import CreateCache | |
create_cache_response = cache_client.create_cache(cache_name) | |
if isinstance(create_cache_response, CreateCache.Success) or isinstance( | |
create_cache_response, CreateCache.CacheAlreadyExists | |
): | |
return None | |
elif isinstance(create_cache_response, CreateCache.Error): | |
raise create_cache_response.inner_exception | |
else: | |
raise Exception(f"Unexpected response cache creation: {create_cache_response}") | |
def _validate_ttl(ttl: Optional[timedelta]) -> None: | |
if ttl is not None and ttl <= timedelta(seconds=0): | |
raise ValueError(f"ttl must be positive but was {ttl}.") | |
class MomentoCache(BaseCache): | |
"""Cache that uses Momento as a backend. See https://gomomento.com/""" | |
def __init__( | |
self, | |
cache_client: momento.CacheClient, | |
cache_name: str, | |
*, | |
ttl: Optional[timedelta] = None, | |
ensure_cache_exists: bool = True, | |
): | |
"""Instantiate a prompt cache using Momento as a backend. | |
Note: to instantiate the cache client passed to MomentoCache, | |
you must have a Momento account. See https://gomomento.com/. | |
Args: | |
cache_client (CacheClient): The Momento cache client. | |
cache_name (str): The name of the cache to use to store the data. | |
ttl (Optional[timedelta], optional): The time to live for the cache items. | |
Defaults to None, ie use the client default TTL. | |
ensure_cache_exists (bool, optional): Create the cache if it doesn't | |
exist. Defaults to True. | |
Raises: | |
ImportError: Momento python package is not installed. | |
TypeError: cache_client is not of type momento.CacheClientObject | |
ValueError: ttl is non-null and non-negative | |
""" | |
try: | |
from momento import CacheClient | |
except ImportError: | |
raise ImportError( | |
"Could not import momento python package. " | |
"Please install it with `pip install momento`." | |
) | |
if not isinstance(cache_client, CacheClient): | |
raise TypeError("cache_client must be a momento.CacheClient object.") | |
_validate_ttl(ttl) | |
if ensure_cache_exists: | |
_ensure_cache_exists(cache_client, cache_name) | |
self.cache_client = cache_client | |
self.cache_name = cache_name | |
self.ttl = ttl | |
def from_client_params( | |
cls, | |
cache_name: str, | |
ttl: timedelta, | |
*, | |
configuration: Optional[momento.config.Configuration] = None, | |
api_key: Optional[str] = None, | |
auth_token: Optional[str] = None, # for backwards compatibility | |
**kwargs: Any, | |
) -> MomentoCache: | |
"""Construct cache from CacheClient parameters.""" | |
try: | |
from momento import CacheClient, Configurations, CredentialProvider | |
except ImportError: | |
raise ImportError( | |
"Could not import momento python package. " | |
"Please install it with `pip install momento`." | |
) | |
if configuration is None: | |
configuration = Configurations.Laptop.v1() | |
# Try checking `MOMENTO_AUTH_TOKEN` first for backwards compatibility | |
try: | |
api_key = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN") | |
except ValueError: | |
api_key = api_key or get_from_env("api_key", "MOMENTO_API_KEY") | |
credentials = CredentialProvider.from_string(api_key) | |
cache_client = CacheClient(configuration, credentials, default_ttl=ttl) | |
return cls(cache_client, cache_name, ttl=ttl, **kwargs) | |
def __key(self, prompt: str, llm_string: str) -> str: | |
"""Compute cache key from prompt and associated model and settings. | |
Args: | |
prompt (str): The prompt run through the language model. | |
llm_string (str): The language model version and settings. | |
Returns: | |
str: The cache key. | |
""" | |
return _hash(prompt + llm_string) | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Lookup llm generations in cache by prompt and associated model and settings. | |
Args: | |
prompt (str): The prompt run through the language model. | |
llm_string (str): The language model version and settings. | |
Raises: | |
SdkException: Momento service or network error | |
Returns: | |
Optional[RETURN_VAL_TYPE]: A list of language model generations. | |
""" | |
from momento.responses import CacheGet | |
generations: RETURN_VAL_TYPE = [] | |
get_response = self.cache_client.get( | |
self.cache_name, self.__key(prompt, llm_string) | |
) | |
if isinstance(get_response, CacheGet.Hit): | |
value = get_response.value_string | |
generations = _load_generations_from_json(value) | |
elif isinstance(get_response, CacheGet.Miss): | |
pass | |
elif isinstance(get_response, CacheGet.Error): | |
raise get_response.inner_exception | |
return generations if generations else None | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Store llm generations in cache. | |
Args: | |
prompt (str): The prompt run through the language model. | |
llm_string (str): The language model string. | |
return_val (RETURN_VAL_TYPE): A list of language model generations. | |
Raises: | |
SdkException: Momento service or network error | |
Exception: Unexpected response | |
""" | |
for gen in return_val: | |
if not isinstance(gen, Generation): | |
raise ValueError( | |
"Momento only supports caching of normal LLM generations, " | |
f"got {type(gen)}" | |
) | |
key = self.__key(prompt, llm_string) | |
value = _dump_generations_to_json(return_val) | |
set_response = self.cache_client.set(self.cache_name, key, value, self.ttl) | |
from momento.responses import CacheSet | |
if isinstance(set_response, CacheSet.Success): | |
pass | |
elif isinstance(set_response, CacheSet.Error): | |
raise set_response.inner_exception | |
else: | |
raise Exception(f"Unexpected response: {set_response}") | |
def clear(self, **kwargs: Any) -> None: | |
"""Clear the cache. | |
Raises: | |
SdkException: Momento service or network error | |
""" | |
from momento.responses import CacheFlush | |
flush_response = self.cache_client.flush_cache(self.cache_name) | |
if isinstance(flush_response, CacheFlush.Success): | |
pass | |
elif isinstance(flush_response, CacheFlush.Error): | |
raise flush_response.inner_exception | |
CASSANDRA_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_cache" | |
CASSANDRA_CACHE_DEFAULT_TTL_SECONDS = None | |
class CassandraCache(BaseCache): | |
""" | |
Cache that uses Cassandra / Astra DB as a backend. | |
It uses a single Cassandra table. | |
The lookup keys (which get to form the primary key) are: | |
- prompt, a string | |
- llm_string, a deterministic str representation of the model parameters. | |
(needed to prevent collisions same-prompt-different-model collisions) | |
""" | |
def __init__( | |
self, | |
session: Optional[CassandraSession] = None, | |
keyspace: Optional[str] = None, | |
table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME, | |
ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS, | |
skip_provisioning: bool = False, | |
): | |
""" | |
Initialize with a ready session and a keyspace name. | |
Args: | |
session (cassandra.cluster.Session): an open Cassandra session | |
keyspace (str): the keyspace to use for storing the cache | |
table_name (str): name of the Cassandra table to use as cache | |
ttl_seconds (optional int): time-to-live for cache entries | |
(default: None, i.e. forever) | |
""" | |
try: | |
from cassio.table import ElasticCassandraTable | |
except (ImportError, ModuleNotFoundError): | |
raise ValueError( | |
"Could not import cassio python package. " | |
"Please install it with `pip install cassio`." | |
) | |
self.session = session | |
self.keyspace = keyspace | |
self.table_name = table_name | |
self.ttl_seconds = ttl_seconds | |
self.kv_cache = ElasticCassandraTable( | |
session=self.session, | |
keyspace=self.keyspace, | |
table=self.table_name, | |
keys=["llm_string", "prompt"], | |
primary_key_type=["TEXT", "TEXT"], | |
ttl_seconds=self.ttl_seconds, | |
skip_provisioning=skip_provisioning, | |
) | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Look up based on prompt and llm_string.""" | |
item = self.kv_cache.get( | |
llm_string=_hash(llm_string), | |
prompt=_hash(prompt), | |
) | |
if item is not None: | |
generations = _loads_generations(item["body_blob"]) | |
# this protects against malformed cached items: | |
if generations is not None: | |
return generations | |
else: | |
return None | |
else: | |
return None | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Update cache based on prompt and llm_string.""" | |
blob = _dumps_generations(return_val) | |
self.kv_cache.put( | |
llm_string=_hash(llm_string), | |
prompt=_hash(prompt), | |
body_blob=blob, | |
) | |
def delete_through_llm( | |
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None | |
) -> None: | |
""" | |
A wrapper around `delete` with the LLM being passed. | |
In case the llm(prompt) calls have a `stop` param, you should pass it here | |
""" | |
llm_string = get_prompts( | |
{**llm.dict(), **{"stop": stop}}, | |
[], | |
)[1] | |
return self.delete(prompt, llm_string=llm_string) | |
def delete(self, prompt: str, llm_string: str) -> None: | |
"""Evict from cache if there's an entry.""" | |
return self.kv_cache.delete( | |
llm_string=_hash(llm_string), | |
prompt=_hash(prompt), | |
) | |
def clear(self, **kwargs: Any) -> None: | |
"""Clear cache. This is for all LLMs at once.""" | |
self.kv_cache.clear() | |
CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC = "dot" | |
CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD = 0.85 | |
CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_semantic_cache" | |
CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS = None | |
CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16 | |
class CassandraSemanticCache(BaseCache): | |
""" | |
Cache that uses Cassandra as a vector-store backend for semantic | |
(i.e. similarity-based) lookup. | |
It uses a single (vector) Cassandra table and stores, in principle, | |
cached values from several LLMs, so the LLM's llm_string is part | |
of the rows' primary keys. | |
The similarity is based on one of several distance metrics (default: "dot"). | |
If choosing another metric, the default threshold is to be re-tuned accordingly. | |
""" | |
def __init__( | |
self, | |
session: Optional[CassandraSession], | |
keyspace: Optional[str], | |
embedding: Embeddings, | |
table_name: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME, | |
distance_metric: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC, | |
score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD, | |
ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS, | |
skip_provisioning: bool = False, | |
): | |
""" | |
Initialize the cache with all relevant parameters. | |
Args: | |
session (cassandra.cluster.Session): an open Cassandra session | |
keyspace (str): the keyspace to use for storing the cache | |
embedding (Embedding): Embedding provider for semantic | |
encoding and search. | |
table_name (str): name of the Cassandra (vector) table | |
to use as cache | |
distance_metric (str, 'dot'): which measure to adopt for | |
similarity searches | |
score_threshold (optional float): numeric value to use as | |
cutoff for the similarity searches | |
ttl_seconds (optional int): time-to-live for cache entries | |
(default: None, i.e. forever) | |
The default score threshold is tuned to the default metric. | |
Tune it carefully yourself if switching to another distance metric. | |
""" | |
try: | |
from cassio.table import MetadataVectorCassandraTable | |
except (ImportError, ModuleNotFoundError): | |
raise ValueError( | |
"Could not import cassio python package. " | |
"Please install it with `pip install cassio`." | |
) | |
self.session = session | |
self.keyspace = keyspace | |
self.embedding = embedding | |
self.table_name = table_name | |
self.distance_metric = distance_metric | |
self.score_threshold = score_threshold | |
self.ttl_seconds = ttl_seconds | |
# The contract for this class has separate lookup and update: | |
# in order to spare some embedding calculations we cache them between | |
# the two calls. | |
# Note: each instance of this class has its own `_get_embedding` with | |
# its own lru. | |
def _cache_embedding(text: str) -> List[float]: | |
return self.embedding.embed_query(text=text) | |
self._get_embedding = _cache_embedding | |
self.embedding_dimension = self._get_embedding_dimension() | |
self.table = MetadataVectorCassandraTable( | |
session=self.session, | |
keyspace=self.keyspace, | |
table=self.table_name, | |
primary_key_type=["TEXT"], | |
vector_dimension=self.embedding_dimension, | |
ttl_seconds=self.ttl_seconds, | |
metadata_indexing=("allow", {"_llm_string_hash"}), | |
skip_provisioning=skip_provisioning, | |
) | |
def _get_embedding_dimension(self) -> int: | |
return len(self._get_embedding(text="This is a sample sentence.")) | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Update cache based on prompt and llm_string.""" | |
embedding_vector = self._get_embedding(text=prompt) | |
llm_string_hash = _hash(llm_string) | |
body = _dumps_generations(return_val) | |
metadata = { | |
"_prompt": prompt, | |
"_llm_string_hash": llm_string_hash, | |
} | |
row_id = f"{_hash(prompt)}-{llm_string_hash}" | |
# | |
self.table.put( | |
body_blob=body, | |
vector=embedding_vector, | |
row_id=row_id, | |
metadata=metadata, | |
) | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Look up based on prompt and llm_string.""" | |
hit_with_id = self.lookup_with_id(prompt, llm_string) | |
if hit_with_id is not None: | |
return hit_with_id[1] | |
else: | |
return None | |
def lookup_with_id( | |
self, prompt: str, llm_string: str | |
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: | |
""" | |
Look up based on prompt and llm_string. | |
If there are hits, return (document_id, cached_entry) | |
""" | |
prompt_embedding: List[float] = self._get_embedding(text=prompt) | |
hits = list( | |
self.table.metric_ann_search( | |
vector=prompt_embedding, | |
metadata={"_llm_string_hash": _hash(llm_string)}, | |
n=1, | |
metric=self.distance_metric, | |
metric_threshold=self.score_threshold, | |
) | |
) | |
if hits: | |
hit = hits[0] | |
generations = _loads_generations(hit["body_blob"]) | |
if generations is not None: | |
# this protects against malformed cached items: | |
return ( | |
hit["row_id"], | |
generations, | |
) | |
else: | |
return None | |
else: | |
return None | |
def lookup_with_id_through_llm( | |
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None | |
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: | |
llm_string = get_prompts( | |
{**llm.dict(), **{"stop": stop}}, | |
[], | |
)[1] | |
return self.lookup_with_id(prompt, llm_string=llm_string) | |
def delete_by_document_id(self, document_id: str) -> None: | |
""" | |
Given this is a "similarity search" cache, an invalidation pattern | |
that makes sense is first a lookup to get an ID, and then deleting | |
with that ID. This is for the second step. | |
""" | |
self.table.delete(row_id=document_id) | |
def clear(self, **kwargs: Any) -> None: | |
"""Clear the *whole* semantic cache.""" | |
self.table.clear() | |
class FullMd5LLMCache(Base): # type: ignore | |
"""SQLite table for full LLM Cache (all generations).""" | |
__tablename__ = "full_md5_llm_cache" | |
id = Column(String, primary_key=True) | |
prompt_md5 = Column(String, index=True) | |
llm = Column(String, index=True) | |
idx = Column(Integer, index=True) | |
prompt = Column(String) | |
response = Column(String) | |
class SQLAlchemyMd5Cache(BaseCache): | |
"""Cache that uses SQAlchemy as a backend.""" | |
def __init__( | |
self, engine: Engine, cache_schema: Type[FullMd5LLMCache] = FullMd5LLMCache | |
): | |
"""Initialize by creating all tables.""" | |
self.engine = engine | |
self.cache_schema = cache_schema | |
self.cache_schema.metadata.create_all(self.engine) | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Look up based on prompt and llm_string.""" | |
rows = self._search_rows(prompt, llm_string) | |
if rows: | |
return [loads(row[0]) for row in rows] | |
return None | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Update based on prompt and llm_string.""" | |
self._delete_previous(prompt, llm_string) | |
prompt_md5 = self.get_md5(prompt) | |
items = [ | |
self.cache_schema( | |
id=str(uuid.uuid1()), | |
prompt=prompt, | |
prompt_md5=prompt_md5, | |
llm=llm_string, | |
response=dumps(gen), | |
idx=i, | |
) | |
for i, gen in enumerate(return_val) | |
] | |
with Session(self.engine) as session, session.begin(): | |
for item in items: | |
session.merge(item) | |
def _delete_previous(self, prompt: str, llm_string: str) -> None: | |
stmt = ( | |
select(self.cache_schema.response) | |
.where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore | |
.where(self.cache_schema.llm == llm_string) | |
.where(self.cache_schema.prompt == prompt) | |
.order_by(self.cache_schema.idx) | |
) | |
with Session(self.engine) as session, session.begin(): | |
rows = session.execute(stmt).fetchall() | |
for item in rows: | |
session.delete(item) | |
def _search_rows(self, prompt: str, llm_string: str) -> List[Row]: | |
prompt_pd5 = self.get_md5(prompt) | |
stmt = ( | |
select(self.cache_schema.response) | |
.where(self.cache_schema.prompt_md5 == prompt_pd5) # type: ignore | |
.where(self.cache_schema.llm == llm_string) | |
.where(self.cache_schema.prompt == prompt) | |
.order_by(self.cache_schema.idx) | |
) | |
with Session(self.engine) as session: | |
return session.execute(stmt).fetchall() | |
def clear(self, **kwargs: Any) -> None: | |
"""Clear cache.""" | |
with Session(self.engine) as session: | |
session.execute(self.cache_schema.delete()) | |
def get_md5(input_string: str) -> str: | |
return hashlib.md5(input_string.encode()).hexdigest() | |
ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_cache" | |
class AstraDBCache(BaseCache): | |
""" | |
Cache that uses Astra DB as a backend. | |
It uses a single collection as a kv store | |
The lookup keys, combined in the _id of the documents, are: | |
- prompt, a string | |
- llm_string, a deterministic str representation of the model parameters. | |
(needed to prevent same-prompt-different-model collisions) | |
""" | |
def __init__( | |
self, | |
*, | |
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, | |
token: Optional[str] = None, | |
api_endpoint: Optional[str] = None, | |
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed | |
namespace: Optional[str] = None, | |
): | |
""" | |
Create an AstraDB cache using a collection for storage. | |
Args (only keyword-arguments accepted): | |
collection_name (str): name of the Astra DB collection to create/use. | |
token (Optional[str]): API token for Astra DB usage. | |
api_endpoint (Optional[str]): full URL to the API endpoint, | |
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com". | |
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*, | |
you can pass an already-created 'astrapy.db.AstraDB' instance. | |
namespace (Optional[str]): namespace (aka keyspace) where the | |
collection is created. Defaults to the database's "default namespace". | |
""" | |
try: | |
from astrapy.db import ( | |
AstraDB as LibAstraDB, | |
) | |
except (ImportError, ModuleNotFoundError): | |
raise ImportError( | |
"Could not import a recent astrapy python package. " | |
"Please install it with `pip install --upgrade astrapy`." | |
) | |
# Conflicting-arg checks: | |
if astra_db_client is not None: | |
if token is not None or api_endpoint is not None: | |
raise ValueError( | |
"You cannot pass 'astra_db_client' to AstraDB if passing " | |
"'token' and 'api_endpoint'." | |
) | |
self.collection_name = collection_name | |
self.token = token | |
self.api_endpoint = api_endpoint | |
self.namespace = namespace | |
if astra_db_client is not None: | |
self.astra_db = astra_db_client | |
else: | |
self.astra_db = LibAstraDB( | |
token=self.token, | |
api_endpoint=self.api_endpoint, | |
namespace=self.namespace, | |
) | |
self.collection = self.astra_db.create_collection( | |
collection_name=self.collection_name, | |
) | |
def _make_id(prompt: str, llm_string: str) -> str: | |
return f"{_hash(prompt)}#{_hash(llm_string)}" | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Look up based on prompt and llm_string.""" | |
doc_id = self._make_id(prompt, llm_string) | |
item = self.collection.find_one( | |
filter={ | |
"_id": doc_id, | |
}, | |
projection={ | |
"body_blob": 1, | |
}, | |
)["data"]["document"] | |
if item is not None: | |
generations = _loads_generations(item["body_blob"]) | |
# this protects against malformed cached items: | |
if generations is not None: | |
return generations | |
else: | |
return None | |
else: | |
return None | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Update cache based on prompt and llm_string.""" | |
doc_id = self._make_id(prompt, llm_string) | |
blob = _dumps_generations(return_val) | |
self.collection.upsert( | |
{ | |
"_id": doc_id, | |
"body_blob": blob, | |
}, | |
) | |
def delete_through_llm( | |
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None | |
) -> None: | |
""" | |
A wrapper around `delete` with the LLM being passed. | |
In case the llm(prompt) calls have a `stop` param, you should pass it here | |
""" | |
llm_string = get_prompts( | |
{**llm.dict(), **{"stop": stop}}, | |
[], | |
)[1] | |
return self.delete(prompt, llm_string=llm_string) | |
def delete(self, prompt: str, llm_string: str) -> None: | |
"""Evict from cache if there's an entry.""" | |
doc_id = self._make_id(prompt, llm_string) | |
return self.collection.delete_one(doc_id) | |
def clear(self, **kwargs: Any) -> None: | |
"""Clear cache. This is for all LLMs at once.""" | |
self.astra_db.truncate_collection(self.collection_name) | |
ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD = 0.85 | |
ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_semantic_cache" | |
ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16 | |
class AstraDBSemanticCache(BaseCache): | |
""" | |
Cache that uses Astra DB as a vector-store backend for semantic | |
(i.e. similarity-based) lookup. | |
It uses a single (vector) collection and can store | |
cached values from several LLMs, so the LLM's 'llm_string' is stored | |
in the document metadata. | |
You can choose the preferred similarity (or use the API default) -- | |
remember the threshold might require metric-dependend tuning. | |
""" | |
def __init__( | |
self, | |
*, | |
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, | |
token: Optional[str] = None, | |
api_endpoint: Optional[str] = None, | |
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed | |
namespace: Optional[str] = None, | |
embedding: Embeddings, | |
metric: Optional[str] = None, | |
similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD, | |
): | |
""" | |
Initialize the cache with all relevant parameters. | |
Args: | |
collection_name (str): name of the Astra DB collection to create/use. | |
token (Optional[str]): API token for Astra DB usage. | |
api_endpoint (Optional[str]): full URL to the API endpoint, | |
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com". | |
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*, | |
you can pass an already-created 'astrapy.db.AstraDB' instance. | |
namespace (Optional[str]): namespace (aka keyspace) where the | |
collection is created. Defaults to the database's "default namespace". | |
embedding (Embedding): Embedding provider for semantic | |
encoding and search. | |
metric: the function to use for evaluating similarity of text embeddings. | |
Defaults to 'cosine' (alternatives: 'euclidean', 'dot_product') | |
similarity_threshold (float, optional): the minimum similarity | |
for accepting a (semantic-search) match. | |
The default score threshold is tuned to the default metric. | |
Tune it carefully yourself if switching to another distance metric. | |
""" | |
try: | |
from astrapy.db import ( | |
AstraDB as LibAstraDB, | |
) | |
except (ImportError, ModuleNotFoundError): | |
raise ImportError( | |
"Could not import a recent astrapy python package. " | |
"Please install it with `pip install --upgrade astrapy`." | |
) | |
# Conflicting-arg checks: | |
if astra_db_client is not None: | |
if token is not None or api_endpoint is not None: | |
raise ValueError( | |
"You cannot pass 'astra_db_client' to AstraDB if passing " | |
"'token' and 'api_endpoint'." | |
) | |
self.embedding = embedding | |
self.metric = metric | |
self.similarity_threshold = similarity_threshold | |
# The contract for this class has separate lookup and update: | |
# in order to spare some embedding calculations we cache them between | |
# the two calls. | |
# Note: each instance of this class has its own `_get_embedding` with | |
# its own lru. | |
def _cache_embedding(text: str) -> List[float]: | |
return self.embedding.embed_query(text=text) | |
self._get_embedding = _cache_embedding | |
self.embedding_dimension = self._get_embedding_dimension() | |
self.collection_name = collection_name | |
self.token = token | |
self.api_endpoint = api_endpoint | |
self.namespace = namespace | |
if astra_db_client is not None: | |
self.astra_db = astra_db_client | |
else: | |
self.astra_db = LibAstraDB( | |
token=self.token, | |
api_endpoint=self.api_endpoint, | |
namespace=self.namespace, | |
) | |
self.collection = self.astra_db.create_collection( | |
collection_name=self.collection_name, | |
dimension=self.embedding_dimension, | |
metric=self.metric, | |
) | |
def _get_embedding_dimension(self) -> int: | |
return len(self._get_embedding(text="This is a sample sentence.")) | |
def _make_id(prompt: str, llm_string: str) -> str: | |
return f"{_hash(prompt)}#{_hash(llm_string)}" | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Update cache based on prompt and llm_string.""" | |
doc_id = self._make_id(prompt, llm_string) | |
llm_string_hash = _hash(llm_string) | |
embedding_vector = self._get_embedding(text=prompt) | |
body = _dumps_generations(return_val) | |
# | |
self.collection.upsert( | |
{ | |
"_id": doc_id, | |
"body_blob": body, | |
"llm_string_hash": llm_string_hash, | |
"$vector": embedding_vector, | |
} | |
) | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Look up based on prompt and llm_string.""" | |
hit_with_id = self.lookup_with_id(prompt, llm_string) | |
if hit_with_id is not None: | |
return hit_with_id[1] | |
else: | |
return None | |
def lookup_with_id( | |
self, prompt: str, llm_string: str | |
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: | |
""" | |
Look up based on prompt and llm_string. | |
If there are hits, return (document_id, cached_entry) for the top hit | |
""" | |
prompt_embedding: List[float] = self._get_embedding(text=prompt) | |
llm_string_hash = _hash(llm_string) | |
hit = self.collection.vector_find_one( | |
vector=prompt_embedding, | |
filter={ | |
"llm_string_hash": llm_string_hash, | |
}, | |
fields=["body_blob", "_id"], | |
include_similarity=True, | |
) | |
if hit is None or hit["$similarity"] < self.similarity_threshold: | |
return None | |
else: | |
generations = _loads_generations(hit["body_blob"]) | |
if generations is not None: | |
# this protects against malformed cached items: | |
return (hit["_id"], generations) | |
else: | |
return None | |
def lookup_with_id_through_llm( | |
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None | |
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: | |
llm_string = get_prompts( | |
{**llm.dict(), **{"stop": stop}}, | |
[], | |
)[1] | |
return self.lookup_with_id(prompt, llm_string=llm_string) | |
def delete_by_document_id(self, document_id: str) -> None: | |
""" | |
Given this is a "similarity search" cache, an invalidation pattern | |
that makes sense is first a lookup to get an ID, and then deleting | |
with that ID. This is for the second step. | |
""" | |
self.collection.delete_one(document_id) | |
def clear(self, **kwargs: Any) -> None: | |
"""Clear the *whole* semantic cache.""" | |
self.astra_db.truncate_collection(self.collection_name) | |