|
"""Langchain Wrapper around Sambanova embedding APIs.""" |
|
|
|
import json |
|
from typing import Dict, Generator, List, Optional |
|
|
|
import requests |
|
from langchain_core.embeddings import Embeddings |
|
from langchain_core.pydantic_v1 import BaseModel |
|
from langchain_core.utils import get_from_dict_or_env, pre_init |
|
|
|
|
|
class SambaStudioEmbeddings(BaseModel, Embeddings): |
|
"""SambaNova embedding models. |
|
|
|
To use, you should have the environment variables |
|
``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_BASE_URI`` |
|
``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``, ``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``, |
|
``SAMBASTUDIO_EMBEDDINGS_API_KEY`` |
|
set with your personal sambastudio variable or pass it as a named parameter |
|
to the constructor. |
|
|
|
Example: |
|
.. code-block:: python |
|
|
|
from langchain_community.embeddings import SambaStudioEmbeddings |
|
|
|
embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url, |
|
sambastudio_embeddings_base_uri=base_uri, |
|
sambastudio_embeddings_project_id=project_id, |
|
sambastudio_embeddings_endpoint_id=endpoint_id, |
|
sambastudio_embeddings_api_key=api_key, |
|
batch_size=32) |
|
(or) |
|
|
|
embeddings = SambaStudioEmbeddings(batch_size=32) |
|
|
|
(or) |
|
|
|
# CoE example |
|
embeddings = SambaStudioEmbeddings( |
|
batch_size=1, |
|
model_kwargs={ |
|
'select_expert':'e5-mistral-7b-instruct' |
|
} |
|
) |
|
""" |
|
|
|
sambastudio_embeddings_base_url: str = '' |
|
"""Base url to use""" |
|
|
|
sambastudio_embeddings_base_uri: str = '' |
|
"""endpoint base uri""" |
|
|
|
sambastudio_embeddings_project_id: str = '' |
|
"""Project id on sambastudio for model""" |
|
|
|
sambastudio_embeddings_endpoint_id: str = '' |
|
"""endpoint id on sambastudio for model""" |
|
|
|
sambastudio_embeddings_api_key: str = '' |
|
"""sambastudio api key""" |
|
|
|
model_kwargs: dict = {} |
|
"""Key word arguments to pass to the model.""" |
|
|
|
batch_size: int = 32 |
|
"""Batch size for the embedding models""" |
|
|
|
@pre_init |
|
def validate_environment(cls, values: Dict) -> Dict: |
|
"""Validate that api key and python package exists in environment.""" |
|
values['sambastudio_embeddings_base_url'] = get_from_dict_or_env( |
|
values, 'sambastudio_embeddings_base_url', 'SAMBASTUDIO_EMBEDDINGS_BASE_URL' |
|
) |
|
values['sambastudio_embeddings_base_uri'] = get_from_dict_or_env( |
|
values, |
|
'sambastudio_embeddings_base_uri', |
|
'SAMBASTUDIO_EMBEDDINGS_BASE_URI', |
|
default='api/predict/generic', |
|
) |
|
values['sambastudio_embeddings_project_id'] = get_from_dict_or_env( |
|
values, |
|
'sambastudio_embeddings_project_id', |
|
'SAMBASTUDIO_EMBEDDINGS_PROJECT_ID', |
|
) |
|
values['sambastudio_embeddings_endpoint_id'] = get_from_dict_or_env( |
|
values, |
|
'sambastudio_embeddings_endpoint_id', |
|
'SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID', |
|
) |
|
values['sambastudio_embeddings_api_key'] = get_from_dict_or_env( |
|
values, 'sambastudio_embeddings_api_key', 'SAMBASTUDIO_EMBEDDINGS_API_KEY' |
|
) |
|
return values |
|
|
|
def _get_tuning_params(self) -> str: |
|
""" |
|
Get the tuning parameters to use when calling the model |
|
|
|
Returns: |
|
The tuning parameters as a JSON string. |
|
""" |
|
if 'api/v2/predict/generic' in self.sambastudio_embeddings_base_uri: |
|
tuning_params_dict = self.model_kwargs |
|
else: |
|
tuning_params_dict = { |
|
k: {'type': type(v).__name__, 'value': str(v)} for k, v in (self.model_kwargs.items()) |
|
} |
|
tuning_params = json.dumps(tuning_params_dict) |
|
return tuning_params |
|
|
|
def _get_full_url(self, path: str) -> str: |
|
""" |
|
Return the full API URL for a given path. |
|
|
|
:param str path: the sub-path |
|
:returns: the full API URL for the sub-path |
|
:rtype: str |
|
""" |
|
return f'{self.sambastudio_embeddings_base_url}/{self.sambastudio_embeddings_base_uri}/{path}' |
|
|
|
def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator: |
|
"""Generator for creating batches in the embed documents method |
|
Args: |
|
texts (List[str]): list of strings to embed |
|
batch_size (int, optional): batch size to be used for the embedding model. |
|
Will depend on the RDU endpoint used. |
|
Yields: |
|
List[str]: list (batch) of strings of size batch size |
|
""" |
|
for i in range(0, len(texts), batch_size): |
|
yield texts[i : i + batch_size] |
|
|
|
def embed_documents(self, texts: List[str], batch_size: Optional[int] = None) -> List[List[float]]: |
|
"""Returns a list of embeddings for the given sentences. |
|
Args: |
|
texts (`List[str]`): List of texts to encode |
|
batch_size (`int`): Batch size for the encoding |
|
|
|
Returns: |
|
`List[np.ndarray]` or `List[tensor]`: List of embeddings |
|
for the given sentences |
|
""" |
|
if batch_size is None: |
|
batch_size = self.batch_size |
|
http_session = requests.Session() |
|
url = self._get_full_url(f'{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}') |
|
params = json.loads(self._get_tuning_params()) |
|
embeddings = [] |
|
|
|
if 'api/predict/nlp' in self.sambastudio_embeddings_base_uri: |
|
for batch in self._iterate_over_batches(texts, batch_size): |
|
data = {'inputs': batch, 'params': params} |
|
response = http_session.post( |
|
url, |
|
headers={'key': self.sambastudio_embeddings_api_key}, |
|
json=data, |
|
) |
|
if response.status_code != 200: |
|
raise RuntimeError( |
|
f'Sambanova /complete call failed with status code ' |
|
f'{response.status_code}.\n Details: {response.text}' |
|
) |
|
try: |
|
embedding = response.json()['data'] |
|
embeddings.extend(embedding) |
|
except KeyError: |
|
raise KeyError( |
|
"'data' not found in endpoint response", |
|
response.json(), |
|
) |
|
|
|
elif 'api/v2/predict/generic' in self.sambastudio_embeddings_base_uri: |
|
for batch in self._iterate_over_batches(texts, batch_size): |
|
items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(batch)] |
|
data = {'items': items, 'params': params} |
|
response = http_session.post( |
|
url, |
|
headers={'key': self.sambastudio_embeddings_api_key}, |
|
json=data, |
|
) |
|
if response.status_code != 200: |
|
raise RuntimeError( |
|
f'Sambanova /complete call failed with status code ' |
|
f'{response.status_code}.\n Details: {response.text}' |
|
) |
|
try: |
|
embedding = [item['value'] for item in response.json()['items']] |
|
embeddings.extend(embedding) |
|
except KeyError: |
|
raise KeyError( |
|
"'items' not found in endpoint response", |
|
response.json(), |
|
) |
|
|
|
elif 'api/predict/generic' in self.sambastudio_embeddings_base_uri: |
|
for batch in self._iterate_over_batches(texts, batch_size): |
|
data = {'instances': batch, 'params': params} |
|
response = http_session.post( |
|
url, |
|
headers={'key': self.sambastudio_embeddings_api_key}, |
|
json=data, |
|
) |
|
if response.status_code != 200: |
|
raise RuntimeError( |
|
f'Sambanova /complete call failed with status code ' |
|
f'{response.status_code}.\n Details: {response.text}' |
|
) |
|
try: |
|
if params.get('select_expert'): |
|
embedding = response.json()['predictions'] |
|
else: |
|
embedding = response.json()['predictions'] |
|
embeddings.extend(embedding) |
|
except KeyError: |
|
raise KeyError( |
|
"'predictions' not found in endpoint response", |
|
response.json(), |
|
) |
|
|
|
else: |
|
raise ValueError( |
|
f'handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented' |
|
) |
|
|
|
return embeddings |
|
|
|
def embed_query(self, text: str) -> List[float]: |
|
"""Returns a list of embeddings for the given sentences. |
|
Args: |
|
sentences (`List[str]`): List of sentences to encode |
|
|
|
Returns: |
|
`List[np.ndarray]` or `List[tensor]`: List of embeddings |
|
for the given sentences |
|
""" |
|
http_session = requests.Session() |
|
url = self._get_full_url(f'{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}') |
|
params = json.loads(self._get_tuning_params()) |
|
|
|
if 'api/predict/nlp' in self.sambastudio_embeddings_base_uri: |
|
data = {'inputs': [text], 'params': params} |
|
response = http_session.post( |
|
url, |
|
headers={'key': self.sambastudio_embeddings_api_key}, |
|
json=data, |
|
) |
|
if response.status_code != 200: |
|
raise RuntimeError( |
|
f'Sambanova /complete call failed with status code ' |
|
f'{response.status_code}.\n Details: {response.text}' |
|
) |
|
try: |
|
embedding = response.json()['data'][0] |
|
except KeyError: |
|
raise KeyError( |
|
"'data' not found in endpoint response", |
|
response.json(), |
|
) |
|
|
|
elif 'api/v2/predict/generic' in self.sambastudio_embeddings_base_uri: |
|
data = {'items': [{'id': 'item0', 'value': text}], 'params': params} |
|
response = http_session.post( |
|
url, |
|
headers={'key': self.sambastudio_embeddings_api_key}, |
|
json=data, |
|
) |
|
if response.status_code != 200: |
|
raise RuntimeError( |
|
f'Sambanova /complete call failed with status code ' |
|
f'{response.status_code}.\n Details: {response.text}' |
|
) |
|
try: |
|
embedding = response.json()['items'][0]['value'] |
|
except KeyError: |
|
raise KeyError( |
|
"'items' not found in endpoint response", |
|
response.json(), |
|
) |
|
|
|
elif 'api/predict/generic' in self.sambastudio_embeddings_base_uri: |
|
data = {'instances': [text], 'params': params} |
|
response = http_session.post( |
|
url, |
|
headers={'key': self.sambastudio_embeddings_api_key}, |
|
json=data, |
|
) |
|
if response.status_code != 200: |
|
raise RuntimeError( |
|
f'Sambanova /complete call failed with status code ' |
|
f'{response.status_code}.\n Details: {response.text}' |
|
) |
|
try: |
|
if params.get('select_expert'): |
|
embedding = response.json()['predictions'][0] |
|
else: |
|
embedding = response.json()['predictions'][0] |
|
except KeyError: |
|
raise KeyError( |
|
"'predictions' not found in endpoint response", |
|
response.json(), |
|
) |
|
|
|
else: |
|
raise ValueError( |
|
f'handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented' |
|
) |
|
|
|
return embedding |
|
|