import io import json import os import sys import time import traceback import types from copy import deepcopy from enum import Enum from functools import partial from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union import httpx import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.asyncify import asyncify from litellm.litellm_core_utils.prompt_templates.factory import ( custom_prompt, prompt_factory, ) from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, _get_httpx_client, get_async_httpx_client, ) from litellm.types.llms.openai import AllMessageValues from litellm.utils import ( CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage, get_secret, ) from ..common_utils import AWSEventStreamDecoder, SagemakerError from .transformation import SagemakerConfig sagemaker_config = SagemakerConfig() """ SAGEMAKER AUTH Keys/Vars os.environ['AWS_ACCESS_KEY_ID'] = "" os.environ['AWS_SECRET_ACCESS_KEY'] = "" """ # set os.environ['AWS_REGION_NAME'] = class SagemakerLLM(BaseAWSLLM): def _load_credentials( self, optional_params: dict, ): try: from botocore.credentials import Credentials except ImportError: raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") ## CREDENTIALS ## # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_session_token = optional_params.pop("aws_session_token", None) aws_region_name = optional_params.pop("aws_region_name", None) aws_role_name = optional_params.pop("aws_role_name", None) aws_session_name = optional_params.pop("aws_session_name", None) aws_profile_name = optional_params.pop("aws_profile_name", None) optional_params.pop( "aws_bedrock_runtime_endpoint", None ) # https://bedrock-runtime.{region_name}.amazonaws.com aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) ### SET REGION NAME ### if aws_region_name is None: # check env # litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) if litellm_aws_region_name is not None and isinstance( litellm_aws_region_name, str ): aws_region_name = litellm_aws_region_name standard_aws_region_name = get_secret("AWS_REGION", None) if standard_aws_region_name is not None and isinstance( standard_aws_region_name, str ): aws_region_name = standard_aws_region_name if aws_region_name is None: aws_region_name = "us-west-2" credentials: Credentials = self.get_credentials( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, aws_region_name=aws_region_name, aws_session_name=aws_session_name, aws_profile_name=aws_profile_name, aws_role_name=aws_role_name, aws_web_identity_token=aws_web_identity_token, aws_sts_endpoint=aws_sts_endpoint, ) return credentials, aws_region_name def _prepare_request( self, credentials, model: str, data: dict, messages: List[AllMessageValues], optional_params: dict, aws_region_name: str, extra_headers: Optional[dict] = None, ): try: import boto3 from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials except ImportError: raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name) if optional_params.get("stream") is True: api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations-response-stream" else: api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations" sagemaker_base_url = optional_params.get("sagemaker_base_url", None) if sagemaker_base_url is not None: api_base = sagemaker_base_url encoded_data = json.dumps(data).encode("utf-8") headers = sagemaker_config.validate_environment( headers=extra_headers, model=model, messages=messages, optional_params=optional_params, ) request = AWSRequest( method="POST", url=api_base, data=encoded_data, headers=headers ) sigv4.add_auth(request) if ( extra_headers is not None and "Authorization" in extra_headers ): # prevent sigv4 from overwriting the auth header request.headers["Authorization"] = extra_headers["Authorization"] prepped_request = request.prepare() return prepped_request def completion( # noqa: PLR0915 self, model: str, messages: list, model_response: ModelResponse, print_verbose: Callable, encoding, logging_obj, optional_params: dict, litellm_params: dict, timeout: Optional[Union[float, httpx.Timeout]] = None, custom_prompt_dict={}, hf_model_name=None, logger_fn=None, acompletion: bool = False, headers: dict = {}, ): # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker credentials, aws_region_name = self._load_credentials(optional_params) inference_params = deepcopy(optional_params) stream = inference_params.pop("stream", None) model_id = optional_params.get("model_id", None) ## Load Config config = litellm.SagemakerConfig.get_config() for k, v in config.items(): if ( k not in inference_params ): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v if stream is True: if acompletion is True: response = self.async_streaming( messages=messages, model=model, custom_prompt_dict=custom_prompt_dict, hf_model_name=hf_model_name, optional_params=optional_params, encoding=encoding, model_response=model_response, logging_obj=logging_obj, model_id=model_id, aws_region_name=aws_region_name, credentials=credentials, headers=headers, litellm_params=litellm_params, ) return response else: data = sagemaker_config.transform_request( model=model, messages=messages, optional_params=optional_params, litellm_params=litellm_params, headers=headers, ) prepared_request = self._prepare_request( model=model, data=data, messages=messages, optional_params=optional_params, credentials=credentials, aws_region_name=aws_region_name, ) if model_id is not None: # Add model_id as InferenceComponentName header # boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html prepared_request.headers.update( {"X-Amzn-SageMaker-Inference-Component": model_id} ) sync_handler = _get_httpx_client() sync_response = sync_handler.post( url=prepared_request.url, headers=prepared_request.headers, # type: ignore json=data, stream=stream, ) if sync_response.status_code != 200: raise SagemakerError( status_code=sync_response.status_code, message=str(sync_response.read()), ) decoder = AWSEventStreamDecoder(model="") completion_stream = decoder.iter_bytes( sync_response.iter_bytes(chunk_size=1024) ) streaming_response = CustomStreamWrapper( completion_stream=completion_stream, model=model, custom_llm_provider="sagemaker", logging_obj=logging_obj, ) ## LOGGING logging_obj.post_call( input=messages, api_key="", original_response=streaming_response, additional_args={"complete_input_dict": data}, ) return streaming_response # Non-Streaming Requests # Async completion if acompletion is True: return self.async_completion( messages=messages, model=model, custom_prompt_dict=custom_prompt_dict, hf_model_name=hf_model_name, model_response=model_response, encoding=encoding, logging_obj=logging_obj, model_id=model_id, optional_params=optional_params, credentials=credentials, aws_region_name=aws_region_name, headers=headers, litellm_params=litellm_params, ) ## Non-Streaming completion CALL _data = sagemaker_config.transform_request( model=model, messages=messages, optional_params=optional_params, litellm_params=litellm_params, headers=headers, ) prepared_request_args = { "model": model, "data": _data, "optional_params": optional_params, "credentials": credentials, "aws_region_name": aws_region_name, "messages": messages, } prepared_request = self._prepare_request(**prepared_request_args) try: if model_id is not None: # Add model_id as InferenceComponentName header # boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html prepared_request.headers.update( {"X-Amzn-SageMaker-Inference-Component": model_id} ) ## LOGGING timeout = 300.0 sync_handler = _get_httpx_client() ## LOGGING logging_obj.pre_call( input=[], api_key="", additional_args={ "complete_input_dict": _data, "api_base": prepared_request.url, "headers": prepared_request.headers, }, ) # make sync httpx post request here try: sync_response = sync_handler.post( url=prepared_request.url, headers=prepared_request.headers, # type: ignore json=_data, timeout=timeout, ) if sync_response.status_code != 200: raise SagemakerError( status_code=sync_response.status_code, message=sync_response.text, ) except Exception as e: ## LOGGING logging_obj.post_call( input=[], api_key="", original_response=str(e), additional_args={"complete_input_dict": _data}, ) raise e except Exception as e: verbose_logger.error("Sagemaker error %s", str(e)) status_code = ( getattr(e, "response", {}) .get("ResponseMetadata", {}) .get("HTTPStatusCode", 500) ) error_message = ( getattr(e, "response", {}).get("Error", {}).get("Message", str(e)) ) if "Inference Component Name header is required" in error_message: error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`" raise SagemakerError(status_code=status_code, message=error_message) return sagemaker_config.transform_response( model=model, raw_response=sync_response, model_response=model_response, logging_obj=logging_obj, request_data=_data, messages=messages, optional_params=optional_params, encoding=encoding, litellm_params=litellm_params, ) async def make_async_call( self, api_base: str, headers: dict, data: dict, logging_obj, client=None, ): try: if client is None: client = get_async_httpx_client( llm_provider=litellm.LlmProviders.SAGEMAKER ) # Create a new client if none provided response = await client.post( api_base, headers=headers, json=data, stream=True, ) if response.status_code != 200: raise SagemakerError( status_code=response.status_code, message=response.text ) decoder = AWSEventStreamDecoder(model="") completion_stream = decoder.aiter_bytes( response.aiter_bytes(chunk_size=1024) ) return completion_stream # LOGGING logging_obj.post_call( input=[], api_key="", original_response="first stream response received", additional_args={"complete_input_dict": data}, ) except httpx.HTTPStatusError as err: error_code = err.response.status_code raise SagemakerError(status_code=error_code, message=err.response.text) except httpx.TimeoutException: raise SagemakerError(status_code=408, message="Timeout error occurred.") except Exception as e: raise SagemakerError(status_code=500, message=str(e)) async def async_streaming( self, messages: List[AllMessageValues], model: str, custom_prompt_dict: dict, hf_model_name: Optional[str], credentials, aws_region_name: str, optional_params, encoding, model_response: ModelResponse, model_id: Optional[str], logging_obj: Any, litellm_params: dict, headers: dict, ): data = await sagemaker_config.async_transform_request( model=model, messages=messages, optional_params={**optional_params, "stream": True}, litellm_params=litellm_params, headers=headers, ) asyncified_prepare_request = asyncify(self._prepare_request) prepared_request_args = { "model": model, "data": data, "optional_params": optional_params, "credentials": credentials, "aws_region_name": aws_region_name, "messages": messages, } prepared_request = await asyncified_prepare_request(**prepared_request_args) completion_stream = await self.make_async_call( api_base=prepared_request.url, headers=prepared_request.headers, # type: ignore data=data, logging_obj=logging_obj, ) streaming_response = CustomStreamWrapper( completion_stream=completion_stream, model=model, custom_llm_provider="sagemaker", logging_obj=logging_obj, ) # LOGGING logging_obj.post_call( input=[], api_key="", original_response="first stream response received", additional_args={"complete_input_dict": data}, ) return streaming_response async def async_completion( self, messages: List[AllMessageValues], model: str, custom_prompt_dict: dict, hf_model_name: Optional[str], credentials, aws_region_name: str, encoding, model_response: ModelResponse, optional_params: dict, logging_obj: Any, model_id: Optional[str], headers: dict, litellm_params: dict, ): timeout = 300.0 async_handler = get_async_httpx_client( llm_provider=litellm.LlmProviders.SAGEMAKER ) data = await sagemaker_config.async_transform_request( model=model, messages=messages, optional_params=optional_params, litellm_params=litellm_params, headers=headers, ) asyncified_prepare_request = asyncify(self._prepare_request) prepared_request_args = { "model": model, "data": data, "optional_params": optional_params, "credentials": credentials, "aws_region_name": aws_region_name, "messages": messages, } prepared_request = await asyncified_prepare_request(**prepared_request_args) ## LOGGING logging_obj.pre_call( input=[], api_key="", additional_args={ "complete_input_dict": data, "api_base": prepared_request.url, "headers": prepared_request.headers, }, ) try: if model_id is not None: # Add model_id as InferenceComponentName header # boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html prepared_request.headers.update( {"X-Amzn-SageMaker-Inference-Componen": model_id} ) # make async httpx post request here try: response = await async_handler.post( url=prepared_request.url, headers=prepared_request.headers, # type: ignore json=data, timeout=timeout, ) if response.status_code != 200: raise SagemakerError( status_code=response.status_code, message=response.text ) except Exception as e: ## LOGGING logging_obj.post_call( input=data["inputs"], api_key="", original_response=str(e), additional_args={"complete_input_dict": data}, ) raise e except Exception as e: error_message = f"{str(e)}" if "Inference Component Name header is required" in error_message: error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`" raise SagemakerError(status_code=500, message=error_message) return sagemaker_config.transform_response( model=model, raw_response=response, model_response=model_response, logging_obj=logging_obj, request_data=data, messages=messages, optional_params=optional_params, encoding=encoding, litellm_params=litellm_params, ) def embedding( self, model: str, input: list, model_response: EmbeddingResponse, print_verbose: Callable, encoding, logging_obj, optional_params: dict, custom_prompt_dict={}, litellm_params=None, logger_fn=None, ): """ Supports Huggingface Jumpstart embeddings like GPT-6B """ ### BOTO3 INIT import boto3 # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_region_name = optional_params.pop("aws_region_name", None) if aws_access_key_id is not None: # uses auth params passed to completion # aws_access_key_id is not None, assume user is trying to auth using litellm.completion client = boto3.client( service_name="sagemaker-runtime", aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=aws_region_name, ) else: # aws_access_key_id is None, assume user is trying to auth using env variables # boto3 automaticaly reads env variables # we need to read region name from env # I assume majority of users use .env for auth region_name = ( get_secret("AWS_REGION_NAME") or aws_region_name # get region from config file if specified or "us-west-2" # default to us-west-2 if region not specified ) client = boto3.client( service_name="sagemaker-runtime", region_name=region_name, ) # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker inference_params = deepcopy(optional_params) inference_params.pop("stream", None) ## Load Config config = litellm.SagemakerConfig.get_config() for k, v in config.items(): if ( k not in inference_params ): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v #### HF EMBEDDING LOGIC data = json.dumps({"text_inputs": input}).encode("utf-8") ## LOGGING request_str = f""" response = client.invoke_endpoint( EndpointName={model}, ContentType="application/json", Body={data}, # type: ignore CustomAttributes="accept_eula=true", )""" # type: ignore logging_obj.pre_call( input=input, api_key="", additional_args={"complete_input_dict": data, "request_str": request_str}, ) ## EMBEDDING CALL try: response = client.invoke_endpoint( EndpointName=model, ContentType="application/json", Body=data, CustomAttributes="accept_eula=true", ) except Exception as e: status_code = ( getattr(e, "response", {}) .get("ResponseMetadata", {}) .get("HTTPStatusCode", 500) ) error_message = ( getattr(e, "response", {}).get("Error", {}).get("Message", str(e)) ) raise SagemakerError(status_code=status_code, message=error_message) response = json.loads(response["Body"].read().decode("utf8")) ## LOGGING logging_obj.post_call( input=input, api_key="", original_response=response, additional_args={"complete_input_dict": data}, ) print_verbose(f"raw model_response: {response}") if "embedding" not in response: raise SagemakerError( status_code=500, message="embedding not found in response" ) embeddings = response["embedding"] if not isinstance(embeddings, list): raise SagemakerError( status_code=422, message=f"Response not in expected format - {embeddings}", ) output_data = [] for idx, embedding in enumerate(embeddings): output_data.append( {"object": "embedding", "index": idx, "embedding": embedding} ) model_response.object = "list" model_response.data = output_data model_response.model = model input_tokens = 0 for text in input: input_tokens += len(encoding.encode(text)) setattr( model_response, "usage", Usage( prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens, ), ) return model_response