forked from phoenix/litellm-mirror
fix(main.py): get the region name from boto3 client if dynamic var not set
This commit is contained in:
parent
5aae2313f3
commit
79670ab82e
2 changed files with 64 additions and 0 deletions
|
@ -2212,15 +2212,26 @@ def completion(
|
|||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
|
||||
if "aws_bedrock_client" in optional_params:
|
||||
verbose_logger.warning(
|
||||
"'aws_bedrock_client' is a deprecated param. Please move to another auth method - https://docs.litellm.ai/docs/providers/bedrock#boto3---authentication."
|
||||
)
|
||||
# Extract credentials for legacy boto3 client and pass thru to httpx
|
||||
aws_bedrock_client = optional_params.pop("aws_bedrock_client")
|
||||
creds = aws_bedrock_client._get_credentials().get_frozen_credentials()
|
||||
|
||||
if creds.access_key:
|
||||
optional_params["aws_access_key_id"] = creds.access_key
|
||||
if creds.secret_key:
|
||||
optional_params["aws_secret_access_key"] = creds.secret_key
|
||||
if creds.token:
|
||||
optional_params["aws_session_token"] = creds.token
|
||||
if (
|
||||
"aws_region_name" not in optional_params
|
||||
or optional_params["aws_region_name"] is None
|
||||
):
|
||||
optional_params["aws_region_name"] = (
|
||||
aws_bedrock_client.meta.region_name
|
||||
)
|
||||
|
||||
if model in litellm.BEDROCK_CONVERSE_MODELS:
|
||||
response = bedrock_converse_chat_completion.completion(
|
||||
|
|
|
@ -856,3 +856,56 @@ async def test_bedrock_custom_prompt_template():
|
|||
prompt = json.loads(mock_client_post.call_args.kwargs["data"])["prompt"]
|
||||
assert prompt == "<|im_start|>user\nWhat's AWS?<|im_end|>"
|
||||
mock_client_post.assert_called_once()
|
||||
|
||||
|
||||
def test_completion_bedrock_external_client_region():
|
||||
print("\ncalling bedrock claude external client auth")
|
||||
import os
|
||||
|
||||
aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
|
||||
aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
|
||||
aws_region_name = "us-east-1"
|
||||
|
||||
os.environ.pop("AWS_ACCESS_KEY_ID", None)
|
||||
os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
try:
|
||||
import boto3
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
bedrock = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
region_name=aws_region_name,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
endpoint_url=f"https://bedrock-runtime.{aws_region_name}.amazonaws.com",
|
||||
)
|
||||
with patch.object(client, "post", new=Mock()) as mock_client_post:
|
||||
try:
|
||||
response = completion(
|
||||
model="bedrock/anthropic.claude-instant-v1",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.1,
|
||||
aws_bedrock_client=bedrock,
|
||||
client=client,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
print(f"mock_client_post.call_args: {mock_client_post.call_args}")
|
||||
assert "us-east-1" in mock_client_post.call_args.kwargs["url"]
|
||||
|
||||
mock_client_post.assert_called_once()
|
||||
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue