diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md index 7485ae1576..ebf18f8fb9 100644 --- a/docs/my-website/docs/providers/bedrock.md +++ b/docs/my-website/docs/providers/bedrock.md @@ -49,6 +49,47 @@ response = completion( ) ``` +### Passing an external BedrockRuntime.Client as a parameter - Completion() +Pass an external BedrockRuntime.Client object as a parameter to litellm.completion. Useful when using an AWS credentials profile, SSO session, assumed role session, or if environment variables are not available for auth. + +Create a client from session credentials: +```python +import boto3 +from litellm import completion + +bedrock = boto3.client( + service_name="bedrock-runtime", + region_name="us-east-1", + aws_access_key_id="", + aws_secret_access_key_id="", + aws_session_token="", +) + +response = completion( + model="bedrock/anthropic.claude-instant-v1", + messages=[{ "content": "Hello, how are you?","role": "user"}], + aws_bedrock_client=bedrock, +) +``` + +Create a client from AWS profile in `~/.aws/config`: +```python +import boto3 +from litellm import completion + +dev_session = boto3.Session(profile_name="dev-profile") +bedrock = dev_session.client( + service_name="bedrock-runtime", + region_name="us-east-1", +) + +response = completion( + model="bedrock/anthropic.claude-instant-v1", + messages=[{ "content": "Hello, how are you?","role": "user"}], + aws_bedrock_client=bedrock, +) +``` + ## Supported AWS Bedrock Models Here's an example of using a bedrock model with LiteLLM diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 04587262c4..2ccb6f9d8b 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -269,11 +269,15 @@ def completion( aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_region_name = optional_params.pop("aws_region_name", None) - # only pass variables that are not None - client = init_bedrock_client( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_region_name=aws_region_name, + # use passed in BedrockRuntime.Client if provided, otherwise create a new one + client = optional_params.pop( + "aws_bedrock_client", + # only pass variables that are not None + init_bedrock_client( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_region_name=aws_region_name, + ), ) model = model diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 7dc23bcc47..73b56e8ac4 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -846,7 +846,7 @@ def test_completion_bedrock_claude_completion_auth(): # Add any assertions here to check the response print(response) - os.environ["AWS_ACCESS_KEY_ID"] = aws_secret_access_key + os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key os.environ["AWS_REGION_NAME"] = aws_region_name @@ -854,6 +854,49 @@ def test_completion_bedrock_claude_completion_auth(): pytest.fail(f"Error occurred: {e}") # test_completion_bedrock_claude_completion_auth() +def test_completion_bedrock_claude_external_client_auth(): + print("calling bedrock claude external client auth") + import os + + aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"] + aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"] + aws_session_token = os.environ["AWS_SESSION_TOKEN"] + aws_region_name = os.environ["AWS_REGION_NAME"] + + os.environ["AWS_ACCESS_KEY_ID"] = "" + os.environ["AWS_SECRET_ACCESS_KEY"] = "" + os.environ["AWS_REGION_NAME"] = "" + + try: + import boto3 + bedrock = boto3.client( + service_name="bedrock-runtime", + region_name=aws_region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + endpoint_url=f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + ) + + response = completion( + model="bedrock/anthropic.claude-instant-v1", + messages=messages, + max_tokens=10, + temperature=0.1, + logger_fn=logger_fn, + aws_bedrock_client=bedrock, + ) + # Add any assertions here to check the response + print(response) + + os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id + os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key + os.environ["AWS_REGION_NAME"] = aws_region_name + + except Exception as e: + pytest.fail(f"Error occurred: {e}") +# test_completion_bedrock_claude_external_client_auth() + def test_completion_bedrock_claude_stream(): print("calling claude") litellm.set_verbose = False