forked from phoenix/litellm-mirror
fix(main.py): get the region name from boto3 client if dynamic var not set
This commit is contained in:
parent
5aae2313f3
commit
79670ab82e
2 changed files with 64 additions and 0 deletions
|
@ -2212,15 +2212,26 @@ def completion(
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
|
|
||||||
if "aws_bedrock_client" in optional_params:
|
if "aws_bedrock_client" in optional_params:
|
||||||
|
verbose_logger.warning(
|
||||||
|
"'aws_bedrock_client' is a deprecated param. Please move to another auth method - https://docs.litellm.ai/docs/providers/bedrock#boto3---authentication."
|
||||||
|
)
|
||||||
# Extract credentials for legacy boto3 client and pass thru to httpx
|
# Extract credentials for legacy boto3 client and pass thru to httpx
|
||||||
aws_bedrock_client = optional_params.pop("aws_bedrock_client")
|
aws_bedrock_client = optional_params.pop("aws_bedrock_client")
|
||||||
creds = aws_bedrock_client._get_credentials().get_frozen_credentials()
|
creds = aws_bedrock_client._get_credentials().get_frozen_credentials()
|
||||||
|
|
||||||
if creds.access_key:
|
if creds.access_key:
|
||||||
optional_params["aws_access_key_id"] = creds.access_key
|
optional_params["aws_access_key_id"] = creds.access_key
|
||||||
if creds.secret_key:
|
if creds.secret_key:
|
||||||
optional_params["aws_secret_access_key"] = creds.secret_key
|
optional_params["aws_secret_access_key"] = creds.secret_key
|
||||||
if creds.token:
|
if creds.token:
|
||||||
optional_params["aws_session_token"] = creds.token
|
optional_params["aws_session_token"] = creds.token
|
||||||
|
if (
|
||||||
|
"aws_region_name" not in optional_params
|
||||||
|
or optional_params["aws_region_name"] is None
|
||||||
|
):
|
||||||
|
optional_params["aws_region_name"] = (
|
||||||
|
aws_bedrock_client.meta.region_name
|
||||||
|
)
|
||||||
|
|
||||||
if model in litellm.BEDROCK_CONVERSE_MODELS:
|
if model in litellm.BEDROCK_CONVERSE_MODELS:
|
||||||
response = bedrock_converse_chat_completion.completion(
|
response = bedrock_converse_chat_completion.completion(
|
||||||
|
|
|
@ -856,3 +856,56 @@ async def test_bedrock_custom_prompt_template():
|
||||||
prompt = json.loads(mock_client_post.call_args.kwargs["data"])["prompt"]
|
prompt = json.loads(mock_client_post.call_args.kwargs["data"])["prompt"]
|
||||||
assert prompt == "<|im_start|>user\nWhat's AWS?<|im_end|>"
|
assert prompt == "<|im_start|>user\nWhat's AWS?<|im_end|>"
|
||||||
mock_client_post.assert_called_once()
|
mock_client_post.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_bedrock_external_client_region():
|
||||||
|
print("\ncalling bedrock claude external client auth")
|
||||||
|
import os
|
||||||
|
|
||||||
|
aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
|
||||||
|
aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
|
||||||
|
aws_region_name = "us-east-1"
|
||||||
|
|
||||||
|
os.environ.pop("AWS_ACCESS_KEY_ID", None)
|
||||||
|
os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
bedrock = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
region_name=aws_region_name,
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
endpoint_url=f"https://bedrock-runtime.{aws_region_name}.amazonaws.com",
|
||||||
|
)
|
||||||
|
with patch.object(client, "post", new=Mock()) as mock_client_post:
|
||||||
|
try:
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/anthropic.claude-instant-v1",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.1,
|
||||||
|
aws_bedrock_client=bedrock,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
print(f"mock_client_post.call_args: {mock_client_post.call_args}")
|
||||||
|
assert "us-east-1" in mock_client_post.call_args.kwargs["url"]
|
||||||
|
|
||||||
|
mock_client_post.assert_called_once()
|
||||||
|
|
||||||
|
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
|
||||||
|
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
|
||||||
|
except RateLimitError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue