mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
(Bug Fix) - Bedrock completions with aws_region_name (#8384)
* test_bedrock_completion_with_region_name * test_bedrock_base_model_helper * test_bedrock_base_model_helper * fix aws_bedrock_runtime_endpoint * test_dynamic_aws_params_propagation * test_dynamic_aws_params_propagation
This commit is contained in:
parent
64ccf4cd6e
commit
88093608e3
4 changed files with 310 additions and 11 deletions
|
@ -52,6 +52,7 @@ class BaseAWSLLM:
|
|||
"aws_role_name",
|
||||
"aws_web_identity_token",
|
||||
"aws_sts_endpoint",
|
||||
"aws_bedrock_runtime_endpoint",
|
||||
]
|
||||
|
||||
def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str:
|
||||
|
|
|
@ -87,7 +87,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
optional_params=optional_params,
|
||||
)
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
aws_bedrock_runtime_endpoint = optional_params.get(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
|
@ -125,15 +125,15 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||
extra_headers = optional_params.pop("extra_headers", None)
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||
extra_headers = optional_params.get("extra_headers", None)
|
||||
aws_secret_access_key = optional_params.get("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.get("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.get("aws_session_token", None)
|
||||
aws_role_name = optional_params.get("aws_role_name", None)
|
||||
aws_session_name = optional_params.get("aws_session_name", None)
|
||||
aws_profile_name = optional_params.get("aws_profile_name", None)
|
||||
aws_web_identity_token = optional_params.get("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.get("aws_sts_endpoint", None)
|
||||
aws_region_name = self._get_aws_region_name(optional_params)
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
|
@ -588,7 +588,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
"""
|
||||
Get the AWS region name from the environment variables
|
||||
"""
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_region_name = optional_params.get("aws_region_name", None)
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check env #
|
||||
|
|
|
@ -14,6 +14,7 @@ import litellm.types
|
|||
load_dotenv()
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
|
|
@ -0,0 +1,297 @@
|
|||
# tests/llm_translation/test_base_aws_llm.py
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from botocore.credentials import Credentials
|
||||
import sys
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from unittest.mock import Mock
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
|
||||
|
||||
def test_bedrock_completion_with_region_name():
|
||||
litellm._turn_on_debug()
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
mock_response = Mock()
|
||||
# Construct a response similar to our other tests.
|
||||
mock_response.text = json.dumps(
|
||||
{
|
||||
"response_id": "379ed018/60744aff-e741-4aad-bd10-74639a4ade79",
|
||||
"text": "Hello! How's it going? I hope you're having a fantastic day!",
|
||||
"generation_id": "38709bb9-f20f-42d9-9c61-13a73b7bbc12",
|
||||
"chat_history": [
|
||||
{"role": "USER", "message": "Hello, world!"},
|
||||
{
|
||||
"role": "CHATBOT",
|
||||
"message": "Hello! How's it going? I hope you're having a fantastic day!",
|
||||
},
|
||||
],
|
||||
"finish_reason": "COMPLETE",
|
||||
}
|
||||
)
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json = lambda: json.loads(mock_response.text)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Pass the client so that the HTTP call will be intercepted.
|
||||
response = litellm.completion(
|
||||
model="cohere.command-r-v1:0",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
aws_region_name="us-west-12",
|
||||
client=client,
|
||||
)
|
||||
|
||||
# Ensure our post method has been called.
|
||||
mock_post.assert_called_once()
|
||||
|
||||
assert (
|
||||
mock_post.call_args.kwargs["url"]
|
||||
== "https://bedrock-runtime.us-west-12.amazonaws.com/model/cohere.command-r-v1:0/invoke"
|
||||
)
|
||||
assert (
|
||||
mock_post.call_args.kwargs["data"]
|
||||
== '{"message": "Hello, world!", "chat_history": []}'
|
||||
)
|
||||
|
||||
# Print the URL and body of the HTTP request.
|
||||
# assert request was signed with the correct region
|
||||
_authorization_header = mock_post.call_args.kwargs["headers"]["Authorization"]
|
||||
import re
|
||||
|
||||
# Ensure the authorization header contains the exact region segment "us-west-12/bedrock/aws4_request"
|
||||
pattern = r"us-west-12/bedrock/aws4_request"
|
||||
assert re.search(pattern, _authorization_header) is not None
|
||||
|
||||
|
||||
def test_bedrock_completion_with_dynamic_authentication_params():
|
||||
litellm._turn_on_debug()
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
mock_response = Mock()
|
||||
# Construct a response similar to our other tests.
|
||||
mock_response.text = json.dumps(
|
||||
{
|
||||
"response_id": "379ed018/60744aff-e741-4aad-bd10-74639a4ade79",
|
||||
"text": "Hello! How's it going? I hope you're having a fantastic day!",
|
||||
"generation_id": "38709bb9-f20f-42d9-9c61-13a73b7bbc12",
|
||||
"chat_history": [
|
||||
{"role": "USER", "message": "Hello, world!"},
|
||||
{
|
||||
"role": "CHATBOT",
|
||||
"message": "Hello! How's it going? I hope you're having a fantastic day!",
|
||||
},
|
||||
],
|
||||
"finish_reason": "COMPLETE",
|
||||
}
|
||||
)
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json = lambda: json.loads(mock_response.text)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Pass the client so that the HTTP call will be intercepted.
|
||||
response = litellm.completion(
|
||||
model="cohere.command-r-v1:0",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
aws_access_key_id="dynamically_generated_access_key_id",
|
||||
aws_secret_access_key="dynamically_generated_secret_access_key",
|
||||
client=client,
|
||||
)
|
||||
|
||||
# Ensure our post method has been called.
|
||||
mock_post.assert_called_once()
|
||||
import re
|
||||
|
||||
# Get authorization header
|
||||
_authorization_header = mock_post.call_args.kwargs["headers"]["Authorization"]
|
||||
|
||||
# Check for exact credential pattern
|
||||
pattern = r"AWS4-HMAC-SHA256 Credential=dynamically_generated_access_key_id/\d{8}/[a-z0-9-]+/bedrock/aws4_request"
|
||||
assert re.search(pattern, _authorization_header) is not None
|
||||
|
||||
|
||||
def test_bedrock_completion_with_dynamic_bedrock_runtime_endpoint():
|
||||
litellm._turn_on_debug()
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
mock_response = Mock()
|
||||
# Construct a response similar to our other tests.
|
||||
mock_response.text = json.dumps(
|
||||
{
|
||||
"response_id": "379ed018/60744aff-e741-4aad-bd10-74639a4ade79",
|
||||
"text": "Hello! How's it going? I hope you're having a fantastic day!",
|
||||
"generation_id": "38709bb9-f20f-42d9-9c61-13a73b7bbc12",
|
||||
"chat_history": [
|
||||
{"role": "USER", "message": "Hello, world!"},
|
||||
{
|
||||
"role": "CHATBOT",
|
||||
"message": "Hello! How's it going? I hope you're having a fantastic day!",
|
||||
},
|
||||
],
|
||||
"finish_reason": "COMPLETE",
|
||||
}
|
||||
)
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json = lambda: json.loads(mock_response.text)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Pass the client so that the HTTP call will be intercepted.
|
||||
response = litellm.completion(
|
||||
model="cohere.command-r-v1:0",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
aws_bedrock_runtime_endpoint="https://my-fake-endpoint.com",
|
||||
client=client,
|
||||
)
|
||||
|
||||
# Ensure our post method has been called.
|
||||
mock_post.assert_called_once()
|
||||
assert (
|
||||
mock_post.call_args.kwargs["url"]
|
||||
== "https://my-fake-endpoint.com/model/cohere.command-r-v1:0/invoke"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# A dummy credentials object to return from get_credentials.
|
||||
# (It must have attributes so that SigV4Auth.add_auth doesn't break.)
|
||||
# ------------------------------------------------------------------------------
|
||||
class DummyCredentials:
|
||||
access_key = "dummy_access"
|
||||
secret_key = "dummy_secret"
|
||||
token = "dummy_token"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# This test makes sure that a given dynamic parameter is passed into the call
|
||||
# to BaseAWSLLM.get_credentials. (Some dynamic params—for example aws_region_name
|
||||
# or aws_bedrock_runtime_endpoint—are already covered by other tests.)
|
||||
# ------------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"bedrock/converse/cohere.command-r-v1:0",
|
||||
"cohere.command-r-v1:0",
|
||||
"bedrock/cohere.command-r-v1:0",
|
||||
"bedrock/invoke/cohere.command-r-v1:0",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"param_name, param_value",
|
||||
[
|
||||
("aws_session_token", "dummy_session_token"),
|
||||
("aws_session_name", "dummy_session_name"),
|
||||
("aws_profile_name", "dummy_profile_name"),
|
||||
("aws_role_name", "dummy_role_name"),
|
||||
("aws_web_identity_token", "dummy_web_identity_token"),
|
||||
("aws_sts_endpoint", "dummy_sts_endpoint"),
|
||||
],
|
||||
)
|
||||
def test_dynamic_aws_params_propagation(model, param_name, param_value):
|
||||
"""
|
||||
When passed to litellm.completion, each dynamic AWS authentication parameter
|
||||
should propagate down to the get_credentials() call in BaseAWSLLM.
|
||||
|
||||
Also tests different model parameter values.
|
||||
"""
|
||||
client = HTTPHandler()
|
||||
|
||||
# Base parameters required for the completion call.
|
||||
# (We include aws_access_key_id and aws_secret_access_key so that the correct auth
|
||||
# branch in get_credentials() is reached.)
|
||||
base_params = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||
"aws_access_key_id": "dummy_access",
|
||||
"aws_secret_access_key": "dummy_secret",
|
||||
"client": client,
|
||||
}
|
||||
# For parameters such as aws_role_name or aws_web_identity_token a session name is required.
|
||||
if param_name in ("aws_role_name", "aws_web_identity_token"):
|
||||
base_params["aws_session_name"] = "dummy_session_name"
|
||||
if param_name == "aws_web_identity_token":
|
||||
# The web identity branch also requires a role name.
|
||||
base_params["aws_role_name"] = "dummy_role_name"
|
||||
# Inject the dynamic parameter under test.
|
||||
base_params[param_name] = param_value
|
||||
|
||||
# Patch SigV4Auth in the signing (so that no actual signing is done).
|
||||
with patch("botocore.auth.SigV4Auth", autospec=True) as mock_sigv4:
|
||||
instance = mock_sigv4.return_value
|
||||
instance.add_auth.return_value = None
|
||||
|
||||
# Patch BaseAWSLLM.get_credentials so that we can capture its kwargs.
|
||||
def dummy_get_credentials(**kwargs):
|
||||
dummy_get_credentials.called_kwargs = kwargs # type: ignore[attr-defined]
|
||||
return DummyCredentials()
|
||||
|
||||
with patch.object(
|
||||
BaseAWSLLM, "get_credentials", side_effect=dummy_get_credentials
|
||||
):
|
||||
# Patch the HTTP client's post method to avoid an actual HTTP call.
|
||||
with patch.object(client, "post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.text = json.dumps(
|
||||
{
|
||||
"response_id": "dummy_response",
|
||||
"text": "Hello! world",
|
||||
"generation_id": "dummy_gen",
|
||||
"chat_history": [],
|
||||
"finish_reason": "COMPLETE",
|
||||
}
|
||||
)
|
||||
if "converse" in model:
|
||||
mock_response.text = json.dumps(
|
||||
{
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"text": "Here's a joke..."}],
|
||||
}
|
||||
},
|
||||
"usage": {
|
||||
"inputTokens": 12,
|
||||
"outputTokens": 6,
|
||||
"totalTokens": 18,
|
||||
},
|
||||
"stopReason": "stop",
|
||||
}
|
||||
)
|
||||
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json = lambda: json.loads(mock_response.text)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Call litellm.completion with our base & dynamic parameters.
|
||||
litellm.completion(**base_params)
|
||||
|
||||
print(
|
||||
"get_credentials.called_kwargs",
|
||||
json.dumps(dummy_get_credentials.called_kwargs, indent=4),
|
||||
)
|
||||
|
||||
# We now assert that get_credentials() was called with the dynamic param.
|
||||
assert (
|
||||
dummy_get_credentials.called_kwargs.get(param_name) == param_value
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue