|
"""Langchain Wrapper around Sambanova LLM APIs.""" |
|
|
|
import json |
|
from typing import Any, Dict, Generator, Iterator, List, Optional, Union |
|
|
|
import requests |
|
from langchain_core.callbacks.manager import CallbackManagerForLLMRun |
|
from langchain_core.language_models.llms import LLM |
|
from langchain_core.outputs import GenerationChunk |
|
from langchain_core.pydantic_v1 import Extra |
|
from langchain_core.utils import get_from_dict_or_env, pre_init |
|
|
|
|
|
class SSEndpointHandler: |
|
""" |
|
SambaNova Systems Interface for SambaStudio model endpoints. |
|
|
|
:param str host_url: Base URL of the DaaS API service |
|
""" |
|
|
|
def __init__(self, host_url: str, api_base_uri: str): |
|
""" |
|
Initialize the SSEndpointHandler. |
|
|
|
:param str host_url: Base URL of the DaaS API service |
|
:param str api_base_uri: Base URI of the DaaS API service |
|
""" |
|
self.host_url = host_url |
|
self.api_base_uri = api_base_uri |
|
self.http_session = requests.Session() |
|
|
|
def _process_response(self, response: requests.Response) -> Dict: |
|
""" |
|
Processes the API response and returns the resulting dict. |
|
|
|
All resulting dicts, regardless of success or failure, will contain the |
|
`status_code` key with the API response status code. |
|
|
|
If the API returned an error, the resulting dict will contain the key |
|
`detail` with the error message. |
|
|
|
If the API call was successful, the resulting dict will contain the key |
|
`data` with the response data. |
|
|
|
:param requests.Response response: the response object to process |
|
:return: the response dict |
|
:type: dict |
|
""" |
|
result: Dict[str, Any] = {} |
|
try: |
|
result = response.json() |
|
except Exception as e: |
|
result['detail'] = str(e) |
|
if 'status_code' not in result: |
|
result['status_code'] = response.status_code |
|
return result |
|
|
|
def _process_streaming_response( |
|
self, |
|
response: requests.Response, |
|
) -> Generator[Dict, None, None]: |
|
"""Process the streaming response""" |
|
if 'api/predict/nlp' in self.api_base_uri: |
|
try: |
|
import sseclient |
|
except ImportError: |
|
raise ImportError( |
|
'could not import sseclient library' 'Please install it with `pip install sseclient-py`.' |
|
) |
|
client = sseclient.SSEClient(response) |
|
close_conn = False |
|
for event in client.events(): |
|
if event.event == 'error_event': |
|
close_conn = True |
|
chunk = { |
|
'event': event.event, |
|
'data': event.data, |
|
'status_code': response.status_code, |
|
} |
|
yield chunk |
|
if close_conn: |
|
client.close() |
|
elif 'api/v2/predict/generic' in self.api_base_uri or 'api/predict/generic' in self.api_base_uri: |
|
try: |
|
for line in response.iter_lines(): |
|
chunk = json.loads(line) |
|
if 'status_code' not in chunk: |
|
chunk['status_code'] = response.status_code |
|
yield chunk |
|
except Exception as e: |
|
raise RuntimeError(f'Error processing streaming response: {e}') |
|
else: |
|
raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented') |
|
|
|
def _get_full_url(self, path: str) -> str: |
|
""" |
|
Return the full API URL for a given path. |
|
|
|
:param str path: the sub-path |
|
:returns: the full API URL for the sub-path |
|
:type: str |
|
""" |
|
return f'{self.host_url}/{self.api_base_uri}/{path}' |
|
|
|
def nlp_predict( |
|
self, |
|
project: str, |
|
endpoint: str, |
|
key: str, |
|
input: Union[List[str], str], |
|
params: Optional[str] = '', |
|
stream: bool = False, |
|
) -> Dict: |
|
""" |
|
NLP predict using inline input string. |
|
|
|
:param str project: Project ID in which the endpoint exists |
|
:param str endpoint: Endpoint ID |
|
:param str key: API Key |
|
:param str input_str: Input string |
|
:param str params: Input params string |
|
:returns: Prediction results |
|
:type: dict |
|
""" |
|
if isinstance(input, str): |
|
input = [input] |
|
if 'api/predict/nlp' in self.api_base_uri: |
|
if params: |
|
data = {'inputs': input, 'params': json.loads(params)} |
|
else: |
|
data = {'inputs': input} |
|
elif 'api/v2/predict/generic' in self.api_base_uri: |
|
items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(input)] |
|
if params: |
|
data = {'items': items, 'params': json.loads(params)} |
|
else: |
|
data = {'items': items} |
|
elif 'api/predict/generic' in self.api_base_uri: |
|
if params: |
|
data = {'instances': input, 'params': json.loads(params)} |
|
else: |
|
data = {'instances': input} |
|
else: |
|
raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented') |
|
response = self.http_session.post( |
|
self._get_full_url(f'{project}/{endpoint}'), |
|
headers={'key': key}, |
|
json=data, |
|
) |
|
return self._process_response(response) |
|
|
|
def nlp_predict_stream( |
|
self, |
|
project: str, |
|
endpoint: str, |
|
key: str, |
|
input: Union[List[str], str], |
|
params: Optional[str] = '', |
|
) -> Iterator[Dict]: |
|
""" |
|
NLP predict using inline input string. |
|
|
|
:param str project: Project ID in which the endpoint exists |
|
:param str endpoint: Endpoint ID |
|
:param str key: API Key |
|
:param str input_str: Input string |
|
:param str params: Input params string |
|
:returns: Prediction results |
|
:type: dict |
|
""" |
|
if 'api/predict/nlp' in self.api_base_uri: |
|
if isinstance(input, str): |
|
input = [input] |
|
if params: |
|
data = {'inputs': input, 'params': json.loads(params)} |
|
else: |
|
data = {'inputs': input} |
|
elif 'api/v2/predict/generic' in self.api_base_uri: |
|
if isinstance(input, str): |
|
input = [input] |
|
items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(input)] |
|
if params: |
|
data = {'items': items, 'params': json.loads(params)} |
|
else: |
|
data = {'items': items} |
|
elif 'api/predict/generic' in self.api_base_uri: |
|
if isinstance(input, list): |
|
input = input[0] |
|
if params: |
|
data = {'instance': input, 'params': json.loads(params)} |
|
else: |
|
data = {'instance': input} |
|
else: |
|
raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented') |
|
|
|
response = self.http_session.post( |
|
self._get_full_url(f'stream/{project}/{endpoint}'), |
|
headers={'key': key}, |
|
json=data, |
|
stream=True, |
|
) |
|
for chunk in self._process_streaming_response(response): |
|
yield chunk |
|
|
|
|
|
class SambaStudio(LLM): |
|
""" |
|
SambaStudio large language models. |
|
|
|
To use, you should have the environment variables |
|
``SAMBASTUDIO_BASE_URL`` set with your SambaStudio environment URL. |
|
``SAMBASTUDIO_BASE_URI`` set with your SambaStudio api base URI. |
|
``SAMBASTUDIO_PROJECT_ID`` set with your SambaStudio project ID. |
|
``SAMBASTUDIO_ENDPOINT_ID`` set with your SambaStudio endpoint ID. |
|
``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key. |
|
|
|
https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite |
|
|
|
read extra documentation in https://docs.sambanova.ai/sambastudio/latest/index.html |
|
|
|
Example: |
|
.. code-block:: python |
|
|
|
from langchain_community.llms.sambanova import SambaStudio |
|
SambaStudio( |
|
sambastudio_base_url="your-SambaStudio-environment-URL", |
|
sambastudio_base_uri="your-SambaStudio-base-URI", |
|
sambastudio_project_id="your-SambaStudio-project-ID", |
|
sambastudio_endpoint_id="your-SambaStudio-endpoint-ID", |
|
sambastudio_api_key="your-SambaStudio-endpoint-API-key, |
|
streaming=False |
|
model_kwargs={ |
|
"do_sample": False, |
|
"max_tokens_to_generate": 1000, |
|
"temperature": 0.7, |
|
"top_p": 1.0, |
|
"repetition_penalty": 1, |
|
"top_k": 50, |
|
#"process_prompt": False, |
|
#"select_expert": "Meta-Llama-3-8B-Instruct" |
|
}, |
|
) |
|
""" |
|
|
|
sambastudio_base_url: str = '' |
|
"""Base url to use""" |
|
|
|
sambastudio_base_uri: str = '' |
|
"""endpoint base uri""" |
|
|
|
sambastudio_project_id: str = '' |
|
"""Project id on sambastudio for model""" |
|
|
|
sambastudio_endpoint_id: str = '' |
|
"""endpoint id on sambastudio for model""" |
|
|
|
sambastudio_api_key: str = '' |
|
"""sambastudio api key""" |
|
|
|
model_kwargs: Optional[dict] = None |
|
"""Key word arguments to pass to the model.""" |
|
|
|
streaming: Optional[bool] = False |
|
"""Streaming flag to get streamed response.""" |
|
|
|
class Config: |
|
"""Configuration for this pydantic object.""" |
|
|
|
extra = Extra.forbid |
|
|
|
@classmethod |
|
def is_lc_serializable(cls) -> bool: |
|
return True |
|
|
|
@property |
|
def _identifying_params(self) -> Dict[str, Any]: |
|
"""Get the identifying parameters.""" |
|
return {**{'model_kwargs': self.model_kwargs}} |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
"""Return type of llm.""" |
|
return 'Sambastudio LLM' |
|
|
|
@pre_init |
|
def validate_environment(cls, values: Dict) -> Dict: |
|
"""Validate that api key and python package exists in environment.""" |
|
values['sambastudio_base_url'] = get_from_dict_or_env(values, 'sambastudio_base_url', 'SAMBASTUDIO_BASE_URL') |
|
values['sambastudio_base_uri'] = get_from_dict_or_env( |
|
values, |
|
'sambastudio_base_uri', |
|
'SAMBASTUDIO_BASE_URI', |
|
default='api/predict/generic', |
|
) |
|
values['sambastudio_project_id'] = get_from_dict_or_env( |
|
values, 'sambastudio_project_id', 'SAMBASTUDIO_PROJECT_ID' |
|
) |
|
values['sambastudio_endpoint_id'] = get_from_dict_or_env( |
|
values, 'sambastudio_endpoint_id', 'SAMBASTUDIO_ENDPOINT_ID' |
|
) |
|
values['sambastudio_api_key'] = get_from_dict_or_env(values, 'sambastudio_api_key', 'SAMBASTUDIO_API_KEY') |
|
return values |
|
|
|
def _get_tuning_params(self, stop: Optional[List[str]]) -> str: |
|
""" |
|
Get the tuning parameters to use when calling the LLM. |
|
|
|
Args: |
|
stop: Stop words to use when generating. Model output is cut off at the |
|
first occurrence of any of the stop substrings. |
|
|
|
Returns: |
|
The tuning parameters as a JSON string. |
|
""" |
|
_model_kwargs = self.model_kwargs or {} |
|
_kwarg_stop_sequences = _model_kwargs.get('stop_sequences', []) |
|
_stop_sequences = stop or _kwarg_stop_sequences |
|
|
|
|
|
|
|
|
|
if 'api/v2/predict/generic' in self.sambastudio_base_uri: |
|
tuning_params_dict = _model_kwargs |
|
else: |
|
tuning_params_dict = {k: {'type': type(v).__name__, 'value': str(v)} for k, v in (_model_kwargs.items())} |
|
|
|
tuning_params = json.dumps(tuning_params_dict) |
|
return tuning_params |
|
|
|
def _handle_nlp_predict(self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str) -> str: |
|
""" |
|
Perform an NLP prediction using the SambaStudio endpoint handler. |
|
|
|
Args: |
|
sdk: The SSEndpointHandler to use for the prediction. |
|
prompt: The prompt to use for the prediction. |
|
tuning_params: The tuning parameters to use for the prediction. |
|
|
|
Returns: |
|
The prediction result. |
|
|
|
Raises: |
|
ValueError: If the prediction fails. |
|
""" |
|
response = sdk.nlp_predict( |
|
self.sambastudio_project_id, |
|
self.sambastudio_endpoint_id, |
|
self.sambastudio_api_key, |
|
prompt, |
|
tuning_params, |
|
) |
|
if response['status_code'] != 200: |
|
optional_detail = response.get('detail') |
|
if optional_detail: |
|
raise RuntimeError( |
|
f"Sambanova /complete call failed with status code " |
|
f"{response['status_code']}.\n Details: {optional_detail}" |
|
) |
|
else: |
|
raise RuntimeError( |
|
f"Sambanova /complete call failed with status code " |
|
f"{response['status_code']}.\n response {response}" |
|
) |
|
if 'api/predict/nlp' in self.sambastudio_base_uri: |
|
return response['data'][0]['completion'] |
|
elif 'api/v2/predict/generic' in self.sambastudio_base_uri: |
|
return response['items'][0]['value']['completion'] |
|
elif 'api/predict/generic' in self.sambastudio_base_uri: |
|
return response['predictions'][0]['completion'] |
|
else: |
|
raise ValueError(f'handling of endpoint uri: {self.sambastudio_base_uri} not implemented') |
|
|
|
def _handle_completion_requests(self, prompt: Union[List[str], str], stop: Optional[List[str]]) -> str: |
|
""" |
|
Perform a prediction using the SambaStudio endpoint handler. |
|
|
|
Args: |
|
prompt: The prompt to use for the prediction. |
|
stop: stop sequences. |
|
|
|
Returns: |
|
The prediction result. |
|
|
|
Raises: |
|
ValueError: If the prediction fails. |
|
""" |
|
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url, self.sambastudio_base_uri) |
|
tuning_params = self._get_tuning_params(stop) |
|
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params) |
|
|
|
def _handle_nlp_predict_stream( |
|
self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str |
|
) -> Iterator[GenerationChunk]: |
|
""" |
|
Perform a streaming request to the LLM. |
|
|
|
Args: |
|
sdk: The SVEndpointHandler to use for the prediction. |
|
prompt: The prompt to use for the prediction. |
|
tuning_params: The tuning parameters to use for the prediction. |
|
|
|
Returns: |
|
An iterator of GenerationChunks. |
|
""" |
|
for chunk in sdk.nlp_predict_stream( |
|
self.sambastudio_project_id, |
|
self.sambastudio_endpoint_id, |
|
self.sambastudio_api_key, |
|
prompt, |
|
tuning_params, |
|
): |
|
if chunk['status_code'] != 200: |
|
error = chunk.get('error') |
|
if error: |
|
optional_code = error.get('code') |
|
optional_details = error.get('details') |
|
optional_message = error.get('message') |
|
raise ValueError( |
|
f"Sambanova /complete call failed with status code " |
|
f"{chunk['status_code']}.\n" |
|
f"Message: {optional_message}\n" |
|
f"Details: {optional_details}\n" |
|
f"Code: {optional_code}\n" |
|
) |
|
else: |
|
raise RuntimeError( |
|
f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}." |
|
) |
|
if 'api/predict/nlp' in self.sambastudio_base_uri: |
|
text = json.loads(chunk['data'])['stream_token'] |
|
elif 'api/v2/predict/generic' in self.sambastudio_base_uri: |
|
text = chunk['result']['items'][0]['value']['stream_token'] |
|
elif 'api/predict/generic' in self.sambastudio_base_uri: |
|
if len(chunk['result']['responses']) > 0: |
|
text = chunk['result']['responses'][0]['stream_token'] |
|
else: |
|
text = '' |
|
else: |
|
raise ValueError(f'handling of endpoint uri: {self.sambastudio_base_uri}' f'not implemented') |
|
generated_chunk = GenerationChunk(text=text) |
|
yield generated_chunk |
|
|
|
def _stream( |
|
self, |
|
prompt: Union[List[str], str], |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> Iterator[GenerationChunk]: |
|
"""Call out to Sambanova's complete endpoint. |
|
|
|
Args: |
|
prompt: The prompt to pass into the model. |
|
stop: Optional list of stop words to use when generating. |
|
|
|
Returns: |
|
The string generated by the model. |
|
""" |
|
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url, self.sambastudio_base_uri) |
|
tuning_params = self._get_tuning_params(stop) |
|
try: |
|
if self.streaming: |
|
for chunk in self._handle_nlp_predict_stream(ss_endpoint, prompt, tuning_params): |
|
if run_manager: |
|
run_manager.on_llm_new_token(chunk.text) |
|
yield chunk |
|
else: |
|
return |
|
except Exception as e: |
|
|
|
raise ValueError(f'Error raised by the inference endpoint: {e}') from e |
|
|
|
def _handle_stream_request( |
|
self, |
|
prompt: Union[List[str], str], |
|
stop: Optional[List[str]], |
|
run_manager: Optional[CallbackManagerForLLMRun], |
|
kwargs: Dict[str, Any], |
|
) -> str: |
|
""" |
|
Perform a streaming request to the LLM. |
|
|
|
Args: |
|
prompt: The prompt to generate from. |
|
stop: Stop words to use when generating. Model output is cut off at the |
|
first occurrence of any of the stop substrings. |
|
run_manager: Callback manager for the run. |
|
**kwargs: Additional keyword arguments. directly passed |
|
to the sambastudio model in API call. |
|
|
|
Returns: |
|
The model output as a string. |
|
""" |
|
completion = '' |
|
for chunk in self._stream(prompt=prompt, stop=stop, run_manager=run_manager, **kwargs): |
|
completion += chunk.text |
|
return completion |
|
|
|
def _call( |
|
self, |
|
prompt: Union[List[str], str], |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> str: |
|
"""Call out to Sambanova's complete endpoint. |
|
|
|
Args: |
|
prompt: The prompt to pass into the model. |
|
stop: Optional list of stop words to use when generating. |
|
|
|
Returns: |
|
The string generated by the model. |
|
""" |
|
if stop is not None: |
|
raise Exception('stop not implemented') |
|
try: |
|
if self.streaming: |
|
return self._handle_stream_request(prompt, stop, run_manager, kwargs) |
|
return self._handle_completion_requests(prompt, stop) |
|
except Exception as e: |
|
|
|
raise ValueError(f'Error raised by the inference endpoint: {e}') from e |
|
|
|
|
|
class SambaNovaCloud(LLM): |
|
""" |
|
SambaNova Cloud large language models. |
|
|
|
To use, you should have the environment variables |
|
``SAMBANOVA_URL`` set with your SambaNova Cloud URL. |
|
``SAMBANOVA_API_KEY`` set with your SambaNova Cloud API Key. |
|
|
|
http://cloud.sambanova.ai/ |
|
|
|
Example: |
|
.. code-block:: python |
|
|
|
SambaNovaCloud( |
|
sambanova_url = SambaNova cloud endpoint URL, |
|
sambanova_api_key = set with your SambaNova cloud API key, |
|
max_tokens = mas number of tokens to generate |
|
stop_tokens = list of stop tokens |
|
model = model name |
|
) |
|
""" |
|
|
|
sambanova_url: str = '' |
|
"""SambaNova Cloud Url""" |
|
|
|
sambanova_api_key: str = '' |
|
"""SambaNova Cloud api key""" |
|
|
|
max_tokens: int = 1024 |
|
"""max tokens to generate""" |
|
|
|
stop_tokens: list = ['<|eot_id|>'] |
|
"""Stop tokens""" |
|
|
|
model: str = 'llama3-8b' |
|
"""LLM model expert to use""" |
|
|
|
temperature: float = 0.0 |
|
"""model temperature""" |
|
|
|
top_p: float = 0.0 |
|
"""model top p""" |
|
|
|
top_k: int = 1 |
|
"""model top k""" |
|
|
|
stream_api: bool = True |
|
"""use stream api""" |
|
|
|
stream_options: dict = {'include_usage': True} |
|
"""stream options, include usage to get generation metrics""" |
|
|
|
class Config: |
|
"""Configuration for this pydantic object.""" |
|
|
|
extra = Extra.forbid |
|
|
|
@classmethod |
|
def is_lc_serializable(cls) -> bool: |
|
return True |
|
|
|
@property |
|
def _identifying_params(self) -> Dict[str, Any]: |
|
"""Get the identifying parameters.""" |
|
return { |
|
'model': self.model, |
|
'max_tokens': self.max_tokens, |
|
'stop': self.stop_tokens, |
|
'temperature': self.temperature, |
|
'top_p': self.top_p, |
|
'top_k': self.top_k, |
|
} |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
"""Return type of llm.""" |
|
return 'SambaNova Cloud' |
|
|
|
@pre_init |
|
def validate_environment(cls, values: Dict) -> Dict: |
|
"""Validate that api key and python package exists in environment.""" |
|
values['sambanova_url'] = get_from_dict_or_env( |
|
values, 'sambanova_url', 'SAMBANOVA_URL', default='https://api.sambanova.ai/v1/chat/completions' |
|
) |
|
values['sambanova_api_key'] = get_from_dict_or_env(values, 'sambanova_api_key', 'SAMBANOVA_API_KEY') |
|
return values |
|
|
|
def _handle_nlp_predict_stream( |
|
self, |
|
prompt: Union[List[str], str], |
|
stop: List[str], |
|
) -> Iterator[GenerationChunk]: |
|
""" |
|
Perform a streaming request to the LLM. |
|
|
|
Args: |
|
prompt: The prompt to use for the prediction. |
|
stop: list of stop tokens |
|
|
|
Returns: |
|
An iterator of GenerationChunks. |
|
""" |
|
try: |
|
import sseclient |
|
except ImportError: |
|
raise ImportError('could not import sseclient library' 'Please install it with `pip install sseclient-py`.') |
|
try: |
|
formatted_prompt = json.loads(prompt) |
|
except: |
|
formatted_prompt = [{'role': 'user', 'content': prompt}] |
|
|
|
http_session = requests.Session() |
|
if not stop: |
|
stop = self.stop_tokens |
|
data = { |
|
'messages': formatted_prompt, |
|
'max_tokens': self.max_tokens, |
|
'stop': stop, |
|
'model': self.model, |
|
'temperature': self.temperature, |
|
'top_p': self.top_p, |
|
'top_k': self.top_k, |
|
'stream': self.stream_api, |
|
'stream_options': self.stream_options, |
|
} |
|
|
|
response = http_session.post( |
|
self.sambanova_url, |
|
headers={'Authorization': f'Bearer {self.sambanova_api_key}', 'Content-Type': 'application/json'}, |
|
json=data, |
|
stream=True, |
|
) |
|
|
|
client = sseclient.SSEClient(response) |
|
close_conn = False |
|
|
|
if response.status_code != 200: |
|
raise RuntimeError( |
|
f'Sambanova /complete call failed with status code ' f'{response.status_code}.' f'{response.text}.' |
|
) |
|
|
|
for event in client.events(): |
|
if event.event == 'error_event': |
|
close_conn = True |
|
chunk = { |
|
'event': event.event, |
|
'data': event.data, |
|
'status_code': response.status_code, |
|
} |
|
|
|
if chunk.get('error'): |
|
raise RuntimeError( |
|
f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}." |
|
) |
|
|
|
try: |
|
|
|
if chunk['data'] != '[DONE]': |
|
data = json.loads(chunk['data']) |
|
if data.get('error'): |
|
raise RuntimeError( |
|
f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}." |
|
) |
|
|
|
if data.get('usage') is None: |
|
|
|
if data['choices'][0]['finish_reason'] is None: |
|
text = data['choices'][0]['delta']['content'] |
|
generated_chunk = GenerationChunk(text=text) |
|
yield generated_chunk |
|
except Exception as e: |
|
raise Exception(f'Error getting content chunk raw streamed response: {chunk}') |
|
|
|
def _stream( |
|
self, |
|
prompt: Union[List[str], str], |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> Iterator[GenerationChunk]: |
|
"""Call out to Sambanova's complete endpoint. |
|
|
|
Args: |
|
prompt: The prompt to pass into the model. |
|
stop: Optional list of stop words to use when generating. |
|
|
|
Returns: |
|
The string generated by the model. |
|
""" |
|
try: |
|
for chunk in self._handle_nlp_predict_stream(prompt, stop): |
|
if run_manager: |
|
run_manager.on_llm_new_token(chunk.text) |
|
yield chunk |
|
except Exception as e: |
|
|
|
raise ValueError(f'Error raised by the inference endpoint: {e}') from e |
|
|
|
def _handle_stream_request( |
|
self, |
|
prompt: Union[List[str], str], |
|
stop: Optional[List[str]], |
|
run_manager: Optional[CallbackManagerForLLMRun], |
|
kwargs: Dict[str, Any], |
|
) -> str: |
|
""" |
|
Perform a streaming request to the LLM. |
|
|
|
Args: |
|
prompt: The prompt to generate from. |
|
stop: Stop words to use when generating. Model output is cut off at the |
|
first occurrence of any of the stop substrings. |
|
run_manager: Callback manager for the run. |
|
**kwargs: Additional keyword arguments. directly passed |
|
to the Sambanova Cloud model in API call. |
|
|
|
Returns: |
|
The model output as a string. |
|
""" |
|
completion = '' |
|
for chunk in self._stream(prompt=prompt, stop=stop, run_manager=run_manager, **kwargs): |
|
completion += chunk.text |
|
return completion |
|
|
|
def _call( |
|
self, |
|
prompt: Union[List[str], str], |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> str: |
|
"""Call out to Sambanova's complete endpoint. |
|
|
|
Args: |
|
prompt: The prompt to pass into the model. |
|
stop: Optional list of stop words to use when generating. |
|
|
|
Returns: |
|
The string generated by the model. |
|
""" |
|
try: |
|
return self._handle_stream_request(prompt, stop, run_manager, kwargs) |
|
except Exception as e: |
|
|
|
raise ValueError(f'Error raised by the inference endpoint: {e}') from e |
|
|