Fix anthropic thinking + response_format (#9594)

* fix(anthropic/chat/transformation.py): Don't set tool choice on response_format conversion when thinking is enabled

Not allowed by Anthropic

Fixes https://github.com/BerriAI/litellm/issues/8901

* refactor: move test to base anthropic chat tests

ensures consistent behaviour across vertex/anthropic/bedrock

* fix(anthropic/chat/transformation.py): if thinking token is specified and max tokens is not - ensure max token to anthropic is higher than thinking tokens

* feat(converse_transformation.py): correctly handle thinking + response format on Bedrock Converse

Fixes https://github.com/BerriAI/litellm/issues/8901

* fix(converse_transformation.py): correctly handle adding max tokens

* test: handle service unavailable error
This commit is contained in:
Krish Dholakia 2025-03-28 15:57:40 -07:00 committed by GitHub
parent 7c1026e210
commit 5f8859eda8
8 changed files with 96 additions and 6 deletions

View file

@ -7,6 +7,7 @@ DEFAULT_MAX_RETRIES = 2
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
)
DEFAULT_MAX_TOKENS = 4096
DEFAULT_REDIS_SYNC_INTERVAL = 1
DEFAULT_COOLDOWN_TIME_SECONDS = 5
DEFAULT_REPLICATE_POLLING_RETRIES = 5

View file

@ -300,6 +300,15 @@ class AnthropicConfig(BaseConfig):
model: str,
drop_params: bool,
) -> dict:
is_thinking_enabled = self.is_thinking_enabled(
non_default_params=non_default_params
)
## handle thinking tokens
self.update_optional_params_with_thinking_tokens(
non_default_params=non_default_params, optional_params=optional_params
)
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
@ -349,19 +358,23 @@ class AnthropicConfig(BaseConfig):
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
_tool_choice = {"name": RESPONSE_FORMAT_TOOL_NAME, "type": "tool"}
if not is_thinking_enabled:
_tool_choice = {"name": RESPONSE_FORMAT_TOOL_NAME, "type": "tool"}
optional_params["tool_choice"] = _tool_choice
_tool = self._create_json_tool_call_for_response_format(
json_schema=json_schema,
)
optional_params = self._add_tools_to_optional_params(
optional_params=optional_params, tools=[_tool]
)
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
if param == "user":
optional_params["metadata"] = {"user_id": value}
if param == "thinking":
optional_params["thinking"] = value
return optional_params
def _create_json_tool_call_for_response_format(

View file

@ -13,12 +13,13 @@ from typing import (
Optional,
Type,
Union,
cast,
)
import httpx
from pydantic import BaseModel
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.constants import DEFAULT_MAX_TOKENS, RESPONSE_FORMAT_TOOL_NAME
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import (
AllMessageValues,
@ -102,6 +103,29 @@ class BaseConfig(ABC):
) -> Optional[dict]:
return type_to_response_format_param(response_format=response_format)
def is_thinking_enabled(self, non_default_params: dict) -> bool:
return non_default_params.get("thinking", {}).get("type", None) == "enabled"
def update_optional_params_with_thinking_tokens(
self, non_default_params: dict, optional_params: dict
):
"""
Handles scenario where max tokens is not specified. For anthropic models (anthropic api/bedrock/vertex ai), this requires having the max tokens being set and being greater than the thinking token budget.
Checks 'non_default_params' for 'thinking' and 'max_tokens'
if 'thinking' is enabled and 'max_tokens' is not specified, set 'max_tokens' to the thinking token budget + DEFAULT_MAX_TOKENS
"""
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
if is_thinking_enabled and "max_tokens" not in non_default_params:
thinking_token_budget = cast(dict, non_default_params["thinking"]).get(
"budget_tokens", None
)
if thinking_token_budget is not None:
optional_params["max_tokens"] = (
thinking_token_budget + DEFAULT_MAX_TOKENS
)
def should_fake_stream(
self,
model: Optional[str],

View file

@ -210,6 +210,10 @@ class AmazonConverseConfig(BaseConfig):
drop_params: bool,
messages: Optional[List[AllMessageValues]] = None,
) -> dict:
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
self.update_optional_params_with_thinking_tokens(
non_default_params=non_default_params, optional_params=optional_params
)
for param, value in non_default_params.items():
if param == "response_format" and isinstance(value, dict):
@ -247,8 +251,11 @@ class AmazonConverseConfig(BaseConfig):
optional_params = self._add_tools_to_optional_params(
optional_params=optional_params, tools=[_tool]
)
if litellm.utils.supports_tool_choice(
model=model, custom_llm_provider=self.custom_llm_provider
if (
litellm.utils.supports_tool_choice(
model=model, custom_llm_provider=self.custom_llm_provider
)
and not is_thinking_enabled
):
optional_params["tool_choice"] = ToolChoiceValuesBlock(
tool=SpecificToolChoiceBlock(
@ -284,6 +291,7 @@ class AmazonConverseConfig(BaseConfig):
optional_params["tool_choice"] = _tool_choice_value
if param == "thinking":
optional_params["thinking"] = value
return optional_params
@overload

View file

@ -1008,6 +1008,11 @@ class BaseAnthropicChatTest(ABC):
"""Must return the base completion call args"""
pass
@abstractmethod
def get_base_completion_call_args_with_thinking(self) -> dict:
"""Must return the base completion call args"""
pass
@property
def completion_function(self):
return litellm.completion
@ -1066,3 +1071,21 @@ class BaseAnthropicChatTest(ABC):
json.loads(built_response.choices[0].message.content).keys()
== json.loads(non_stream_response.choices[0].message.content).keys()
), f"Got={json.loads(built_response.choices[0].message.content)}, Expected={json.loads(non_stream_response.choices[0].message.content)}"
def test_completion_thinking_with_response_format(self):
from pydantic import BaseModel
class RFormat(BaseModel):
question: str
answer: str
base_completion_call_args = self.get_base_completion_call_args_with_thinking()
messages = [{"role": "user", "content": "Generate 5 question + answer pairs"}]
response = self.completion_function(
**base_completion_call_args,
messages=messages,
response_format=RFormat,
)
print(response)

View file

@ -467,6 +467,12 @@ class TestAnthropicCompletion(BaseLLMChatTest, BaseAnthropicChatTest):
def get_base_completion_call_args(self) -> dict:
return {"model": "anthropic/claude-3-5-sonnet-20240620"}
def get_base_completion_call_args_with_thinking(self) -> dict:
return {
"model": "anthropic/claude-3-7-sonnet-latest",
"thinking": {"type": "enabled", "budget_tokens": 16000},
}
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
from litellm.litellm_core_utils.prompt_templates.factory import (

View file

@ -35,7 +35,7 @@ from litellm import (
from litellm.llms.bedrock.chat import BedrockLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.litellm_core_utils.prompt_templates.factory import _bedrock_tools_pt
from base_llm_unit_tests import BaseLLMChatTest
from base_llm_unit_tests import BaseLLMChatTest, BaseAnthropicChatTest
from base_rerank_unit_tests import BaseLLMRerankTest
from base_embedding_unit_tests import BaseLLMEmbeddingTest
@ -2191,6 +2191,19 @@ class TestBedrockConverseChatCrossRegion(BaseLLMChatTest):
assert cost > 0
class TestBedrockConverseAnthropicUnitTests(BaseAnthropicChatTest):
def get_base_completion_call_args(self) -> dict:
return {
"model": "bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0",
}
def get_base_completion_call_args_with_thinking(self) -> dict:
return {
"model": "bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
"thinking": {"type": "enabled", "budget_tokens": 16000},
}
class TestBedrockConverseChatNormal(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"

View file

@ -56,6 +56,8 @@ async def test_chat_completion_cohere_citations(stream):
assert citations_chunk
else:
assert response.citations is not None
except litellm.ServiceUnavailableError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")