From 872ff6176d506dfb4f7e2ee8d22505c98b96c6c6 Mon Sep 17 00:00:00 2001 From: Lucca Zenobio Date: Wed, 20 Mar 2024 15:22:23 -0300 Subject: [PATCH 1/4] updates --- litellm/llms/prompt_templates/factory.py | 6 +++--- litellm/utils.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index b23f10315..abe340e7d 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -604,13 +604,13 @@ def convert_to_anthropic_tool_result(message: dict) -> str: def convert_to_anthropic_tool_invoke(tool_calls: list) -> str: invokes = "" for tool in tool_calls: - if tool["type"] != "function": + if tool.type != "function": continue - tool_name = tool["function"]["name"] + tool_name = tool.function.name parameters = "".join( f"<{param}>{val}\n" - for param, val in json.loads(tool["function"]["arguments"]).items() + for param, val in json.loads(tool.function.arguments).items() ) invokes += ( "\n" diff --git a/litellm/utils.py b/litellm/utils.py index a8c003181..57327473d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -271,7 +271,10 @@ class Message(OpenAIObject): if tool_calls is not None: self.tool_calls = [] for tool_call in tool_calls: - self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call)) + if isinstance(tool_call, dict): + self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call)) + else: + self.tool_calls.append(tool_call) if logprobs is not None: self._logprobs = logprobs From 0c0780be83c7a4559684d44c538fea6a5b07cad3 Mon Sep 17 00:00:00 2001 From: Lucca Zenobio Date: Thu, 21 Mar 2024 10:43:27 -0300 Subject: [PATCH 2/4] extra headers --- litellm/llms/bedrock.py | 15 ++++++++++++++- litellm/llms/prompt_templates/factory.py | 9 +++++---- litellm/main.py | 1 + litellm/utils.py | 2 +- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 0f52d3abc..a32cee381 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -495,6 +495,15 @@ class AmazonStabilityConfig: } +def add_custom_header(headers): + """Closure to capture the headers and add them.""" + def callback(request, **kwargs): + """Actual callback function that Boto3 will call.""" + for header_name, header_value in headers.items(): + request.headers.add_header(header_name, header_value) + return callback + + def init_bedrock_client( region_name=None, aws_access_key_id: Optional[str] = None, @@ -504,12 +513,12 @@ def init_bedrock_client( aws_session_name: Optional[str] = None, aws_profile_name: Optional[str] = None, aws_role_name: Optional[str] = None, + extra_headers: Optional[dict] = None, timeout: Optional[int] = None, ): # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) standard_aws_region_name = get_secret("AWS_REGION", None) - ## CHECK IS 'os.environ/' passed in # Define the list of parameters to check params_to_check = [ @@ -618,6 +627,8 @@ def init_bedrock_client( endpoint_url=endpoint_url, config=config, ) + if extra_headers: + client.meta.events.register('before-sign.bedrock-runtime.*', add_custom_header(extra_headers)) return client @@ -677,6 +688,7 @@ def completion( litellm_params=None, logger_fn=None, timeout=None, + extra_headers: Optional[dict] = None, ): exception_mapping_worked = False try: @@ -704,6 +716,7 @@ def completion( aws_role_name=aws_role_name, aws_session_name=aws_session_name, aws_profile_name=aws_profile_name, + extra_headers=extra_headers, timeout=timeout, ) diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index abe340e7d..a09b988a5 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -604,13 +604,14 @@ def convert_to_anthropic_tool_result(message: dict) -> str: def convert_to_anthropic_tool_invoke(tool_calls: list) -> str: invokes = "" for tool in tool_calls: - if tool.type != "function": + tool = dict(tool) + if tool["type"] != "function": continue - - tool_name = tool.function.name + tool_function = dict(tool["function"]) + tool_name = tool_function["name"] parameters = "".join( f"<{param}>{val}\n" - for param, val in json.loads(tool.function.arguments).items() + for param, val in json.loads(tool_function["arguments"]).items() ) invokes += ( "\n" diff --git a/litellm/main.py b/litellm/main.py index 724190391..dceaf9bf5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1749,6 +1749,7 @@ def completion( logger_fn=logger_fn, encoding=encoding, logging_obj=logging, + extra_headers=extra_headers, timeout=timeout, ) diff --git a/litellm/utils.py b/litellm/utils.py index 57327473d..4124ea437 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5146,7 +5146,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): """ if custom_llm_provider == "bedrock": if model.startswith("anthropic.claude-3"): - return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params() + return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params() + ["extra_headers"] elif model.startswith("anthropic"): return litellm.AmazonAnthropicConfig().get_supported_openai_params() elif model.startswith("ai21"): From cda78a5da0b2d705fb5ac56193c7508dc1ba7c4f Mon Sep 17 00:00:00 2001 From: Lucca Zenobio Date: Mon, 25 Mar 2024 13:08:17 -0300 Subject: [PATCH 3/4] update --- litellm/llms/bedrock.py | 2 +- litellm/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index a32cee381..d13301910 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -128,7 +128,7 @@ class AmazonAnthropicClaude3Config: } def get_supported_openai_params(self): - return ["max_tokens", "tools", "tool_choice", "stream"] + return ["max_tokens", "tools", "tool_choice", "stream", "extra_headers"] def map_openai_params(self, non_default_params: dict, optional_params: dict): for param, value in non_default_params.items(): diff --git a/litellm/utils.py b/litellm/utils.py index 4124ea437..57327473d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5146,7 +5146,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): """ if custom_llm_provider == "bedrock": if model.startswith("anthropic.claude-3"): - return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params() + ["extra_headers"] + return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params() elif model.startswith("anthropic"): return litellm.AmazonAnthropicConfig().get_supported_openai_params() elif model.startswith("ai21"): From a9e2ef62125c462cc62d824f3d90bbc1d0366dfe Mon Sep 17 00:00:00 2001 From: Lucca Zenobio Date: Mon, 29 Apr 2024 10:05:30 -0300 Subject: [PATCH 4/4] test --- litellm/tests/test_bedrock_completion.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index ca2ffea5f..2aab8a3b4 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -207,6 +207,25 @@ def test_completion_bedrock_claude_sts_client_auth(): # test_completion_bedrock_claude_sts_client_auth() +def test_bedrock_extra_headers(): + try: + litellm.set_verbose = True + response: ModelResponse = completion( + model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0", + messages=messages, + max_tokens=10, + temperature=0.78, + extra_headers={"x-key": "x_key_value"} + ) + # Add any assertions here to check the response + assert len(response.choices) > 0 + assert len(response.choices[0].message.content) > 0 + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_bedrock_claude_3(): try: litellm.set_verbose = True