mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat add support for aws_region_name
This commit is contained in:
parent
b4ba12e22c
commit
fa569aaf6f
2 changed files with 102 additions and 7 deletions
|
@ -104,17 +104,11 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
|
||||||
# set os.environ['AWS_REGION_NAME'] = <your-region_name>
|
# set os.environ['AWS_REGION_NAME'] = <your-region_name>
|
||||||
class SagemakerLLM(BaseAWSLLM):
|
class SagemakerLLM(BaseAWSLLM):
|
||||||
|
|
||||||
def _prepare_request(
|
def _load_credentials(
|
||||||
self,
|
self,
|
||||||
model: str,
|
|
||||||
data: dict,
|
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
extra_headers: Optional[dict] = None,
|
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
import boto3
|
|
||||||
from botocore.auth import SigV4Auth
|
|
||||||
from botocore.awsrequest import AWSRequest
|
|
||||||
from botocore.credentials import Credentials
|
from botocore.credentials import Credentials
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
@ -163,6 +157,25 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
aws_web_identity_token=aws_web_identity_token,
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
aws_sts_endpoint=aws_sts_endpoint,
|
aws_sts_endpoint=aws_sts_endpoint,
|
||||||
)
|
)
|
||||||
|
return credentials, aws_region_name
|
||||||
|
|
||||||
|
def _prepare_request(
|
||||||
|
self,
|
||||||
|
credentials,
|
||||||
|
model: str,
|
||||||
|
data: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
aws_region_name: str,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name)
|
sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name)
|
||||||
if optional_params.get("stream") is True:
|
if optional_params.get("stream") is True:
|
||||||
api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations-response-stream"
|
api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations-response-stream"
|
||||||
|
@ -198,6 +211,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
):
|
):
|
||||||
|
|
||||||
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||||
|
credentials, aws_region_name = self._load_credentials(optional_params)
|
||||||
inference_params = deepcopy(optional_params)
|
inference_params = deepcopy(optional_params)
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
|
@ -250,6 +264,8 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
model=model,
|
model=model,
|
||||||
data=data,
|
data=data,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
credentials=credentials,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
)
|
)
|
||||||
if model_id is not None:
|
if model_id is not None:
|
||||||
# Add model_id as InferenceComponentName header
|
# Add model_id as InferenceComponentName header
|
||||||
|
@ -313,6 +329,8 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
model=model,
|
model=model,
|
||||||
data=_data,
|
data=_data,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
credentials=credentials,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Async completion
|
# Async completion
|
||||||
|
@ -357,6 +375,12 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
json=_data,
|
json=_data,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if sync_response.status_code != 200:
|
||||||
|
raise SagemakerError(
|
||||||
|
status_code=sync_response.status_code,
|
||||||
|
message=sync_response.text,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -367,6 +391,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
verbose_logger.error("Sagemaker error %s", str(e))
|
||||||
status_code = (
|
status_code = (
|
||||||
getattr(e, "response", {})
|
getattr(e, "response", {})
|
||||||
.get("ResponseMetadata", {})
|
.get("ResponseMetadata", {})
|
||||||
|
@ -547,6 +572,11 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
json=data,
|
json=data,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise SagemakerError(
|
||||||
|
status_code=response.status_code, message=response.text
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
|
|
@ -156,6 +156,7 @@ async def test_acompletion_sagemaker_non_stream():
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_response.json = return_val
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
expected_payload = {
|
expected_payload = {
|
||||||
"inputs": "hi",
|
"inputs": "hi",
|
||||||
|
@ -215,6 +216,7 @@ async def test_completion_sagemaker_non_stream():
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_response.json = return_val
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
expected_payload = {
|
expected_payload = {
|
||||||
"inputs": "hi",
|
"inputs": "hi",
|
||||||
|
@ -249,3 +251,66 @@ async def test_completion_sagemaker_non_stream():
|
||||||
kwargs["url"]
|
kwargs["url"]
|
||||||
== "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
|
== "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_sagemaker_non_stream_with_aws_params():
|
||||||
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"generated_text": "This is a mock response from SageMaker.",
|
||||||
|
"id": "cmpl-mockid",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1629800000,
|
||||||
|
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "This is a mock response from SageMaker.",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "length",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
expected_payload = {
|
||||||
|
"inputs": "hi",
|
||||||
|
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
# Act: Call the litellm.acompletion function
|
||||||
|
response = litellm.completion(
|
||||||
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=80,
|
||||||
|
input_cost_per_second=0.000420,
|
||||||
|
aws_access_key_id="gm",
|
||||||
|
aws_secret_access_key="s",
|
||||||
|
aws_region_name="us-west-5",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print what was called on the mock
|
||||||
|
print("call args=", mock_post.call_args)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
_, kwargs = mock_post.call_args
|
||||||
|
args_to_sagemaker = kwargs["json"]
|
||||||
|
print("Arguments passed to sagemaker=", args_to_sagemaker)
|
||||||
|
assert args_to_sagemaker == expected_payload
|
||||||
|
assert (
|
||||||
|
kwargs["url"]
|
||||||
|
== "https://runtime.sagemaker.us-west-5.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue