mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
feat add support for aws_region_name
This commit is contained in:
parent
b4ba12e22c
commit
fa569aaf6f
2 changed files with 102 additions and 7 deletions
|
@ -156,6 +156,7 @@ async def test_acompletion_sagemaker_non_stream():
|
|||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.status_code = 200
|
||||
|
||||
expected_payload = {
|
||||
"inputs": "hi",
|
||||
|
@ -215,6 +216,7 @@ async def test_completion_sagemaker_non_stream():
|
|||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.status_code = 200
|
||||
|
||||
expected_payload = {
|
||||
"inputs": "hi",
|
||||
|
@ -249,3 +251,66 @@ async def test_completion_sagemaker_non_stream():
|
|||
kwargs["url"]
|
||||
== "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_sagemaker_non_stream_with_aws_params():
|
||||
mock_response = MagicMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"generated_text": "This is a mock response from SageMaker.",
|
||||
"id": "cmpl-mockid",
|
||||
"object": "text_completion",
|
||||
"created": 1629800000,
|
||||
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||
"choices": [
|
||||
{
|
||||
"text": "This is a mock response from SageMaker.",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "length",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.status_code = 200
|
||||
|
||||
expected_payload = {
|
||||
"inputs": "hi",
|
||||
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
# Act: Call the litellm.acompletion function
|
||||
response = litellm.completion(
|
||||
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||
messages=[
|
||||
{"role": "user", "content": "hi"},
|
||||
],
|
||||
temperature=0.2,
|
||||
max_tokens=80,
|
||||
input_cost_per_second=0.000420,
|
||||
aws_access_key_id="gm",
|
||||
aws_secret_access_key="s",
|
||||
aws_region_name="us-west-5",
|
||||
)
|
||||
|
||||
# Print what was called on the mock
|
||||
print("call args=", mock_post.call_args)
|
||||
|
||||
# Assert
|
||||
mock_post.assert_called_once()
|
||||
_, kwargs = mock_post.call_args
|
||||
args_to_sagemaker = kwargs["json"]
|
||||
print("Arguments passed to sagemaker=", args_to_sagemaker)
|
||||
assert args_to_sagemaker == expected_payload
|
||||
assert (
|
||||
kwargs["url"]
|
||||
== "https://runtime.sagemaker.us-west-5.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue