mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Fix anthropic thinking + response_format (#9594)
* fix(anthropic/chat/transformation.py): Don't set tool choice on response_format conversion when thinking is enabled Not allowed by Anthropic Fixes https://github.com/BerriAI/litellm/issues/8901 * refactor: move test to base anthropic chat tests ensures consistent behaviour across vertex/anthropic/bedrock * fix(anthropic/chat/transformation.py): if thinking token is specified and max tokens is not - ensure max token to anthropic is higher than thinking tokens * feat(converse_transformation.py): correctly handle thinking + response format on Bedrock Converse Fixes https://github.com/BerriAI/litellm/issues/8901 * fix(converse_transformation.py): correctly handle adding max tokens * test: handle service unavailable error
This commit is contained in:
parent
7c1026e210
commit
5f8859eda8
8 changed files with 96 additions and 6 deletions
|
@ -7,6 +7,7 @@ DEFAULT_MAX_RETRIES = 2
|
|||
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
|
||||
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
|
||||
)
|
||||
DEFAULT_MAX_TOKENS = 4096
|
||||
DEFAULT_REDIS_SYNC_INTERVAL = 1
|
||||
DEFAULT_COOLDOWN_TIME_SECONDS = 5
|
||||
DEFAULT_REPLICATE_POLLING_RETRIES = 5
|
||||
|
|
|
@ -300,6 +300,15 @@ class AnthropicConfig(BaseConfig):
|
|||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
|
||||
is_thinking_enabled = self.is_thinking_enabled(
|
||||
non_default_params=non_default_params
|
||||
)
|
||||
|
||||
## handle thinking tokens
|
||||
self.update_optional_params_with_thinking_tokens(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
)
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
|
@ -349,19 +358,23 @@ class AnthropicConfig(BaseConfig):
|
|||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||
"""
|
||||
|
||||
if not is_thinking_enabled:
|
||||
_tool_choice = {"name": RESPONSE_FORMAT_TOOL_NAME, "type": "tool"}
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
|
||||
_tool = self._create_json_tool_call_for_response_format(
|
||||
json_schema=json_schema,
|
||||
)
|
||||
optional_params = self._add_tools_to_optional_params(
|
||||
optional_params=optional_params, tools=[_tool]
|
||||
)
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
|
||||
optional_params["json_mode"] = True
|
||||
if param == "user":
|
||||
optional_params["metadata"] = {"user_id": value}
|
||||
if param == "thinking":
|
||||
optional_params["thinking"] = value
|
||||
|
||||
return optional_params
|
||||
|
||||
def _create_json_tool_call_for_response_format(
|
||||
|
|
|
@ -13,12 +13,13 @@ from typing import (
|
|||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||
from litellm.constants import DEFAULT_MAX_TOKENS, RESPONSE_FORMAT_TOOL_NAME
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
|
@ -102,6 +103,29 @@ class BaseConfig(ABC):
|
|||
) -> Optional[dict]:
|
||||
return type_to_response_format_param(response_format=response_format)
|
||||
|
||||
def is_thinking_enabled(self, non_default_params: dict) -> bool:
|
||||
return non_default_params.get("thinking", {}).get("type", None) == "enabled"
|
||||
|
||||
def update_optional_params_with_thinking_tokens(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
):
|
||||
"""
|
||||
Handles scenario where max tokens is not specified. For anthropic models (anthropic api/bedrock/vertex ai), this requires having the max tokens being set and being greater than the thinking token budget.
|
||||
|
||||
Checks 'non_default_params' for 'thinking' and 'max_tokens'
|
||||
|
||||
if 'thinking' is enabled and 'max_tokens' is not specified, set 'max_tokens' to the thinking token budget + DEFAULT_MAX_TOKENS
|
||||
"""
|
||||
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
|
||||
if is_thinking_enabled and "max_tokens" not in non_default_params:
|
||||
thinking_token_budget = cast(dict, non_default_params["thinking"]).get(
|
||||
"budget_tokens", None
|
||||
)
|
||||
if thinking_token_budget is not None:
|
||||
optional_params["max_tokens"] = (
|
||||
thinking_token_budget + DEFAULT_MAX_TOKENS
|
||||
)
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
|
|
|
@ -210,6 +210,10 @@ class AmazonConverseConfig(BaseConfig):
|
|||
drop_params: bool,
|
||||
messages: Optional[List[AllMessageValues]] = None,
|
||||
) -> dict:
|
||||
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
|
||||
self.update_optional_params_with_thinking_tokens(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
)
|
||||
for param, value in non_default_params.items():
|
||||
if param == "response_format" and isinstance(value, dict):
|
||||
|
||||
|
@ -247,8 +251,11 @@ class AmazonConverseConfig(BaseConfig):
|
|||
optional_params = self._add_tools_to_optional_params(
|
||||
optional_params=optional_params, tools=[_tool]
|
||||
)
|
||||
if litellm.utils.supports_tool_choice(
|
||||
if (
|
||||
litellm.utils.supports_tool_choice(
|
||||
model=model, custom_llm_provider=self.custom_llm_provider
|
||||
)
|
||||
and not is_thinking_enabled
|
||||
):
|
||||
optional_params["tool_choice"] = ToolChoiceValuesBlock(
|
||||
tool=SpecificToolChoiceBlock(
|
||||
|
@ -284,6 +291,7 @@ class AmazonConverseConfig(BaseConfig):
|
|||
optional_params["tool_choice"] = _tool_choice_value
|
||||
if param == "thinking":
|
||||
optional_params["thinking"] = value
|
||||
|
||||
return optional_params
|
||||
|
||||
@overload
|
||||
|
|
|
@ -1008,6 +1008,11 @@ class BaseAnthropicChatTest(ABC):
|
|||
"""Must return the base completion call args"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_base_completion_call_args_with_thinking(self) -> dict:
|
||||
"""Must return the base completion call args"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def completion_function(self):
|
||||
return litellm.completion
|
||||
|
@ -1066,3 +1071,21 @@ class BaseAnthropicChatTest(ABC):
|
|||
json.loads(built_response.choices[0].message.content).keys()
|
||||
== json.loads(non_stream_response.choices[0].message.content).keys()
|
||||
), f"Got={json.loads(built_response.choices[0].message.content)}, Expected={json.loads(non_stream_response.choices[0].message.content)}"
|
||||
|
||||
def test_completion_thinking_with_response_format(self):
|
||||
from pydantic import BaseModel
|
||||
|
||||
class RFormat(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
base_completion_call_args = self.get_base_completion_call_args_with_thinking()
|
||||
|
||||
messages = [{"role": "user", "content": "Generate 5 question + answer pairs"}]
|
||||
response = self.completion_function(
|
||||
**base_completion_call_args,
|
||||
messages=messages,
|
||||
response_format=RFormat,
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
|
|
@ -467,6 +467,12 @@ class TestAnthropicCompletion(BaseLLMChatTest, BaseAnthropicChatTest):
|
|||
def get_base_completion_call_args(self) -> dict:
|
||||
return {"model": "anthropic/claude-3-5-sonnet-20240620"}
|
||||
|
||||
def get_base_completion_call_args_with_thinking(self) -> dict:
|
||||
return {
|
||||
"model": "anthropic/claude-3-7-sonnet-latest",
|
||||
"thinking": {"type": "enabled", "budget_tokens": 16000},
|
||||
}
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
|
|
|
@ -35,7 +35,7 @@ from litellm import (
|
|||
from litellm.llms.bedrock.chat import BedrockLLM
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import _bedrock_tools_pt
|
||||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
from base_llm_unit_tests import BaseLLMChatTest, BaseAnthropicChatTest
|
||||
from base_rerank_unit_tests import BaseLLMRerankTest
|
||||
from base_embedding_unit_tests import BaseLLMEmbeddingTest
|
||||
|
||||
|
@ -2191,6 +2191,19 @@ class TestBedrockConverseChatCrossRegion(BaseLLMChatTest):
|
|||
assert cost > 0
|
||||
|
||||
|
||||
class TestBedrockConverseAnthropicUnitTests(BaseAnthropicChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
return {
|
||||
"model": "bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
}
|
||||
|
||||
def get_base_completion_call_args_with_thinking(self) -> dict:
|
||||
return {
|
||||
"model": "bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
"thinking": {"type": "enabled", "budget_tokens": 16000},
|
||||
}
|
||||
|
||||
|
||||
class TestBedrockConverseChatNormal(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
|
|
|
@ -56,6 +56,8 @@ async def test_chat_completion_cohere_citations(stream):
|
|||
assert citations_chunk
|
||||
else:
|
||||
assert response.citations is not None
|
||||
except litellm.ServiceUnavailableError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue