From acd2c1783c7b6618764460a2a3bb6b8d7eee0d66 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Wed, 23 Apr 2025 21:56:05 -0700 Subject: [PATCH] fix(converse_transformation.py): support all bedrock - openai params for arn models (#10256) Fixes https://github.com/BerriAI/litellm/issues/10207 --- .../bedrock/chat/converse_transformation.py | 9 +++++++ .../test_bedrock_completion.py | 26 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index 31d7542cb4..8332463c5c 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -107,6 +107,15 @@ class AmazonConverseConfig(BaseConfig): "response_format", ] + if ( + "arn" in model + ): # we can't infer the model from the arn, so just add all params + supported_params.append("tools") + supported_params.append("tool_choice") + supported_params.append("thinking") + supported_params.append("reasoning_effort") + return supported_params + ## Filter out 'cross-region' from model name base_model = BedrockModelInfo.get_base_model(model) diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index d6e8ed4ff8..101e72e3db 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -2972,6 +2972,30 @@ def test_bedrock_application_inference_profile(): client = HTTPHandler() client2 = HTTPHandler() + tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + } + } + } + ] + + with patch.object(client, "post") as mock_post, patch.object( client2, "post" ) as mock_post2: @@ -2981,6 +3005,7 @@ def test_bedrock_application_inference_profile(): messages=[{"role": "user", "content": "Hello, how are you?"}], model_id="arn:aws:bedrock:eu-central-1:000000000000:application-inference-profile/a0a0a0a0a0a0", client=client, + tools=tools ) except Exception as e: print(e) @@ -2990,6 +3015,7 @@ def test_bedrock_application_inference_profile(): model="bedrock/converse/arn:aws:bedrock:eu-central-1:000000000000:application-inference-profile/a0a0a0a0a0a0", messages=[{"role": "user", "content": "Hello, how are you?"}], client=client2, + tools=tools ) except Exception as e: print(e)