diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md index e76412b1bb..ebf18f8fb9 100644 --- a/docs/my-website/docs/providers/bedrock.md +++ b/docs/my-website/docs/providers/bedrock.md @@ -49,8 +49,10 @@ response = completion( ) ``` -### Passing a BedrockClient as a parameter - Completion() -Pass an existing BedrockClient object to litellm.completion. Useful when using AWS SSO sessions or assumed role sessions. +### Passing an external BedrockRuntime.Client as a parameter - Completion() +Pass an external BedrockRuntime.Client object as a parameter to litellm.completion. Useful when using an AWS credentials profile, SSO session, assumed role session, or if environment variables are not available for auth. + +Create a client from session credentials: ```python import boto3 from litellm import completion @@ -61,12 +63,30 @@ bedrock = boto3.client( aws_access_key_id="", aws_secret_access_key_id="", aws_session_token="", - ) +) response = completion( model="bedrock/anthropic.claude-instant-v1", messages=[{ "content": "Hello, how are you?","role": "user"}], - aws_bedrock_client=bedrock + aws_bedrock_client=bedrock, +) +``` + +Create a client from AWS profile in `~/.aws/config`: +```python +import boto3 +from litellm import completion + +dev_session = boto3.Session(profile_name="dev-profile") +bedrock = dev_session.client( + service_name="bedrock-runtime", + region_name="us-east-1", +) + +response = completion( + model="bedrock/anthropic.claude-instant-v1", + messages=[{ "content": "Hello, how are you?","role": "user"}], + aws_bedrock_client=bedrock, ) ``` diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 1fd7e59482..2ccb6f9d8b 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -269,7 +269,7 @@ def completion( aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_region_name = optional_params.pop("aws_region_name", None) - # use passed in BedrockClient if provided, otherwise init a new one + # use passed in BedrockRuntime.Client if provided, otherwise create a new one client = optional_params.pop( "aws_bedrock_client", # only pass variables that are not None diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 7dc23bcc47..73b56e8ac4 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -846,7 +846,7 @@ def test_completion_bedrock_claude_completion_auth(): # Add any assertions here to check the response print(response) - os.environ["AWS_ACCESS_KEY_ID"] = aws_secret_access_key + os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key os.environ["AWS_REGION_NAME"] = aws_region_name @@ -854,6 +854,49 @@ def test_completion_bedrock_claude_completion_auth(): pytest.fail(f"Error occurred: {e}") # test_completion_bedrock_claude_completion_auth() +def test_completion_bedrock_claude_external_client_auth(): + print("calling bedrock claude external client auth") + import os + + aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"] + aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"] + aws_session_token = os.environ["AWS_SESSION_TOKEN"] + aws_region_name = os.environ["AWS_REGION_NAME"] + + os.environ["AWS_ACCESS_KEY_ID"] = "" + os.environ["AWS_SECRET_ACCESS_KEY"] = "" + os.environ["AWS_REGION_NAME"] = "" + + try: + import boto3 + bedrock = boto3.client( + service_name="bedrock-runtime", + region_name=aws_region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + endpoint_url=f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + ) + + response = completion( + model="bedrock/anthropic.claude-instant-v1", + messages=messages, + max_tokens=10, + temperature=0.1, + logger_fn=logger_fn, + aws_bedrock_client=bedrock, + ) + # Add any assertions here to check the response + print(response) + + os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id + os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key + os.environ["AWS_REGION_NAME"] = aws_region_name + + except Exception as e: + pytest.fail(f"Error occurred: {e}") +# test_completion_bedrock_claude_external_client_auth() + def test_completion_bedrock_claude_stream(): print("calling claude") litellm.set_verbose = False