diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index e3a58a7675..14097bb22c 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -104,17 +104,11 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = "" # set os.environ['AWS_REGION_NAME'] = class SagemakerLLM(BaseAWSLLM): - def _prepare_request( + def _load_credentials( self, - model: str, - data: dict, optional_params: dict, - extra_headers: Optional[dict] = None, ): try: - import boto3 - from botocore.auth import SigV4Auth - from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials except ImportError as e: raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") @@ -163,6 +157,25 @@ class SagemakerLLM(BaseAWSLLM): aws_web_identity_token=aws_web_identity_token, aws_sts_endpoint=aws_sts_endpoint, ) + return credentials, aws_region_name + + def _prepare_request( + self, + credentials, + model: str, + data: dict, + optional_params: dict, + aws_region_name: str, + extra_headers: Optional[dict] = None, + ): + try: + import boto3 + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError as e: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name) if optional_params.get("stream") is True: api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations-response-stream" @@ -198,6 +211,7 @@ class SagemakerLLM(BaseAWSLLM): ): # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker + credentials, aws_region_name = self._load_credentials(optional_params) inference_params = deepcopy(optional_params) ## Load Config @@ -250,6 +264,8 @@ class SagemakerLLM(BaseAWSLLM): model=model, data=data, optional_params=optional_params, + credentials=credentials, + aws_region_name=aws_region_name, ) if model_id is not None: # Add model_id as InferenceComponentName header @@ -313,6 +329,8 @@ class SagemakerLLM(BaseAWSLLM): model=model, data=_data, optional_params=optional_params, + credentials=credentials, + aws_region_name=aws_region_name, ) # Async completion @@ -357,6 +375,12 @@ class SagemakerLLM(BaseAWSLLM): json=_data, timeout=timeout, ) + + if sync_response.status_code != 200: + raise SagemakerError( + status_code=sync_response.status_code, + message=sync_response.text, + ) except Exception as e: ## LOGGING logging_obj.post_call( @@ -367,6 +391,7 @@ class SagemakerLLM(BaseAWSLLM): ) raise e except Exception as e: + verbose_logger.error("Sagemaker error %s", str(e)) status_code = ( getattr(e, "response", {}) .get("ResponseMetadata", {}) @@ -547,6 +572,11 @@ class SagemakerLLM(BaseAWSLLM): json=data, timeout=timeout, ) + + if response.status_code != 200: + raise SagemakerError( + status_code=response.status_code, message=response.text + ) except Exception as e: ## LOGGING logging_obj.post_call( diff --git a/litellm/tests/test_sagemaker.py b/litellm/tests/test_sagemaker.py index 3f8fb6557c..b6b4251c6b 100644 --- a/litellm/tests/test_sagemaker.py +++ b/litellm/tests/test_sagemaker.py @@ -156,6 +156,7 @@ async def test_acompletion_sagemaker_non_stream(): } mock_response.json = return_val + mock_response.status_code = 200 expected_payload = { "inputs": "hi", @@ -215,6 +216,7 @@ async def test_completion_sagemaker_non_stream(): } mock_response.json = return_val + mock_response.status_code = 200 expected_payload = { "inputs": "hi", @@ -249,3 +251,66 @@ async def test_completion_sagemaker_non_stream(): kwargs["url"] == "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations" ) + + +@pytest.mark.asyncio +async def test_completion_sagemaker_non_stream_with_aws_params(): + mock_response = MagicMock() + + def return_val(): + return { + "generated_text": "This is a mock response from SageMaker.", + "id": "cmpl-mockid", + "object": "text_completion", + "created": 1629800000, + "model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + "choices": [ + { + "text": "This is a mock response from SageMaker.", + "index": 0, + "logprobs": None, + "finish_reason": "length", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9}, + } + + mock_response.json = return_val + mock_response.status_code = 200 + + expected_payload = { + "inputs": "hi", + "parameters": {"temperature": 0.2, "max_new_tokens": 80}, + } + + with patch( + "litellm.llms.custom_httpx.http_handler.HTTPHandler.post", + return_value=mock_response, + ) as mock_post: + # Act: Call the litellm.acompletion function + response = litellm.completion( + model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + messages=[ + {"role": "user", "content": "hi"}, + ], + temperature=0.2, + max_tokens=80, + input_cost_per_second=0.000420, + aws_access_key_id="gm", + aws_secret_access_key="s", + aws_region_name="us-west-5", + ) + + # Print what was called on the mock + print("call args=", mock_post.call_args) + + # Assert + mock_post.assert_called_once() + _, kwargs = mock_post.call_args + args_to_sagemaker = kwargs["json"] + print("Arguments passed to sagemaker=", args_to_sagemaker) + assert args_to_sagemaker == expected_payload + assert ( + kwargs["url"] + == "https://runtime.sagemaker.us-west-5.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations" + )