"""Langchain Wrapper around Sambanova LLM APIs.""" import json from typing import Any, Dict, Generator, Iterator, List, Optional, Union import requests from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk from langchain_core.pydantic_v1 import Extra from langchain_core.utils import get_from_dict_or_env, pre_init class SSEndpointHandler: """ SambaNova Systems Interface for SambaStudio model endpoints. :param str host_url: Base URL of the DaaS API service """ def __init__(self, host_url: str, api_base_uri: str): """ Initialize the SSEndpointHandler. :param str host_url: Base URL of the DaaS API service :param str api_base_uri: Base URI of the DaaS API service """ self.host_url = host_url self.api_base_uri = api_base_uri self.http_session = requests.Session() def _process_response(self, response: requests.Response) -> Dict: """ Processes the API response and returns the resulting dict. All resulting dicts, regardless of success or failure, will contain the `status_code` key with the API response status code. If the API returned an error, the resulting dict will contain the key `detail` with the error message. If the API call was successful, the resulting dict will contain the key `data` with the response data. :param requests.Response response: the response object to process :return: the response dict :type: dict """ result: Dict[str, Any] = {} try: result = response.json() except Exception as e: result['detail'] = str(e) if 'status_code' not in result: result['status_code'] = response.status_code return result def _process_streaming_response( self, response: requests.Response, ) -> Generator[Dict, None, None]: """Process the streaming response""" if 'api/predict/nlp' in self.api_base_uri: try: import sseclient except ImportError: raise ImportError( 'could not import sseclient library' 'Please install it with `pip install sseclient-py`.' ) client = sseclient.SSEClient(response) close_conn = False for event in client.events(): if event.event == 'error_event': close_conn = True chunk = { 'event': event.event, 'data': event.data, 'status_code': response.status_code, } yield chunk if close_conn: client.close() elif 'api/v2/predict/generic' in self.api_base_uri or 'api/predict/generic' in self.api_base_uri: try: for line in response.iter_lines(): chunk = json.loads(line) if 'status_code' not in chunk: chunk['status_code'] = response.status_code yield chunk except Exception as e: raise RuntimeError(f'Error processing streaming response: {e}') else: raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented') def _get_full_url(self, path: str) -> str: """ Return the full API URL for a given path. :param str path: the sub-path :returns: the full API URL for the sub-path :type: str """ return f'{self.host_url}/{self.api_base_uri}/{path}' def nlp_predict( self, project: str, endpoint: str, key: str, input: Union[List[str], str], params: Optional[str] = '', stream: bool = False, ) -> Dict: """ NLP predict using inline input string. :param str project: Project ID in which the endpoint exists :param str endpoint: Endpoint ID :param str key: API Key :param str input_str: Input string :param str params: Input params string :returns: Prediction results :type: dict """ if isinstance(input, str): input = [input] if 'api/predict/nlp' in self.api_base_uri: if params: data = {'inputs': input, 'params': json.loads(params)} else: data = {'inputs': input} elif 'api/v2/predict/generic' in self.api_base_uri: items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(input)] if params: data = {'items': items, 'params': json.loads(params)} else: data = {'items': items} elif 'api/predict/generic' in self.api_base_uri: if params: data = {'instances': input, 'params': json.loads(params)} else: data = {'instances': input} else: raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented') response = self.http_session.post( self._get_full_url(f'{project}/{endpoint}'), headers={'key': key}, json=data, ) return self._process_response(response) def nlp_predict_stream( self, project: str, endpoint: str, key: str, input: Union[List[str], str], params: Optional[str] = '', ) -> Iterator[Dict]: """ NLP predict using inline input string. :param str project: Project ID in which the endpoint exists :param str endpoint: Endpoint ID :param str key: API Key :param str input_str: Input string :param str params: Input params string :returns: Prediction results :type: dict """ if 'api/predict/nlp' in self.api_base_uri: if isinstance(input, str): input = [input] if params: data = {'inputs': input, 'params': json.loads(params)} else: data = {'inputs': input} elif 'api/v2/predict/generic' in self.api_base_uri: if isinstance(input, str): input = [input] items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(input)] if params: data = {'items': items, 'params': json.loads(params)} else: data = {'items': items} elif 'api/predict/generic' in self.api_base_uri: if isinstance(input, list): input = input[0] if params: data = {'instance': input, 'params': json.loads(params)} else: data = {'instance': input} else: raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented') # Streaming output response = self.http_session.post( self._get_full_url(f'stream/{project}/{endpoint}'), headers={'key': key}, json=data, stream=True, ) for chunk in self._process_streaming_response(response): yield chunk class SambaStudio(LLM): """ SambaStudio large language models. To use, you should have the environment variables ``SAMBASTUDIO_BASE_URL`` set with your SambaStudio environment URL. ``SAMBASTUDIO_BASE_URI`` set with your SambaStudio api base URI. ``SAMBASTUDIO_PROJECT_ID`` set with your SambaStudio project ID. ``SAMBASTUDIO_ENDPOINT_ID`` set with your SambaStudio endpoint ID. ``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key. https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite read extra documentation in https://docs.sambanova.ai/sambastudio/latest/index.html Example: .. code-block:: python from langchain_community.llms.sambanova import SambaStudio SambaStudio( sambastudio_base_url="your-SambaStudio-environment-URL", sambastudio_base_uri="your-SambaStudio-base-URI", sambastudio_project_id="your-SambaStudio-project-ID", sambastudio_endpoint_id="your-SambaStudio-endpoint-ID", sambastudio_api_key="your-SambaStudio-endpoint-API-key, streaming=False model_kwargs={ "do_sample": False, "max_tokens_to_generate": 1000, "temperature": 0.7, "top_p": 1.0, "repetition_penalty": 1, "top_k": 50, #"process_prompt": False, #"select_expert": "Meta-Llama-3-8B-Instruct" }, ) """ sambastudio_base_url: str = '' """Base url to use""" sambastudio_base_uri: str = '' """endpoint base uri""" sambastudio_project_id: str = '' """Project id on sambastudio for model""" sambastudio_endpoint_id: str = '' """endpoint id on sambastudio for model""" sambastudio_api_key: str = '' """sambastudio api key""" model_kwargs: Optional[dict] = None """Key word arguments to pass to the model.""" streaming: Optional[bool] = False """Streaming flag to get streamed response.""" class Config: """Configuration for this pydantic object.""" extra = Extra.forbid @classmethod def is_lc_serializable(cls) -> bool: return True @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" return {**{'model_kwargs': self.model_kwargs}} @property def _llm_type(self) -> str: """Return type of llm.""" return 'Sambastudio LLM' @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values['sambastudio_base_url'] = get_from_dict_or_env(values, 'sambastudio_base_url', 'SAMBASTUDIO_BASE_URL') values['sambastudio_base_uri'] = get_from_dict_or_env( values, 'sambastudio_base_uri', 'SAMBASTUDIO_BASE_URI', default='api/predict/generic', ) values['sambastudio_project_id'] = get_from_dict_or_env( values, 'sambastudio_project_id', 'SAMBASTUDIO_PROJECT_ID' ) values['sambastudio_endpoint_id'] = get_from_dict_or_env( values, 'sambastudio_endpoint_id', 'SAMBASTUDIO_ENDPOINT_ID' ) values['sambastudio_api_key'] = get_from_dict_or_env(values, 'sambastudio_api_key', 'SAMBASTUDIO_API_KEY') return values def _get_tuning_params(self, stop: Optional[List[str]]) -> str: """ Get the tuning parameters to use when calling the LLM. Args: stop: Stop words to use when generating. Model output is cut off at the first occurrence of any of the stop substrings. Returns: The tuning parameters as a JSON string. """ _model_kwargs = self.model_kwargs or {} _kwarg_stop_sequences = _model_kwargs.get('stop_sequences', []) _stop_sequences = stop or _kwarg_stop_sequences # if not _kwarg_stop_sequences: # _model_kwargs["stop_sequences"] = ",".join( # f'"{x}"' for x in _stop_sequences # ) if 'api/v2/predict/generic' in self.sambastudio_base_uri: tuning_params_dict = _model_kwargs else: tuning_params_dict = {k: {'type': type(v).__name__, 'value': str(v)} for k, v in (_model_kwargs.items())} # _model_kwargs["stop_sequences"] = _kwarg_stop_sequences tuning_params = json.dumps(tuning_params_dict) return tuning_params def _handle_nlp_predict(self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str) -> str: """ Perform an NLP prediction using the SambaStudio endpoint handler. Args: sdk: The SSEndpointHandler to use for the prediction. prompt: The prompt to use for the prediction. tuning_params: The tuning parameters to use for the prediction. Returns: The prediction result. Raises: ValueError: If the prediction fails. """ response = sdk.nlp_predict( self.sambastudio_project_id, self.sambastudio_endpoint_id, self.sambastudio_api_key, prompt, tuning_params, ) if response['status_code'] != 200: optional_detail = response.get('detail') if optional_detail: raise RuntimeError( f"Sambanova /complete call failed with status code " f"{response['status_code']}.\n Details: {optional_detail}" ) else: raise RuntimeError( f"Sambanova /complete call failed with status code " f"{response['status_code']}.\n response {response}" ) if 'api/predict/nlp' in self.sambastudio_base_uri: return response['data'][0]['completion'] elif 'api/v2/predict/generic' in self.sambastudio_base_uri: return response['items'][0]['value']['completion'] elif 'api/predict/generic' in self.sambastudio_base_uri: return response['predictions'][0]['completion'] else: raise ValueError(f'handling of endpoint uri: {self.sambastudio_base_uri} not implemented') def _handle_completion_requests(self, prompt: Union[List[str], str], stop: Optional[List[str]]) -> str: """ Perform a prediction using the SambaStudio endpoint handler. Args: prompt: The prompt to use for the prediction. stop: stop sequences. Returns: The prediction result. Raises: ValueError: If the prediction fails. """ ss_endpoint = SSEndpointHandler(self.sambastudio_base_url, self.sambastudio_base_uri) tuning_params = self._get_tuning_params(stop) return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params) def _handle_nlp_predict_stream( self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str ) -> Iterator[GenerationChunk]: """ Perform a streaming request to the LLM. Args: sdk: The SVEndpointHandler to use for the prediction. prompt: The prompt to use for the prediction. tuning_params: The tuning parameters to use for the prediction. Returns: An iterator of GenerationChunks. """ for chunk in sdk.nlp_predict_stream( self.sambastudio_project_id, self.sambastudio_endpoint_id, self.sambastudio_api_key, prompt, tuning_params, ): if chunk['status_code'] != 200: error = chunk.get('error') if error: optional_code = error.get('code') optional_details = error.get('details') optional_message = error.get('message') raise ValueError( f"Sambanova /complete call failed with status code " f"{chunk['status_code']}.\n" f"Message: {optional_message}\n" f"Details: {optional_details}\n" f"Code: {optional_code}\n" ) else: raise RuntimeError( f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}." ) if 'api/predict/nlp' in self.sambastudio_base_uri: text = json.loads(chunk['data'])['stream_token'] elif 'api/v2/predict/generic' in self.sambastudio_base_uri: text = chunk['result']['items'][0]['value']['stream_token'] elif 'api/predict/generic' in self.sambastudio_base_uri: if len(chunk['result']['responses']) > 0: text = chunk['result']['responses'][0]['stream_token'] else: text = '' else: raise ValueError(f'handling of endpoint uri: {self.sambastudio_base_uri}' f'not implemented') generated_chunk = GenerationChunk(text=text) yield generated_chunk def _stream( self, prompt: Union[List[str], str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: """Call out to Sambanova's complete endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. """ ss_endpoint = SSEndpointHandler(self.sambastudio_base_url, self.sambastudio_base_uri) tuning_params = self._get_tuning_params(stop) try: if self.streaming: for chunk in self._handle_nlp_predict_stream(ss_endpoint, prompt, tuning_params): if run_manager: run_manager.on_llm_new_token(chunk.text) yield chunk else: return except Exception as e: # Handle any errors raised by the inference endpoint raise ValueError(f'Error raised by the inference endpoint: {e}') from e def _handle_stream_request( self, prompt: Union[List[str], str], stop: Optional[List[str]], run_manager: Optional[CallbackManagerForLLMRun], kwargs: Dict[str, Any], ) -> str: """ Perform a streaming request to the LLM. Args: prompt: The prompt to generate from. stop: Stop words to use when generating. Model output is cut off at the first occurrence of any of the stop substrings. run_manager: Callback manager for the run. **kwargs: Additional keyword arguments. directly passed to the sambastudio model in API call. Returns: The model output as a string. """ completion = '' for chunk in self._stream(prompt=prompt, stop=stop, run_manager=run_manager, **kwargs): completion += chunk.text return completion def _call( self, prompt: Union[List[str], str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """Call out to Sambanova's complete endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. """ if stop is not None: raise Exception('stop not implemented') try: if self.streaming: return self._handle_stream_request(prompt, stop, run_manager, kwargs) return self._handle_completion_requests(prompt, stop) except Exception as e: # Handle any errors raised by the inference endpoint raise ValueError(f'Error raised by the inference endpoint: {e}') from e class SambaNovaCloud(LLM): """ SambaNova Cloud large language models. To use, you should have the environment variables ``SAMBANOVA_URL`` set with your SambaNova Cloud URL. ``SAMBANOVA_API_KEY`` set with your SambaNova Cloud API Key. http://cloud.sambanova.ai/ Example: .. code-block:: python SambaNovaCloud( sambanova_url = SambaNova cloud endpoint URL, sambanova_api_key = set with your SambaNova cloud API key, max_tokens = mas number of tokens to generate stop_tokens = list of stop tokens model = model name ) """ sambanova_url: str = '' """SambaNova Cloud Url""" sambanova_api_key: str = '' """SambaNova Cloud api key""" max_tokens: int = 1024 """max tokens to generate""" stop_tokens: list = ['<|eot_id|>'] """Stop tokens""" model: str = 'llama3-8b' """LLM model expert to use""" temperature: float = 0.0 """model temperature""" top_p: float = 0.0 """model top p""" top_k: int = 1 """model top k""" stream_api: bool = True """use stream api""" stream_options: dict = {'include_usage': True} """stream options, include usage to get generation metrics""" class Config: """Configuration for this pydantic object.""" extra = Extra.forbid @classmethod def is_lc_serializable(cls) -> bool: return True @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" return { 'model': self.model, 'max_tokens': self.max_tokens, 'stop': self.stop_tokens, 'temperature': self.temperature, 'top_p': self.top_p, 'top_k': self.top_k, } @property def _llm_type(self) -> str: """Return type of llm.""" return 'SambaNova Cloud' @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values['sambanova_url'] = get_from_dict_or_env( values, 'sambanova_url', 'SAMBANOVA_URL', default='https://api.sambanova.ai/v1/chat/completions' ) values['sambanova_api_key'] = get_from_dict_or_env(values, 'sambanova_api_key', 'SAMBANOVA_API_KEY') return values def _handle_nlp_predict_stream( self, prompt: Union[List[str], str], stop: List[str], ) -> Iterator[GenerationChunk]: """ Perform a streaming request to the LLM. Args: prompt: The prompt to use for the prediction. stop: list of stop tokens Returns: An iterator of GenerationChunks. """ try: import sseclient except ImportError: raise ImportError('could not import sseclient library' 'Please install it with `pip install sseclient-py`.') try: formatted_prompt = json.loads(prompt) except: formatted_prompt = [{'role': 'user', 'content': prompt}] http_session = requests.Session() if not stop: stop = self.stop_tokens data = { 'messages': formatted_prompt, 'max_tokens': self.max_tokens, 'stop': stop, 'model': self.model, 'temperature': self.temperature, 'top_p': self.top_p, 'top_k': self.top_k, 'stream': self.stream_api, 'stream_options': self.stream_options, } # Streaming output response = http_session.post( self.sambanova_url, headers={'Authorization': f'Bearer {self.sambanova_api_key}', 'Content-Type': 'application/json'}, json=data, stream=True, ) client = sseclient.SSEClient(response) close_conn = False if response.status_code != 200: raise RuntimeError( f'Sambanova /complete call failed with status code ' f'{response.status_code}.' f'{response.text}.' ) for event in client.events(): if event.event == 'error_event': close_conn = True chunk = { 'event': event.event, 'data': event.data, 'status_code': response.status_code, } if chunk.get('error'): raise RuntimeError( f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}." ) try: # check if the response is a final event in that case event data response is '[DONE]' if chunk['data'] != '[DONE]': data = json.loads(chunk['data']) if data.get('error'): raise RuntimeError( f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}." ) # check if the response is a final response with usage stats (not includes content) if data.get('usage') is None: # check is not "end of text" response if data['choices'][0]['finish_reason'] is None: text = data['choices'][0]['delta']['content'] generated_chunk = GenerationChunk(text=text) yield generated_chunk except Exception as e: raise Exception(f'Error getting content chunk raw streamed response: {chunk}') def _stream( self, prompt: Union[List[str], str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: """Call out to Sambanova's complete endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. """ try: for chunk in self._handle_nlp_predict_stream(prompt, stop): if run_manager: run_manager.on_llm_new_token(chunk.text) yield chunk except Exception as e: # Handle any errors raised by the inference endpoint raise ValueError(f'Error raised by the inference endpoint: {e}') from e def _handle_stream_request( self, prompt: Union[List[str], str], stop: Optional[List[str]], run_manager: Optional[CallbackManagerForLLMRun], kwargs: Dict[str, Any], ) -> str: """ Perform a streaming request to the LLM. Args: prompt: The prompt to generate from. stop: Stop words to use when generating. Model output is cut off at the first occurrence of any of the stop substrings. run_manager: Callback manager for the run. **kwargs: Additional keyword arguments. directly passed to the Sambanova Cloud model in API call. Returns: The model output as a string. """ completion = '' for chunk in self._stream(prompt=prompt, stop=stop, run_manager=run_manager, **kwargs): completion += chunk.text return completion def _call( self, prompt: Union[List[str], str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """Call out to Sambanova's complete endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. """ try: return self._handle_stream_request(prompt, stop, run_manager, kwargs) except Exception as e: # Handle any errors raised by the inference endpoint raise ValueError(f'Error raised by the inference endpoint: {e}') from e