Fix bedrock passing response_format: {"type": "text"} (#8900)

* fix(converse_transformation.py): ignore type: text, value in response_format

no-op for bedrock

* fix(converse_transformation.py): handle adding response format value to tools

* fix(base_invoke_transformation.py): fix 'get_bedrock_invoke_provider' to handle cross-region-inferencing models

* test(test_bedrock_completion.py): add unit testing for bedrock invoke provider logic

* test: update test

* fix(exception_mapping_utils.py): add context window exceeded error handling for databricks provider route

* fix(fireworks_ai/): support passing tools + response_format together

* fix: cleanup

* fix(base_invoke_transformation.py): fix imports
This commit is contained in:
Krish Dholakia 2025-02-28 20:09:59 -08:00 committed by GitHub
parent c8dc4f3eec
commit c84b489d58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 194 additions and 24 deletions

View file

@ -278,6 +278,7 @@ def exception_type( # type: ignore # noqa: PLR0915
"This model's maximum context length is" in error_str "This model's maximum context length is" in error_str
or "string too long. Expected a string with maximum length" or "string too long. Expected a string with maximum length"
in error_str in error_str
or "model's maximum context limit" in error_str
): ):
exception_mapping_worked = True exception_mapping_worked = True
raise ContextWindowExceededError( raise ContextWindowExceededError(
@ -692,6 +693,13 @@ def exception_type( # type: ignore # noqa: PLR0915
response=getattr(original_exception, "response", None), response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif "model's maximum context limit" in error_str:
exception_mapping_worked = True
raise ContextWindowExceededError(
message=f"{custom_llm_provider}Exception: Context Window Error - {error_str}",
model=model,
llm_provider=custom_llm_provider,
)
elif "token_quota_reached" in error_str: elif "token_quota_reached" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError( raise RateLimitError(

View file

@ -293,18 +293,6 @@ class AnthropicConfig(BaseConfig):
new_stop = new_v new_stop = new_v
return new_stop return new_stop
def _add_tools_to_optional_params(
self, optional_params: dict, tools: List[AllAnthropicToolsValues]
) -> dict:
if "tools" not in optional_params:
optional_params["tools"] = tools
else:
optional_params["tools"] = [
*optional_params["tools"],
*tools,
]
return optional_params
def map_openai_params( def map_openai_params(
self, self,
non_default_params: dict, non_default_params: dict,

View file

@ -111,6 +111,19 @@ class BaseConfig(ABC):
""" """
return False return False
def _add_tools_to_optional_params(self, optional_params: dict, tools: List) -> dict:
"""
Helper util to add tools to optional_params.
"""
if "tools" not in optional_params:
optional_params["tools"] = tools
else:
optional_params["tools"] = [
*optional_params["tools"],
*tools,
]
return optional_params
def translate_developer_role_to_system_role( def translate_developer_role_to_system_role(
self, self,
messages: List[AllMessageValues], messages: List[AllMessageValues],
@ -158,6 +171,7 @@ class BaseConfig(ABC):
optional_params: dict, optional_params: dict,
value: dict, value: dict,
is_response_format_supported: bool, is_response_format_supported: bool,
enforce_tool_choice: bool = True,
) -> dict: ) -> dict:
""" """
Follow similar approach to anthropic - translate to a single tool call. Follow similar approach to anthropic - translate to a single tool call.
@ -195,9 +209,11 @@ class BaseConfig(ABC):
optional_params.setdefault("tools", []) optional_params.setdefault("tools", [])
optional_params["tools"].append(_tool) optional_params["tools"].append(_tool)
optional_params["tool_choice"] = _tool_choice if enforce_tool_choice:
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True optional_params["json_mode"] = True
else: elif is_response_format_supported:
optional_params["response_format"] = value optional_params["response_format"] = value
return optional_params return optional_params

View file

@ -227,6 +227,10 @@ class AmazonConverseConfig(BaseConfig):
json_schema = value["json_schema"]["schema"] json_schema = value["json_schema"]["schema"]
schema_name = value["json_schema"]["name"] schema_name = value["json_schema"]["name"]
description = value["json_schema"].get("description") description = value["json_schema"].get("description")
if "type" in value and value["type"] == "text":
continue
""" """
Follow similar approach to anthropic - translate to a single tool call. Follow similar approach to anthropic - translate to a single tool call.
@ -240,7 +244,9 @@ class AmazonConverseConfig(BaseConfig):
schema_name=schema_name if schema_name != "" else "json_tool_call", schema_name=schema_name if schema_name != "" else "json_tool_call",
description=description, description=description,
) )
optional_params["tools"] = [_tool] optional_params = self._add_tools_to_optional_params(
optional_params=optional_params, tools=[_tool]
)
if litellm.utils.supports_tool_choice( if litellm.utils.supports_tool_choice(
model=model, custom_llm_provider=self.custom_llm_provider model=model, custom_llm_provider=self.custom_llm_provider
): ):
@ -267,7 +273,9 @@ class AmazonConverseConfig(BaseConfig):
if param == "top_p": if param == "top_p":
optional_params["topP"] = value optional_params["topP"] = value
if param == "tools": if param == "tools":
optional_params["tools"] = value optional_params = self._add_tools_to_optional_params(
optional_params=optional_params, tools=value
)
if param == "tool_choice": if param == "tool_choice":
_tool_choice_value = self.map_tool_choice_values( _tool_choice_value = self.map_tool_choice_values(
model=model, tool_choice=value, drop_params=drop_params # type: ignore model=model, tool_choice=value, drop_params=drop_params # type: ignore

View file

@ -3,7 +3,7 @@ import json
import time import time
import urllib.parse import urllib.parse
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
import httpx import httpx
@ -531,6 +531,60 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
""" """
return False return False
@staticmethod
def get_bedrock_invoke_provider(
model: str,
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
"""
Helper function to get the bedrock provider from the model
handles 4 scenarios:
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
"""
if model.startswith("invoke/"):
model = model.replace("invoke/", "", 1)
_split_model = model.split(".")[0]
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
# If not a known provider, check for pattern with two slashes
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
if provider is not None:
return provider
# check if provider == "nova"
if "nova" in model:
return "nova"
for provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
if provider in model:
return provider
return None
@staticmethod
def _get_provider_from_model_path(
model_path: str,
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
"""
Helper function to get the provider from a model path with format: provider/model-name
Args:
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
Returns:
Optional[str]: The provider name, or None if no valid provider found
"""
parts = model_path.split("/")
if len(parts) >= 1:
provider = parts[0]
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
return None
def get_bedrock_model_id( def get_bedrock_model_id(
self, self,
optional_params: dict, optional_params: dict,

View file

@ -90,6 +90,11 @@ class FireworksAIConfig(OpenAIGPTConfig):
) -> dict: ) -> dict:
supported_openai_params = self.get_supported_openai_params(model=model) supported_openai_params = self.get_supported_openai_params(model=model)
is_tools_set = any(
param == "tools" and value is not None
for param, value in non_default_params.items()
)
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "tool_choice": if param == "tool_choice":
if value == "required": if value == "required":
@ -98,18 +103,30 @@ class FireworksAIConfig(OpenAIGPTConfig):
else: else:
# pass through the value of tool choice # pass through the value of tool choice
optional_params["tool_choice"] = value optional_params["tool_choice"] = value
elif ( elif param == "response_format":
param == "response_format" and value.get("type", None) == "json_schema"
): if (
optional_params["response_format"] = { is_tools_set
"type": "json_object", ): # fireworks ai doesn't support tools and response_format together
"schema": value["json_schema"]["schema"], optional_params = self._add_response_format_to_tools(
} optional_params=optional_params,
value=value,
is_response_format_supported=False,
enforce_tool_choice=False, # tools and response_format are both set, don't enforce tool_choice
)
elif "json_schema" in value:
optional_params["response_format"] = {
"type": "json_object",
"schema": value["json_schema"]["schema"],
}
else:
optional_params["response_format"] = value
elif param == "max_completion_tokens": elif param == "max_completion_tokens":
optional_params["max_tokens"] = value optional_params["max_tokens"] = value
elif param in supported_openai_params: elif param in supported_openai_params:
if value is not None: if value is not None:
optional_params[param] = value optional_params[param] = value
return optional_params return optional_params
def _add_transform_inline_image_block( def _add_transform_inline_image_block(

View file

@ -254,6 +254,56 @@ class BaseLLMChatTest(ABC):
# relevant issue: https://github.com/BerriAI/litellm/issues/6741 # relevant issue: https://github.com/BerriAI/litellm/issues/6741
assert response.choices[0].message.content is not None assert response.choices[0].message.content is not None
@pytest.mark.parametrize(
"response_format",
[
{"type": "text"},
],
)
@pytest.mark.flaky(retries=6, delay=1)
def test_response_format_type_text_with_tool_calls_no_tool_choice(
self, response_format
):
base_completion_call_args = self.get_base_completion_call_args()
messages = [
{"role": "user", "content": "What's the weather like in Boston today?"},
]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
try:
response = self.completion_function(
**base_completion_call_args,
messages=messages,
response_format=response_format,
tools=tools,
drop_params=True,
)
except litellm.ContextWindowExceededError:
pytest.skip("Model exceeded context window")
assert response is not None
def test_response_format_type_text(self): def test_response_format_type_text(self):
""" """
Test that the response format type text does not lead to tool calls Test that the response format type text does not lead to tool calls
@ -287,6 +337,7 @@ class BaseLLMChatTest(ABC):
print(f"translated_params={translated_params}") print(f"translated_params={translated_params}")
@pytest.mark.flaky(retries=6, delay=1) @pytest.mark.flaky(retries=6, delay=1)
def test_json_response_pydantic_obj(self): def test_json_response_pydantic_obj(self):
litellm.set_verbose = True litellm.set_verbose = True

View file

@ -2717,6 +2717,33 @@ def test_bedrock_top_k_param(model, expected_params):
assert data["additionalModelRequestFields"] == expected_params assert data["additionalModelRequestFields"] == expected_params
def test_bedrock_invoke_provider():
assert (
litellm.AmazonInvokeConfig().get_bedrock_invoke_provider(
"bedrock/invoke/us.anthropic.claude-3-5-sonnet-20240620-v1:0"
)
== "anthropic"
)
assert (
litellm.AmazonInvokeConfig().get_bedrock_invoke_provider(
"bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0"
)
== "anthropic"
)
assert (
litellm.AmazonInvokeConfig().get_bedrock_invoke_provider(
"bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n"
)
== "llama"
)
assert (
litellm.AmazonInvokeConfig().get_bedrock_invoke_provider(
"us.amazon.nova-pro-v1:0"
)
== "nova"
)
def test_bedrock_description_param(): def test_bedrock_description_param():
from litellm import completion from litellm import completion
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
@ -2754,3 +2781,4 @@ def test_bedrock_description_param():
assert ( assert (
"Find the meaning inside a poem" in request_body_str "Find the meaning inside a poem" in request_body_str
) # assert description is passed ) # assert description is passed