"""Langchain Wrapper around Sambanova embedding APIs.""" import json from typing import Dict, Generator, List, Optional import requests from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel from langchain_core.utils import get_from_dict_or_env, pre_init class SambaStudioEmbeddings(BaseModel, Embeddings): """SambaNova embedding models. To use, you should have the environment variables ``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_BASE_URI`` ``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``, ``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``, ``SAMBASTUDIO_EMBEDDINGS_API_KEY`` set with your personal sambastudio variable or pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain_community.embeddings import SambaStudioEmbeddings embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url, sambastudio_embeddings_base_uri=base_uri, sambastudio_embeddings_project_id=project_id, sambastudio_embeddings_endpoint_id=endpoint_id, sambastudio_embeddings_api_key=api_key, batch_size=32) (or) embeddings = SambaStudioEmbeddings(batch_size=32) (or) # CoE example embeddings = SambaStudioEmbeddings( batch_size=1, model_kwargs={ 'select_expert':'e5-mistral-7b-instruct' } ) """ sambastudio_embeddings_base_url: str = '' """Base url to use""" sambastudio_embeddings_base_uri: str = '' """endpoint base uri""" sambastudio_embeddings_project_id: str = '' """Project id on sambastudio for model""" sambastudio_embeddings_endpoint_id: str = '' """endpoint id on sambastudio for model""" sambastudio_embeddings_api_key: str = '' """sambastudio api key""" model_kwargs: dict = {} """Key word arguments to pass to the model.""" batch_size: int = 32 """Batch size for the embedding models""" @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values['sambastudio_embeddings_base_url'] = get_from_dict_or_env( values, 'sambastudio_embeddings_base_url', 'SAMBASTUDIO_EMBEDDINGS_BASE_URL' ) values['sambastudio_embeddings_base_uri'] = get_from_dict_or_env( values, 'sambastudio_embeddings_base_uri', 'SAMBASTUDIO_EMBEDDINGS_BASE_URI', default='api/predict/generic', ) values['sambastudio_embeddings_project_id'] = get_from_dict_or_env( values, 'sambastudio_embeddings_project_id', 'SAMBASTUDIO_EMBEDDINGS_PROJECT_ID', ) values['sambastudio_embeddings_endpoint_id'] = get_from_dict_or_env( values, 'sambastudio_embeddings_endpoint_id', 'SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID', ) values['sambastudio_embeddings_api_key'] = get_from_dict_or_env( values, 'sambastudio_embeddings_api_key', 'SAMBASTUDIO_EMBEDDINGS_API_KEY' ) return values def _get_tuning_params(self) -> str: """ Get the tuning parameters to use when calling the model Returns: The tuning parameters as a JSON string. """ if 'api/v2/predict/generic' in self.sambastudio_embeddings_base_uri: tuning_params_dict = self.model_kwargs else: tuning_params_dict = { k: {'type': type(v).__name__, 'value': str(v)} for k, v in (self.model_kwargs.items()) } tuning_params = json.dumps(tuning_params_dict) return tuning_params def _get_full_url(self, path: str) -> str: """ Return the full API URL for a given path. :param str path: the sub-path :returns: the full API URL for the sub-path :rtype: str """ return f'{self.sambastudio_embeddings_base_url}/{self.sambastudio_embeddings_base_uri}/{path}' # noqa: E501 def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator: """Generator for creating batches in the embed documents method Args: texts (List[str]): list of strings to embed batch_size (int, optional): batch size to be used for the embedding model. Will depend on the RDU endpoint used. Yields: List[str]: list (batch) of strings of size batch size """ for i in range(0, len(texts), batch_size): yield texts[i : i + batch_size] def embed_documents(self, texts: List[str], batch_size: Optional[int] = None) -> List[List[float]]: """Returns a list of embeddings for the given sentences. Args: texts (`List[str]`): List of texts to encode batch_size (`int`): Batch size for the encoding Returns: `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences """ if batch_size is None: batch_size = self.batch_size http_session = requests.Session() url = self._get_full_url(f'{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}') params = json.loads(self._get_tuning_params()) embeddings = [] if 'api/predict/nlp' in self.sambastudio_embeddings_base_uri: for batch in self._iterate_over_batches(texts, batch_size): data = {'inputs': batch, 'params': params} response = http_session.post( url, headers={'key': self.sambastudio_embeddings_api_key}, json=data, ) if response.status_code != 200: raise RuntimeError( f'Sambanova /complete call failed with status code ' f'{response.status_code}.\n Details: {response.text}' ) try: embedding = response.json()['data'] embeddings.extend(embedding) except KeyError: raise KeyError( "'data' not found in endpoint response", response.json(), ) elif 'api/v2/predict/generic' in self.sambastudio_embeddings_base_uri: for batch in self._iterate_over_batches(texts, batch_size): items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(batch)] data = {'items': items, 'params': params} response = http_session.post( url, headers={'key': self.sambastudio_embeddings_api_key}, json=data, ) if response.status_code != 200: raise RuntimeError( f'Sambanova /complete call failed with status code ' f'{response.status_code}.\n Details: {response.text}' ) try: embedding = [item['value'] for item in response.json()['items']] embeddings.extend(embedding) except KeyError: raise KeyError( "'items' not found in endpoint response", response.json(), ) elif 'api/predict/generic' in self.sambastudio_embeddings_base_uri: for batch in self._iterate_over_batches(texts, batch_size): data = {'instances': batch, 'params': params} response = http_session.post( url, headers={'key': self.sambastudio_embeddings_api_key}, json=data, ) if response.status_code != 200: raise RuntimeError( f'Sambanova /complete call failed with status code ' f'{response.status_code}.\n Details: {response.text}' ) try: if params.get('select_expert'): embedding = response.json()['predictions'] else: embedding = response.json()['predictions'] embeddings.extend(embedding) except KeyError: raise KeyError( "'predictions' not found in endpoint response", response.json(), ) else: raise ValueError( f'handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented' # noqa: E501 ) return embeddings def embed_query(self, text: str) -> List[float]: """Returns a list of embeddings for the given sentences. Args: sentences (`List[str]`): List of sentences to encode Returns: `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences """ http_session = requests.Session() url = self._get_full_url(f'{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}') params = json.loads(self._get_tuning_params()) if 'api/predict/nlp' in self.sambastudio_embeddings_base_uri: data = {'inputs': [text], 'params': params} response = http_session.post( url, headers={'key': self.sambastudio_embeddings_api_key}, json=data, ) if response.status_code != 200: raise RuntimeError( f'Sambanova /complete call failed with status code ' f'{response.status_code}.\n Details: {response.text}' ) try: embedding = response.json()['data'][0] except KeyError: raise KeyError( "'data' not found in endpoint response", response.json(), ) elif 'api/v2/predict/generic' in self.sambastudio_embeddings_base_uri: data = {'items': [{'id': 'item0', 'value': text}], 'params': params} response = http_session.post( url, headers={'key': self.sambastudio_embeddings_api_key}, json=data, ) if response.status_code != 200: raise RuntimeError( f'Sambanova /complete call failed with status code ' f'{response.status_code}.\n Details: {response.text}' ) try: embedding = response.json()['items'][0]['value'] except KeyError: raise KeyError( "'items' not found in endpoint response", response.json(), ) elif 'api/predict/generic' in self.sambastudio_embeddings_base_uri: data = {'instances': [text], 'params': params} response = http_session.post( url, headers={'key': self.sambastudio_embeddings_api_key}, json=data, ) if response.status_code != 200: raise RuntimeError( f'Sambanova /complete call failed with status code ' f'{response.status_code}.\n Details: {response.text}' ) try: if params.get('select_expert'): embedding = response.json()['predictions'][0] else: embedding = response.json()['predictions'][0] except KeyError: raise KeyError( "'predictions' not found in endpoint response", response.json(), ) else: raise ValueError( f'handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented' # noqa: E501 ) return embedding