diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 0f52d3abc..a32cee381 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -495,6 +495,15 @@ class AmazonStabilityConfig: } +def add_custom_header(headers): + """Closure to capture the headers and add them.""" + def callback(request, **kwargs): + """Actual callback function that Boto3 will call.""" + for header_name, header_value in headers.items(): + request.headers.add_header(header_name, header_value) + return callback + + def init_bedrock_client( region_name=None, aws_access_key_id: Optional[str] = None, @@ -504,12 +513,12 @@ def init_bedrock_client( aws_session_name: Optional[str] = None, aws_profile_name: Optional[str] = None, aws_role_name: Optional[str] = None, + extra_headers: Optional[dict] = None, timeout: Optional[int] = None, ): # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) standard_aws_region_name = get_secret("AWS_REGION", None) - ## CHECK IS 'os.environ/' passed in # Define the list of parameters to check params_to_check = [ @@ -618,6 +627,8 @@ def init_bedrock_client( endpoint_url=endpoint_url, config=config, ) + if extra_headers: + client.meta.events.register('before-sign.bedrock-runtime.*', add_custom_header(extra_headers)) return client @@ -677,6 +688,7 @@ def completion( litellm_params=None, logger_fn=None, timeout=None, + extra_headers: Optional[dict] = None, ): exception_mapping_worked = False try: @@ -704,6 +716,7 @@ def completion( aws_role_name=aws_role_name, aws_session_name=aws_session_name, aws_profile_name=aws_profile_name, + extra_headers=extra_headers, timeout=timeout, ) diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index abe340e7d..a09b988a5 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -604,13 +604,14 @@ def convert_to_anthropic_tool_result(message: dict) -> str: def convert_to_anthropic_tool_invoke(tool_calls: list) -> str: invokes = "" for tool in tool_calls: - if tool.type != "function": + tool = dict(tool) + if tool["type"] != "function": continue - - tool_name = tool.function.name + tool_function = dict(tool["function"]) + tool_name = tool_function["name"] parameters = "".join( f"<{param}>{val}\n" - for param, val in json.loads(tool.function.arguments).items() + for param, val in json.loads(tool_function["arguments"]).items() ) invokes += ( "\n" diff --git a/litellm/main.py b/litellm/main.py index 724190391..dceaf9bf5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1749,6 +1749,7 @@ def completion( logger_fn=logger_fn, encoding=encoding, logging_obj=logging, + extra_headers=extra_headers, timeout=timeout, ) diff --git a/litellm/utils.py b/litellm/utils.py index 57327473d..4124ea437 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5146,7 +5146,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): """ if custom_llm_provider == "bedrock": if model.startswith("anthropic.claude-3"): - return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params() + return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params() + ["extra_headers"] elif model.startswith("anthropic"): return litellm.AmazonAnthropicConfig().get_supported_openai_params() elif model.startswith("ai21"):