mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Fix bedrock passing response_format: {"type": "text"}
(#8900)
* fix(converse_transformation.py): ignore type: text, value in response_format no-op for bedrock * fix(converse_transformation.py): handle adding response format value to tools * fix(base_invoke_transformation.py): fix 'get_bedrock_invoke_provider' to handle cross-region-inferencing models * test(test_bedrock_completion.py): add unit testing for bedrock invoke provider logic * test: update test * fix(exception_mapping_utils.py): add context window exceeded error handling for databricks provider route * fix(fireworks_ai/): support passing tools + response_format together * fix: cleanup * fix(base_invoke_transformation.py): fix imports
This commit is contained in:
parent
c8dc4f3eec
commit
c84b489d58
8 changed files with 194 additions and 24 deletions
|
@ -278,6 +278,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
"This model's maximum context length is" in error_str
|
||||
or "string too long. Expected a string with maximum length"
|
||||
in error_str
|
||||
or "model's maximum context limit" in error_str
|
||||
):
|
||||
exception_mapping_worked = True
|
||||
raise ContextWindowExceededError(
|
||||
|
@ -692,6 +693,13 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
response=getattr(original_exception, "response", None),
|
||||
litellm_debug_info=extra_information,
|
||||
)
|
||||
elif "model's maximum context limit" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise ContextWindowExceededError(
|
||||
message=f"{custom_llm_provider}Exception: Context Window Error - {error_str}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
)
|
||||
elif "token_quota_reached" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise RateLimitError(
|
||||
|
|
|
@ -293,18 +293,6 @@ class AnthropicConfig(BaseConfig):
|
|||
new_stop = new_v
|
||||
return new_stop
|
||||
|
||||
def _add_tools_to_optional_params(
|
||||
self, optional_params: dict, tools: List[AllAnthropicToolsValues]
|
||||
) -> dict:
|
||||
if "tools" not in optional_params:
|
||||
optional_params["tools"] = tools
|
||||
else:
|
||||
optional_params["tools"] = [
|
||||
*optional_params["tools"],
|
||||
*tools,
|
||||
]
|
||||
return optional_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
|
|
|
@ -111,6 +111,19 @@ class BaseConfig(ABC):
|
|||
"""
|
||||
return False
|
||||
|
||||
def _add_tools_to_optional_params(self, optional_params: dict, tools: List) -> dict:
|
||||
"""
|
||||
Helper util to add tools to optional_params.
|
||||
"""
|
||||
if "tools" not in optional_params:
|
||||
optional_params["tools"] = tools
|
||||
else:
|
||||
optional_params["tools"] = [
|
||||
*optional_params["tools"],
|
||||
*tools,
|
||||
]
|
||||
return optional_params
|
||||
|
||||
def translate_developer_role_to_system_role(
|
||||
self,
|
||||
messages: List[AllMessageValues],
|
||||
|
@ -158,6 +171,7 @@ class BaseConfig(ABC):
|
|||
optional_params: dict,
|
||||
value: dict,
|
||||
is_response_format_supported: bool,
|
||||
enforce_tool_choice: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Follow similar approach to anthropic - translate to a single tool call.
|
||||
|
@ -195,9 +209,11 @@ class BaseConfig(ABC):
|
|||
|
||||
optional_params.setdefault("tools", [])
|
||||
optional_params["tools"].append(_tool)
|
||||
if enforce_tool_choice:
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
|
||||
optional_params["json_mode"] = True
|
||||
else:
|
||||
elif is_response_format_supported:
|
||||
optional_params["response_format"] = value
|
||||
return optional_params
|
||||
|
||||
|
|
|
@ -227,6 +227,10 @@ class AmazonConverseConfig(BaseConfig):
|
|||
json_schema = value["json_schema"]["schema"]
|
||||
schema_name = value["json_schema"]["name"]
|
||||
description = value["json_schema"].get("description")
|
||||
|
||||
if "type" in value and value["type"] == "text":
|
||||
continue
|
||||
|
||||
"""
|
||||
Follow similar approach to anthropic - translate to a single tool call.
|
||||
|
||||
|
@ -240,7 +244,9 @@ class AmazonConverseConfig(BaseConfig):
|
|||
schema_name=schema_name if schema_name != "" else "json_tool_call",
|
||||
description=description,
|
||||
)
|
||||
optional_params["tools"] = [_tool]
|
||||
optional_params = self._add_tools_to_optional_params(
|
||||
optional_params=optional_params, tools=[_tool]
|
||||
)
|
||||
if litellm.utils.supports_tool_choice(
|
||||
model=model, custom_llm_provider=self.custom_llm_provider
|
||||
):
|
||||
|
@ -267,7 +273,9 @@ class AmazonConverseConfig(BaseConfig):
|
|||
if param == "top_p":
|
||||
optional_params["topP"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
optional_params = self._add_tools_to_optional_params(
|
||||
optional_params=optional_params, tools=value
|
||||
)
|
||||
if param == "tool_choice":
|
||||
_tool_choice_value = self.map_tool_choice_values(
|
||||
model=model, tool_choice=value, drop_params=drop_params # type: ignore
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import time
|
||||
import urllib.parse
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -531,6 +531,60 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
"""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_invoke_provider(
|
||||
model: str,
|
||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the bedrock provider from the model
|
||||
|
||||
handles 4 scenarios:
|
||||
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
|
||||
"""
|
||||
if model.startswith("invoke/"):
|
||||
model = model.replace("invoke/", "", 1)
|
||||
|
||||
_split_model = model.split(".")[0]
|
||||
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||
|
||||
# If not a known provider, check for pattern with two slashes
|
||||
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
|
||||
if provider is not None:
|
||||
return provider
|
||||
|
||||
# check if provider == "nova"
|
||||
if "nova" in model:
|
||||
return "nova"
|
||||
|
||||
for provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
if provider in model:
|
||||
return provider
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_from_model_path(
|
||||
model_path: str,
|
||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the provider from a model path with format: provider/model-name
|
||||
|
||||
Args:
|
||||
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
|
||||
|
||||
Returns:
|
||||
Optional[str]: The provider name, or None if no valid provider found
|
||||
"""
|
||||
parts = model_path.split("/")
|
||||
if len(parts) >= 1:
|
||||
provider = parts[0]
|
||||
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
|
||||
return None
|
||||
|
||||
def get_bedrock_model_id(
|
||||
self,
|
||||
optional_params: dict,
|
||||
|
|
|
@ -90,6 +90,11 @@ class FireworksAIConfig(OpenAIGPTConfig):
|
|||
) -> dict:
|
||||
|
||||
supported_openai_params = self.get_supported_openai_params(model=model)
|
||||
is_tools_set = any(
|
||||
param == "tools" and value is not None
|
||||
for param, value in non_default_params.items()
|
||||
)
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param == "tool_choice":
|
||||
if value == "required":
|
||||
|
@ -98,18 +103,30 @@ class FireworksAIConfig(OpenAIGPTConfig):
|
|||
else:
|
||||
# pass through the value of tool choice
|
||||
optional_params["tool_choice"] = value
|
||||
elif (
|
||||
param == "response_format" and value.get("type", None) == "json_schema"
|
||||
):
|
||||
elif param == "response_format":
|
||||
|
||||
if (
|
||||
is_tools_set
|
||||
): # fireworks ai doesn't support tools and response_format together
|
||||
optional_params = self._add_response_format_to_tools(
|
||||
optional_params=optional_params,
|
||||
value=value,
|
||||
is_response_format_supported=False,
|
||||
enforce_tool_choice=False, # tools and response_format are both set, don't enforce tool_choice
|
||||
)
|
||||
elif "json_schema" in value:
|
||||
optional_params["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": value["json_schema"]["schema"],
|
||||
}
|
||||
else:
|
||||
optional_params["response_format"] = value
|
||||
elif param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
elif param in supported_openai_params:
|
||||
if value is not None:
|
||||
optional_params[param] = value
|
||||
|
||||
return optional_params
|
||||
|
||||
def _add_transform_inline_image_block(
|
||||
|
|
|
@ -254,6 +254,56 @@ class BaseLLMChatTest(ABC):
|
|||
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||
assert response.choices[0].message.content is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response_format",
|
||||
[
|
||||
{"type": "text"},
|
||||
],
|
||||
)
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
def test_response_format_type_text_with_tool_calls_no_tool_choice(
|
||||
self, response_format
|
||||
):
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
messages = [
|
||||
{"role": "user", "content": "What's the weather like in Boston today?"},
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
try:
|
||||
response = self.completion_function(
|
||||
**base_completion_call_args,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
drop_params=True,
|
||||
)
|
||||
except litellm.ContextWindowExceededError:
|
||||
pytest.skip("Model exceeded context window")
|
||||
assert response is not None
|
||||
|
||||
def test_response_format_type_text(self):
|
||||
"""
|
||||
Test that the response format type text does not lead to tool calls
|
||||
|
@ -287,6 +337,7 @@ class BaseLLMChatTest(ABC):
|
|||
|
||||
print(f"translated_params={translated_params}")
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
def test_json_response_pydantic_obj(self):
|
||||
litellm.set_verbose = True
|
||||
|
|
|
@ -2717,6 +2717,33 @@ def test_bedrock_top_k_param(model, expected_params):
|
|||
assert data["additionalModelRequestFields"] == expected_params
|
||||
|
||||
|
||||
|
||||
def test_bedrock_invoke_provider():
|
||||
assert (
|
||||
litellm.AmazonInvokeConfig().get_bedrock_invoke_provider(
|
||||
"bedrock/invoke/us.anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
)
|
||||
== "anthropic"
|
||||
)
|
||||
assert (
|
||||
litellm.AmazonInvokeConfig().get_bedrock_invoke_provider(
|
||||
"bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
)
|
||||
== "anthropic"
|
||||
)
|
||||
assert (
|
||||
litellm.AmazonInvokeConfig().get_bedrock_invoke_provider(
|
||||
"bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n"
|
||||
)
|
||||
== "llama"
|
||||
)
|
||||
assert (
|
||||
litellm.AmazonInvokeConfig().get_bedrock_invoke_provider(
|
||||
"us.amazon.nova-pro-v1:0"
|
||||
)
|
||||
== "nova"
|
||||
)
|
||||
|
||||
def test_bedrock_description_param():
|
||||
from litellm import completion
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
@ -2754,3 +2781,4 @@ def test_bedrock_description_param():
|
|||
assert (
|
||||
"Find the meaning inside a poem" in request_body_str
|
||||
) # assert description is passed
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue