feat add support for aws_region_name

This commit is contained in:
Ishaan Jaff 2024-08-15 19:32:59 -07:00
parent b4ba12e22c
commit fa569aaf6f
2 changed files with 102 additions and 7 deletions

View file

@ -104,17 +104,11 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
# set os.environ['AWS_REGION_NAME'] = <your-region_name> # set os.environ['AWS_REGION_NAME'] = <your-region_name>
class SagemakerLLM(BaseAWSLLM): class SagemakerLLM(BaseAWSLLM):
def _prepare_request( def _load_credentials(
self, self,
model: str,
data: dict,
optional_params: dict, optional_params: dict,
extra_headers: Optional[dict] = None,
): ):
try: try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials from botocore.credentials import Credentials
except ImportError as e: except ImportError as e:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
@ -163,6 +157,25 @@ class SagemakerLLM(BaseAWSLLM):
aws_web_identity_token=aws_web_identity_token, aws_web_identity_token=aws_web_identity_token,
aws_sts_endpoint=aws_sts_endpoint, aws_sts_endpoint=aws_sts_endpoint,
) )
return credentials, aws_region_name
def _prepare_request(
self,
credentials,
model: str,
data: dict,
optional_params: dict,
aws_region_name: str,
extra_headers: Optional[dict] = None,
):
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError as e:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name) sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name)
if optional_params.get("stream") is True: if optional_params.get("stream") is True:
api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations-response-stream" api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations-response-stream"
@ -198,6 +211,7 @@ class SagemakerLLM(BaseAWSLLM):
): ):
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
credentials, aws_region_name = self._load_credentials(optional_params)
inference_params = deepcopy(optional_params) inference_params = deepcopy(optional_params)
## Load Config ## Load Config
@ -250,6 +264,8 @@ class SagemakerLLM(BaseAWSLLM):
model=model, model=model,
data=data, data=data,
optional_params=optional_params, optional_params=optional_params,
credentials=credentials,
aws_region_name=aws_region_name,
) )
if model_id is not None: if model_id is not None:
# Add model_id as InferenceComponentName header # Add model_id as InferenceComponentName header
@ -313,6 +329,8 @@ class SagemakerLLM(BaseAWSLLM):
model=model, model=model,
data=_data, data=_data,
optional_params=optional_params, optional_params=optional_params,
credentials=credentials,
aws_region_name=aws_region_name,
) )
# Async completion # Async completion
@ -357,6 +375,12 @@ class SagemakerLLM(BaseAWSLLM):
json=_data, json=_data,
timeout=timeout, timeout=timeout,
) )
if sync_response.status_code != 200:
raise SagemakerError(
status_code=sync_response.status_code,
message=sync_response.text,
)
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -367,6 +391,7 @@ class SagemakerLLM(BaseAWSLLM):
) )
raise e raise e
except Exception as e: except Exception as e:
verbose_logger.error("Sagemaker error %s", str(e))
status_code = ( status_code = (
getattr(e, "response", {}) getattr(e, "response", {})
.get("ResponseMetadata", {}) .get("ResponseMetadata", {})
@ -547,6 +572,11 @@ class SagemakerLLM(BaseAWSLLM):
json=data, json=data,
timeout=timeout, timeout=timeout,
) )
if response.status_code != 200:
raise SagemakerError(
status_code=response.status_code, message=response.text
)
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(

View file

@ -156,6 +156,7 @@ async def test_acompletion_sagemaker_non_stream():
} }
mock_response.json = return_val mock_response.json = return_val
mock_response.status_code = 200
expected_payload = { expected_payload = {
"inputs": "hi", "inputs": "hi",
@ -215,6 +216,7 @@ async def test_completion_sagemaker_non_stream():
} }
mock_response.json = return_val mock_response.json = return_val
mock_response.status_code = 200
expected_payload = { expected_payload = {
"inputs": "hi", "inputs": "hi",
@ -249,3 +251,66 @@ async def test_completion_sagemaker_non_stream():
kwargs["url"] kwargs["url"]
== "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations" == "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
) )
@pytest.mark.asyncio
async def test_completion_sagemaker_non_stream_with_aws_params():
mock_response = MagicMock()
def return_val():
return {
"generated_text": "This is a mock response from SageMaker.",
"id": "cmpl-mockid",
"object": "text_completion",
"created": 1629800000,
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
"choices": [
{
"text": "This is a mock response from SageMaker.",
"index": 0,
"logprobs": None,
"finish_reason": "length",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"inputs": "hi",
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
}
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.acompletion function
response = litellm.completion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
aws_access_key_id="gm",
aws_secret_access_key="s",
aws_region_name="us-west-5",
)
# Print what was called on the mock
print("call args=", mock_post.call_args)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
args_to_sagemaker = kwargs["json"]
print("Arguments passed to sagemaker=", args_to_sagemaker)
assert args_to_sagemaker == expected_payload
assert (
kwargs["url"]
== "https://runtime.sagemaker.us-west-5.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
)