diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 1ce118469..2f26ae4a9 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -163,8 +163,10 @@ class AmazonAnthropicClaude3Config: "stop", "temperature", "top_p", + "extra_headers" ] + def map_openai_params(self, non_default_params: dict, optional_params: dict): for param, value in non_default_params.items(): if param == "max_tokens": @@ -530,6 +532,15 @@ class AmazonStabilityConfig: } +def add_custom_header(headers): + """Closure to capture the headers and add them.""" + def callback(request, **kwargs): + """Actual callback function that Boto3 will call.""" + for header_name, header_value in headers.items(): + request.headers.add_header(header_name, header_value) + return callback + + def init_bedrock_client( region_name=None, aws_access_key_id: Optional[str] = None, @@ -539,12 +550,12 @@ def init_bedrock_client( aws_session_name: Optional[str] = None, aws_profile_name: Optional[str] = None, aws_role_name: Optional[str] = None, + extra_headers: Optional[dict] = None, timeout: Optional[Union[float, httpx.Timeout]] = None, ): # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) standard_aws_region_name = get_secret("AWS_REGION", None) - ## CHECK IS 'os.environ/' passed in # Define the list of parameters to check params_to_check = [ @@ -660,6 +671,8 @@ def init_bedrock_client( endpoint_url=endpoint_url, config=config, ) + if extra_headers: + client.meta.events.register('before-sign.bedrock-runtime.*', add_custom_header(extra_headers)) return client @@ -723,6 +736,7 @@ def completion( litellm_params=None, logger_fn=None, timeout=None, + extra_headers: Optional[dict] = None, ): exception_mapping_worked = False _is_function_call = False @@ -752,6 +766,7 @@ def completion( aws_role_name=aws_role_name, aws_session_name=aws_session_name, aws_profile_name=aws_profile_name, + extra_headers=extra_headers, timeout=timeout, ) diff --git a/litellm/main.py b/litellm/main.py index 8717af570..de35dbfd0 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1868,6 +1868,7 @@ def completion( logger_fn=logger_fn, encoding=encoding, logging_obj=logging, + extra_headers=extra_headers, timeout=timeout, ) diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index 8eb467542..3f5c831d7 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -207,6 +207,25 @@ def test_completion_bedrock_claude_sts_client_auth(): # test_completion_bedrock_claude_sts_client_auth() +def test_bedrock_extra_headers(): + try: + litellm.set_verbose = True + response: ModelResponse = completion( + model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0", + messages=messages, + max_tokens=10, + temperature=0.78, + extra_headers={"x-key": "x_key_value"} + ) + # Add any assertions here to check the response + assert len(response.choices) > 0 + assert len(response.choices[0].message.content) > 0 + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_bedrock_claude_3(): try: litellm.set_verbose = True