mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Fix bedrock passing response_format: {"type": "text"}
(#8900)
* fix(converse_transformation.py): ignore type: text, value in response_format no-op for bedrock * fix(converse_transformation.py): handle adding response format value to tools * fix(base_invoke_transformation.py): fix 'get_bedrock_invoke_provider' to handle cross-region-inferencing models * test(test_bedrock_completion.py): add unit testing for bedrock invoke provider logic * test: update test * fix(exception_mapping_utils.py): add context window exceeded error handling for databricks provider route * fix(fireworks_ai/): support passing tools + response_format together * fix: cleanup * fix(base_invoke_transformation.py): fix imports
This commit is contained in:
parent
c8dc4f3eec
commit
c84b489d58
8 changed files with 194 additions and 24 deletions
|
@ -278,6 +278,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
"This model's maximum context length is" in error_str
|
"This model's maximum context length is" in error_str
|
||||||
or "string too long. Expected a string with maximum length"
|
or "string too long. Expected a string with maximum length"
|
||||||
in error_str
|
in error_str
|
||||||
|
or "model's maximum context limit" in error_str
|
||||||
):
|
):
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise ContextWindowExceededError(
|
raise ContextWindowExceededError(
|
||||||
|
@ -692,6 +693,13 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
response=getattr(original_exception, "response", None),
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
|
elif "model's maximum context limit" in error_str:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise ContextWindowExceededError(
|
||||||
|
message=f"{custom_llm_provider}Exception: Context Window Error - {error_str}",
|
||||||
|
model=model,
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
elif "token_quota_reached" in error_str:
|
elif "token_quota_reached" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise RateLimitError(
|
raise RateLimitError(
|
||||||
|
|
|
@ -293,18 +293,6 @@ class AnthropicConfig(BaseConfig):
|
||||||
new_stop = new_v
|
new_stop = new_v
|
||||||
return new_stop
|
return new_stop
|
||||||
|
|
||||||
def _add_tools_to_optional_params(
|
|
||||||
self, optional_params: dict, tools: List[AllAnthropicToolsValues]
|
|
||||||
) -> dict:
|
|
||||||
if "tools" not in optional_params:
|
|
||||||
optional_params["tools"] = tools
|
|
||||||
else:
|
|
||||||
optional_params["tools"] = [
|
|
||||||
*optional_params["tools"],
|
|
||||||
*tools,
|
|
||||||
]
|
|
||||||
return optional_params
|
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self,
|
self,
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
|
|
|
@ -111,6 +111,19 @@ class BaseConfig(ABC):
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _add_tools_to_optional_params(self, optional_params: dict, tools: List) -> dict:
|
||||||
|
"""
|
||||||
|
Helper util to add tools to optional_params.
|
||||||
|
"""
|
||||||
|
if "tools" not in optional_params:
|
||||||
|
optional_params["tools"] = tools
|
||||||
|
else:
|
||||||
|
optional_params["tools"] = [
|
||||||
|
*optional_params["tools"],
|
||||||
|
*tools,
|
||||||
|
]
|
||||||
|
return optional_params
|
||||||
|
|
||||||
def translate_developer_role_to_system_role(
|
def translate_developer_role_to_system_role(
|
||||||
self,
|
self,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
|
@ -158,6 +171,7 @@ class BaseConfig(ABC):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
value: dict,
|
value: dict,
|
||||||
is_response_format_supported: bool,
|
is_response_format_supported: bool,
|
||||||
|
enforce_tool_choice: bool = True,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Follow similar approach to anthropic - translate to a single tool call.
|
Follow similar approach to anthropic - translate to a single tool call.
|
||||||
|
@ -195,9 +209,11 @@ class BaseConfig(ABC):
|
||||||
|
|
||||||
optional_params.setdefault("tools", [])
|
optional_params.setdefault("tools", [])
|
||||||
optional_params["tools"].append(_tool)
|
optional_params["tools"].append(_tool)
|
||||||
|
if enforce_tool_choice:
|
||||||
optional_params["tool_choice"] = _tool_choice
|
optional_params["tool_choice"] = _tool_choice
|
||||||
|
|
||||||
optional_params["json_mode"] = True
|
optional_params["json_mode"] = True
|
||||||
else:
|
elif is_response_format_supported:
|
||||||
optional_params["response_format"] = value
|
optional_params["response_format"] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
|
@ -227,6 +227,10 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
json_schema = value["json_schema"]["schema"]
|
json_schema = value["json_schema"]["schema"]
|
||||||
schema_name = value["json_schema"]["name"]
|
schema_name = value["json_schema"]["name"]
|
||||||
description = value["json_schema"].get("description")
|
description = value["json_schema"].get("description")
|
||||||
|
|
||||||
|
if "type" in value and value["type"] == "text":
|
||||||
|
continue
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Follow similar approach to anthropic - translate to a single tool call.
|
Follow similar approach to anthropic - translate to a single tool call.
|
||||||
|
|
||||||
|
@ -240,7 +244,9 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
schema_name=schema_name if schema_name != "" else "json_tool_call",
|
schema_name=schema_name if schema_name != "" else "json_tool_call",
|
||||||
description=description,
|
description=description,
|
||||||
)
|
)
|
||||||
optional_params["tools"] = [_tool]
|
optional_params = self._add_tools_to_optional_params(
|
||||||
|
optional_params=optional_params, tools=[_tool]
|
||||||
|
)
|
||||||
if litellm.utils.supports_tool_choice(
|
if litellm.utils.supports_tool_choice(
|
||||||
model=model, custom_llm_provider=self.custom_llm_provider
|
model=model, custom_llm_provider=self.custom_llm_provider
|
||||||
):
|
):
|
||||||
|
@ -267,7 +273,9 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
if param == "top_p":
|
if param == "top_p":
|
||||||
optional_params["topP"] = value
|
optional_params["topP"] = value
|
||||||
if param == "tools":
|
if param == "tools":
|
||||||
optional_params["tools"] = value
|
optional_params = self._add_tools_to_optional_params(
|
||||||
|
optional_params=optional_params, tools=value
|
||||||
|
)
|
||||||
if param == "tool_choice":
|
if param == "tool_choice":
|
||||||
_tool_choice_value = self.map_tool_choice_values(
|
_tool_choice_value = self.map_tool_choice_values(
|
||||||
model=model, tool_choice=value, drop_params=drop_params # type: ignore
|
model=model, tool_choice=value, drop_params=drop_params # type: ignore
|
||||||
|
|
|
@ -3,7 +3,7 @@ import json
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -531,6 +531,60 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_bedrock_invoke_provider(
|
||||||
|
model: str,
|
||||||
|
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||||
|
"""
|
||||||
|
Helper function to get the bedrock provider from the model
|
||||||
|
|
||||||
|
handles 4 scenarios:
|
||||||
|
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||||
|
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||||
|
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||||
|
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
|
||||||
|
"""
|
||||||
|
if model.startswith("invoke/"):
|
||||||
|
model = model.replace("invoke/", "", 1)
|
||||||
|
|
||||||
|
_split_model = model.split(".")[0]
|
||||||
|
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||||
|
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||||
|
|
||||||
|
# If not a known provider, check for pattern with two slashes
|
||||||
|
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
|
||||||
|
if provider is not None:
|
||||||
|
return provider
|
||||||
|
|
||||||
|
# check if provider == "nova"
|
||||||
|
if "nova" in model:
|
||||||
|
return "nova"
|
||||||
|
|
||||||
|
for provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||||
|
if provider in model:
|
||||||
|
return provider
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_provider_from_model_path(
|
||||||
|
model_path: str,
|
||||||
|
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||||
|
"""
|
||||||
|
Helper function to get the provider from a model path with format: provider/model-name
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: The provider name, or None if no valid provider found
|
||||||
|
"""
|
||||||
|
parts = model_path.split("/")
|
||||||
|
if len(parts) >= 1:
|
||||||
|
provider = parts[0]
|
||||||
|
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||||
|
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
|
||||||
|
return None
|
||||||
|
|
||||||
def get_bedrock_model_id(
|
def get_bedrock_model_id(
|
||||||
self,
|
self,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
|
|
@ -90,6 +90,11 @@ class FireworksAIConfig(OpenAIGPTConfig):
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
||||||
supported_openai_params = self.get_supported_openai_params(model=model)
|
supported_openai_params = self.get_supported_openai_params(model=model)
|
||||||
|
is_tools_set = any(
|
||||||
|
param == "tools" and value is not None
|
||||||
|
for param, value in non_default_params.items()
|
||||||
|
)
|
||||||
|
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param == "tool_choice":
|
if param == "tool_choice":
|
||||||
if value == "required":
|
if value == "required":
|
||||||
|
@ -98,18 +103,30 @@ class FireworksAIConfig(OpenAIGPTConfig):
|
||||||
else:
|
else:
|
||||||
# pass through the value of tool choice
|
# pass through the value of tool choice
|
||||||
optional_params["tool_choice"] = value
|
optional_params["tool_choice"] = value
|
||||||
elif (
|
elif param == "response_format":
|
||||||
param == "response_format" and value.get("type", None) == "json_schema"
|
|
||||||
):
|
if (
|
||||||
|
is_tools_set
|
||||||
|
): # fireworks ai doesn't support tools and response_format together
|
||||||
|
optional_params = self._add_response_format_to_tools(
|
||||||
|
optional_params=optional_params,
|
||||||
|
value=value,
|
||||||
|
is_response_format_supported=False,
|
||||||
|
enforce_tool_choice=False, # tools and response_format are both set, don't enforce tool_choice
|
||||||
|
)
|
||||||
|
elif "json_schema" in value:
|
||||||
optional_params["response_format"] = {
|
optional_params["response_format"] = {
|
||||||
"type": "json_object",
|
"type": "json_object",
|
||||||
"schema": value["json_schema"]["schema"],
|
"schema": value["json_schema"]["schema"],
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
optional_params["response_format"] = value
|
||||||
elif param == "max_completion_tokens":
|
elif param == "max_completion_tokens":
|
||||||
optional_params["max_tokens"] = value
|
optional_params["max_tokens"] = value
|
||||||
elif param in supported_openai_params:
|
elif param in supported_openai_params:
|
||||||
if value is not None:
|
if value is not None:
|
||||||
optional_params[param] = value
|
optional_params[param] = value
|
||||||
|
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
def _add_transform_inline_image_block(
|
def _add_transform_inline_image_block(
|
||||||
|
|
|
@ -254,6 +254,56 @@ class BaseLLMChatTest(ABC):
|
||||||
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||||
assert response.choices[0].message.content is not None
|
assert response.choices[0].message.content is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"response_format",
|
||||||
|
[
|
||||||
|
{"type": "text"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.flaky(retries=6, delay=1)
|
||||||
|
def test_response_format_type_text_with_tool_calls_no_tool_choice(
|
||||||
|
self, response_format
|
||||||
|
):
|
||||||
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "What's the weather like in Boston today?"},
|
||||||
|
]
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
response = self.completion_function(
|
||||||
|
**base_completion_call_args,
|
||||||
|
messages=messages,
|
||||||
|
response_format=response_format,
|
||||||
|
tools=tools,
|
||||||
|
drop_params=True,
|
||||||
|
)
|
||||||
|
except litellm.ContextWindowExceededError:
|
||||||
|
pytest.skip("Model exceeded context window")
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
def test_response_format_type_text(self):
|
def test_response_format_type_text(self):
|
||||||
"""
|
"""
|
||||||
Test that the response format type text does not lead to tool calls
|
Test that the response format type text does not lead to tool calls
|
||||||
|
@ -287,6 +337,7 @@ class BaseLLMChatTest(ABC):
|
||||||
|
|
||||||
print(f"translated_params={translated_params}")
|
print(f"translated_params={translated_params}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.flaky(retries=6, delay=1)
|
@pytest.mark.flaky(retries=6, delay=1)
|
||||||
def test_json_response_pydantic_obj(self):
|
def test_json_response_pydantic_obj(self):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
|
@ -2717,6 +2717,33 @@ def test_bedrock_top_k_param(model, expected_params):
|
||||||
assert data["additionalModelRequestFields"] == expected_params
|
assert data["additionalModelRequestFields"] == expected_params
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_invoke_provider():
|
||||||
|
assert (
|
||||||
|
litellm.AmazonInvokeConfig().get_bedrock_invoke_provider(
|
||||||
|
"bedrock/invoke/us.anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||||
|
)
|
||||||
|
== "anthropic"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
litellm.AmazonInvokeConfig().get_bedrock_invoke_provider(
|
||||||
|
"bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||||
|
)
|
||||||
|
== "anthropic"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
litellm.AmazonInvokeConfig().get_bedrock_invoke_provider(
|
||||||
|
"bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n"
|
||||||
|
)
|
||||||
|
== "llama"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
litellm.AmazonInvokeConfig().get_bedrock_invoke_provider(
|
||||||
|
"us.amazon.nova-pro-v1:0"
|
||||||
|
)
|
||||||
|
== "nova"
|
||||||
|
)
|
||||||
|
|
||||||
def test_bedrock_description_param():
|
def test_bedrock_description_param():
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
@ -2754,3 +2781,4 @@ def test_bedrock_description_param():
|
||||||
assert (
|
assert (
|
||||||
"Find the meaning inside a poem" in request_body_str
|
"Find the meaning inside a poem" in request_body_str
|
||||||
) # assert description is passed
|
) # assert description is passed
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue