Add anthropic thinking + reasoning content support (#8778)

* feat(anthropic/chat/transformation.py): add anthropic thinking param support

* feat(anthropic/chat/transformation.py): support returning thinking content for anthropic on streaming responses

* feat(anthropic/chat/transformation.py): return list of thinking blocks (include block signature)

allows usage in tool call responses

* fix(types/utils.py): extract and map reasoning_content from anthropic as content str

* test: add testing to ensure thinking_blocks are returned at the root

* fix(anthropic/chat/handler.py): return thinking blocks on streaming - include signature

* feat(factory.py): handle anthropic thinking blocks translation if in assistant response

* test: handle openai internal instability

* test: handle openai audio instability

* ci: pin anthropic dep

* test: handle openai audio instability

* fix: fix linting error

* refactor(anthropic/chat/transformation.py): refactor function to remain <50 LOC

* fix: fix linting error

* fix: fix linting error

* fix: fix linting error

* fix: fix linting error
This commit is contained in:
Krish Dholakia 2025-02-24 21:54:30 -08:00 committed by GitHub
parent 9914c166b7
commit 142b195784
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 332 additions and 62 deletions

View file

@ -1939,7 +1939,7 @@ jobs:
pip install "asyncio==3.4.3" pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1" pip install "PyGithub==1.59.1"
pip install "google-cloud-aiplatform==1.59.0" pip install "google-cloud-aiplatform==1.59.0"
pip install anthropic pip install "anthropic==0.21.3"
# Run pytest and generate JUnit XML report # Run pytest and generate JUnit XML report
- run: - run:
name: Build Docker image name: Build Docker image

View file

@ -1444,6 +1444,12 @@ def anthropic_messages_pt( # noqa: PLR0915
## MERGE CONSECUTIVE ASSISTANT CONTENT ## ## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_content_block: ChatCompletionAssistantMessage = messages[msg_i] # type: ignore assistant_content_block: ChatCompletionAssistantMessage = messages[msg_i] # type: ignore
thinking_blocks = assistant_content_block.get("thinking_blocks", None)
if (
thinking_blocks is not None
): # IMPORTANT: ADD THIS FIRST, ELSE ANTHROPIC WILL RAISE AN ERROR
assistant_content.extend(thinking_blocks)
if "content" in assistant_content_block and isinstance( if "content" in assistant_content_block and isinstance(
assistant_content_block["content"], list assistant_content_block["content"], list
): ):

View file

@ -30,6 +30,7 @@ from litellm.types.llms.anthropic import (
UsageDelta, UsageDelta,
) )
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
ChatCompletionThinkingBlock,
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
ChatCompletionUsageBlock, ChatCompletionUsageBlock,
) )
@ -507,6 +508,10 @@ class ModelResponseIterator:
return usage_block return usage_block
def _content_block_delta_helper(self, chunk: dict): def _content_block_delta_helper(self, chunk: dict):
"""
Helper function to handle the content block delta
"""
text = "" text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None tool_use: Optional[ChatCompletionToolCallChunk] = None
provider_specific_fields = {} provider_specific_fields = {}
@ -526,7 +531,17 @@ class ModelResponseIterator:
} }
elif "citation" in content_block["delta"]: elif "citation" in content_block["delta"]:
provider_specific_fields["citation"] = content_block["delta"]["citation"] provider_specific_fields["citation"] = content_block["delta"]["citation"]
elif (
"thinking" in content_block["delta"]
or "signature_delta" == content_block["delta"]
):
provider_specific_fields["thinking_blocks"] = [
ChatCompletionThinkingBlock(
type="thinking",
thinking=content_block["delta"].get("thinking"),
signature_delta=content_block["delta"].get("signature"),
)
]
return text, tool_use, provider_specific_fields return text, tool_use, provider_specific_fields
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:

View file

@ -1,6 +1,6 @@
import json import json
import time import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import httpx import httpx
@ -581,6 +581,43 @@ class AnthropicConfig(BaseConfig):
) )
return _message return _message
def extract_response_content(self, completion_response: dict) -> Tuple[
str,
Optional[List[Any]],
Optional[List[Dict[str, Any]]],
List[ChatCompletionToolCallChunk],
]:
text_content = ""
citations: Optional[List[Any]] = None
thinking_blocks: Optional[List[Dict[str, Any]]] = None
tool_calls: List[ChatCompletionToolCallChunk] = []
for idx, content in enumerate(completion_response["content"]):
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
ChatCompletionToolCallChunk(
id=content["id"],
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=content["name"],
arguments=json.dumps(content["input"]),
),
index=idx,
)
)
## CITATIONS
if content.get("citations", None) is not None:
if citations is None:
citations = []
citations.append(content["citations"])
if content.get("thinking", None) is not None:
if thinking_blocks is None:
thinking_blocks = []
thinking_blocks.append(content)
return text_content, citations, thinking_blocks, tool_calls
def transform_response( def transform_response(
self, self,
model: str, model: str,
@ -628,32 +665,21 @@ class AnthropicConfig(BaseConfig):
) )
else: else:
text_content = "" text_content = ""
citations: List[Any] = [] citations: Optional[List[Any]] = None
thinking_blocks: Optional[List[Dict[str, Any]]] = None
tool_calls: List[ChatCompletionToolCallChunk] = [] tool_calls: List[ChatCompletionToolCallChunk] = []
for idx, content in enumerate(completion_response["content"]):
if content["type"] == "text": text_content, citations, thinking_blocks, tool_calls = (
text_content += content["text"] self.extract_response_content(completion_response=completion_response)
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
ChatCompletionToolCallChunk(
id=content["id"],
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=content["name"],
arguments=json.dumps(content["input"]),
),
index=idx,
) )
)
## CITATIONS
if content.get("citations", None) is not None:
citations.append(content["citations"])
_message = litellm.Message( _message = litellm.Message(
tool_calls=tool_calls, tool_calls=tool_calls,
content=text_content or None, content=text_content or None,
provider_specific_fields={"citations": citations}, provider_specific_fields={
"citations": citations,
"thinking_blocks": thinking_blocks,
},
) )
## HANDLE JSON MODE - anthropic returns single function call ## HANDLE JSON MODE - anthropic returns single function call

View file

@ -9,6 +9,7 @@ from typing import List, Optional, Type, Union
from openai.lib import _parsing, _pydantic from openai.lib import _parsing, _pydantic
from pydantic import BaseModel from pydantic import BaseModel
from litellm._logging import verbose_logger
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ProviderSpecificModelInfo from litellm.types.utils import ProviderSpecificModelInfo
@ -132,6 +133,9 @@ def map_developer_role_to_system_role(
new_messages: List[AllMessageValues] = [] new_messages: List[AllMessageValues] = []
for m in messages: for m in messages:
if m["role"] == "developer": if m["role"] == "developer":
verbose_logger.debug(
"Translating developer role to system role for non-OpenAI providers."
) # ensure user knows what's happening with their input.
new_messages.append({"role": "system", "content": m["content"]}) new_messages.append({"role": "system", "content": m["content"]})
else: else:
new_messages.append(m) new_messages.append(m)

View file

@ -18,7 +18,6 @@ from typing import (
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
from litellm._logging import verbose_logger
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
@ -121,9 +120,6 @@ class BaseConfig(ABC):
Overriden by OpenAI/Azure Overriden by OpenAI/Azure
""" """
verbose_logger.debug(
"Translating developer role to system role for non-OpenAI providers."
) # ensure user knows what's happening with their input.
return map_developer_role_to_system_role(messages=messages) return map_developer_role_to_system_role(messages=messages)
def should_retry_llm_api_inside_llm_translation_on_http_error( def should_retry_llm_api_inside_llm_translation_on_http_error(

View file

@ -1,4 +1,8 @@
model_list: model_list:
- model_name: anthropic/claude-3-7-sonnet-20250219
litellm_params:
model: anthropic/claude-3-7-sonnet-20250219
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: gpt-4 - model_name: gpt-4
litellm_params: litellm_params:
model: openai/gpt-3.5-turbo model: openai/gpt-3.5-turbo

View file

@ -3,7 +3,7 @@ from typing import Any, Dict, Iterable, List, Optional, Union
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from typing_extensions import Literal, Required, TypedDict from typing_extensions import Literal, Required, TypedDict
from .openai import ChatCompletionCachedContent from .openai import ChatCompletionCachedContent, ChatCompletionThinkingBlock
class AnthropicMessagesToolChoice(TypedDict, total=False): class AnthropicMessagesToolChoice(TypedDict, total=False):
@ -62,6 +62,7 @@ class AnthropicMessagesToolUseParam(TypedDict):
AnthropicMessagesAssistantMessageValues = Union[ AnthropicMessagesAssistantMessageValues = Union[
AnthropicMessagesTextParam, AnthropicMessagesTextParam,
AnthropicMessagesToolUseParam, AnthropicMessagesToolUseParam,
ChatCompletionThinkingBlock,
] ]

View file

@ -357,6 +357,12 @@ class ChatCompletionCachedContent(TypedDict):
type: Literal["ephemeral"] type: Literal["ephemeral"]
class ChatCompletionThinkingBlock(TypedDict, total=False):
type: Required[Literal["thinking"]]
thinking: str
signature_delta: str
class OpenAIChatCompletionTextObject(TypedDict): class OpenAIChatCompletionTextObject(TypedDict):
type: Literal["text"] type: Literal["text"]
text: str text: str
@ -450,6 +456,7 @@ class OpenAIChatCompletionAssistantMessage(TypedDict, total=False):
class ChatCompletionAssistantMessage(OpenAIChatCompletionAssistantMessage, total=False): class ChatCompletionAssistantMessage(OpenAIChatCompletionAssistantMessage, total=False):
cache_control: ChatCompletionCachedContent cache_control: ChatCompletionCachedContent
thinking_blocks: Optional[List[ChatCompletionThinkingBlock]]
class ChatCompletionToolMessage(TypedDict): class ChatCompletionToolMessage(TypedDict):

View file

@ -457,6 +457,43 @@ Reference:
ChatCompletionMessage(content='This is a test', role='assistant', function_call=None, tool_calls=None)) ChatCompletionMessage(content='This is a test', role='assistant', function_call=None, tool_calls=None))
""" """
REASONING_CONTENT_COMPATIBLE_PARAMS = [
"thinking_blocks",
"reasoning_content",
]
def map_reasoning_content(provider_specific_fields: Dict[str, Any]) -> str:
"""
Extract reasoning_content from provider_specific_fields
"""
reasoning_content: str = ""
for k, v in provider_specific_fields.items():
if k == "thinking_blocks" and isinstance(v, list):
_reasoning_content = ""
for block in v:
if block.get("type") == "thinking":
_reasoning_content += block.get("thinking", "")
reasoning_content = _reasoning_content
elif k == "reasoning_content":
reasoning_content = v
return reasoning_content
def add_provider_specific_fields(
object: BaseModel, provider_specific_fields: Optional[Dict[str, Any]]
):
if not provider_specific_fields: # set if provider_specific_fields is not empty
return
setattr(object, "provider_specific_fields", provider_specific_fields)
for k, v in provider_specific_fields.items():
if v is not None:
setattr(object, k, v)
if k in REASONING_CONTENT_COMPATIBLE_PARAMS and k != "reasoning_content":
reasoning_content = map_reasoning_content({k: v})
setattr(object, "reasoning_content", reasoning_content)
class Message(OpenAIObject): class Message(OpenAIObject):
content: Optional[str] content: Optional[str]
@ -511,10 +548,7 @@ class Message(OpenAIObject):
# OpenAI compatible APIs like mistral API will raise an error if audio is passed in # OpenAI compatible APIs like mistral API will raise an error if audio is passed in
del self.audio del self.audio
if provider_specific_fields: # set if provider_specific_fields is not empty add_provider_specific_fields(self, provider_specific_fields)
self.provider_specific_fields = provider_specific_fields
for k, v in provider_specific_fields.items():
setattr(self, k, v)
def get(self, key, default=None): def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist # Custom .get() method to access attributes with a default value if the attribute doesn't exist
@ -551,11 +585,7 @@ class Delta(OpenAIObject):
**params, **params,
): ):
super(Delta, self).__init__(**params) super(Delta, self).__init__(**params)
provider_specific_fields: Dict[str, Any] = {} add_provider_specific_fields(self, params.get("provider_specific_fields", {}))
if "reasoning_content" in params:
provider_specific_fields["reasoning_content"] = params["reasoning_content"]
setattr(self, "reasoning_content", params["reasoning_content"])
self.content = content self.content = content
self.role = role self.role = role
# Set default values and correct types # Set default values and correct types
@ -563,9 +593,6 @@ class Delta(OpenAIObject):
self.tool_calls: Optional[List[Union[ChatCompletionDeltaToolCall, Any]]] = None self.tool_calls: Optional[List[Union[ChatCompletionDeltaToolCall, Any]]] = None
self.audio: Optional[ChatCompletionAudioResponse] = None self.audio: Optional[ChatCompletionAudioResponse] = None
if provider_specific_fields: # set if provider_specific_fields is not empty
self.provider_specific_fields = provider_specific_fields
if function_call is not None and isinstance(function_call, dict): if function_call is not None and isinstance(function_call, dict):
self.function_call = FunctionCall(**function_call) self.function_call = FunctionCall(**function_call)
else: else:

View file

@ -1161,3 +1161,53 @@ def test_anthropic_citations_api_streaming():
has_citations = True has_citations = True
assert has_citations assert has_citations
def test_anthropic_thinking_output():
from litellm import completion
resp = completion(
model="anthropic/claude-3-7-sonnet-20250219",
messages=[{"role": "user", "content": "What is the capital of France?"}],
thinking={"type": "enabled", "budget_tokens": 1024},
)
print(resp.choices[0].message)
assert (
resp.choices[0].message.provider_specific_fields["thinking_blocks"] is not None
)
assert resp.choices[0].message.reasoning_content is not None
assert isinstance(resp.choices[0].message.reasoning_content, str)
assert resp.choices[0].message.thinking_blocks is not None
assert isinstance(resp.choices[0].message.thinking_blocks, list)
assert len(resp.choices[0].message.thinking_blocks) > 0
def test_anthropic_thinking_output_stream():
# litellm.set_verbose = True
try:
# litellm._turn_on_debug()
resp = litellm.completion(
model="anthropic/claude-3-7-sonnet-20250219",
messages=[{"role": "user", "content": "Tell me a joke."}],
stream=True,
thinking={"type": "enabled", "budget_tokens": 1024},
timeout=5,
)
reasoning_content_exists = False
for chunk in resp:
print(f"chunk 2: {chunk}")
if (
hasattr(chunk.choices[0].delta, "thinking_blocks")
and chunk.choices[0].delta.thinking_blocks is not None
and chunk.choices[0].delta.reasoning_content is not None
and isinstance(chunk.choices[0].delta.thinking_blocks, list)
and len(chunk.choices[0].delta.thinking_blocks) > 0
and isinstance(chunk.choices[0].delta.reasoning_content, str)
):
reasoning_content_exists = True
break
assert reasoning_content_exists
except litellm.Timeout:
pytest.skip("Model is timing out")

View file

@ -67,6 +67,9 @@ async def test_audio_output_from_model(stream):
except litellm.Timeout as e: except litellm.Timeout as e:
print(e) print(e)
pytest.skip("Skipping test due to timeout") pytest.skip("Skipping test due to timeout")
except Exception as e:
if "openai-internal" in str(e):
pytest.skip("Skipping test due to openai-internal error")
if stream is True: if stream is True:
await check_streaming_response(completion) await check_streaming_response(completion)
@ -86,7 +89,7 @@ async def test_audio_input_to_model(stream):
audio_format = "pcm16" audio_format = "pcm16"
if stream is False: if stream is False:
audio_format = "wav" audio_format = "wav"
litellm.set_verbose = True litellm._turn_on_debug()
url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav" url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav"
response = requests.get(url) response = requests.get(url)
response.raise_for_status() response.raise_for_status()
@ -114,7 +117,9 @@ async def test_audio_input_to_model(stream):
except litellm.Timeout as e: except litellm.Timeout as e:
print(e) print(e)
pytest.skip("Skipping test due to timeout") pytest.skip("Skipping test due to timeout")
except Exception as e:
if "openai-internal" in str(e):
pytest.skip("Skipping test due to openai-internal error")
if stream is True: if stream is True:
await check_streaming_response(completion) await check_streaming_response(completion)
else: else:

View file

@ -1320,13 +1320,19 @@ def test_standard_logging_payload_audio(turn_off_message_logging, stream):
with patch.object( with patch.object(
customHandler, "log_success_event", new=MagicMock() customHandler, "log_success_event", new=MagicMock()
) as mock_client: ) as mock_client:
try:
response = litellm.completion( response = litellm.completion(
model="gpt-4o-audio-preview", model="gpt-4o-audio-preview",
modalities=["text", "audio"], modalities=["text", "audio"],
audio={"voice": "alloy", "format": "pcm16"}, audio={"voice": "alloy", "format": "pcm16"},
messages=[{"role": "user", "content": "response in 1 word - yes or no"}], messages=[
{"role": "user", "content": "response in 1 word - yes or no"}
],
stream=stream, stream=stream,
) )
except Exception as e:
if "openai-internal" in str(e):
pytest.skip("Skipping test due to openai-internal error")
if stream: if stream:
for chunk in response: for chunk in response:

View file

@ -157,6 +157,113 @@ def test_aaparallel_function_call(model):
# test_parallel_function_call() # test_parallel_function_call()
@pytest.mark.parametrize(
"model",
[
"anthropic/claude-3-7-sonnet-20250219",
],
)
@pytest.mark.flaky(retries=3, delay=1)
def test_aaparallel_function_call_with_anthropic_thinking(model):
try:
litellm._turn_on_debug()
litellm.modify_params = True
# Step 1: send the conversation and available functions to the model
messages = [
{
"role": "user",
"content": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses",
}
]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
response = litellm.completion(
model=model,
messages=messages,
tools=tools,
tool_choice="auto", # auto is default, but we'll be explicit
thinking={"type": "enabled", "budget_tokens": 1024},
)
print("Response\n", response)
response_message = response.choices[0].message
tool_calls = response_message.tool_calls
print("Expecting there to be 3 tool calls")
assert (
len(tool_calls) > 0
) # this has to call the function for SF, Tokyo and paris
# Step 2: check if the model wanted to call a function
print(f"tool_calls: {tool_calls}")
if tool_calls:
# Step 3: call the function
# Note: the JSON response may not always be valid; be sure to handle errors
available_functions = {
"get_current_weather": get_current_weather,
} # only one function in this example, but you can have multiple
messages.append(
response_message
) # extend conversation with assistant's reply
print("Response message\n", response_message)
# Step 4: send the info for each function call and function response to the model
for tool_call in tool_calls:
function_name = tool_call.function.name
if function_name not in available_functions:
# the model called a function that does not exist in available_functions - don't try calling anything
return
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
function_response = function_to_call(
location=function_args.get("location"),
unit=function_args.get("unit"),
)
messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": function_response,
}
) # extend conversation with function response
print(f"messages: {messages}")
second_response = litellm.completion(
model=model,
messages=messages,
seed=22,
# tools=tools,
drop_params=True,
thinking={"type": "enabled", "budget_tokens": 1024},
) # get a new response from the model where it can see the function response
print("second response\n", second_response)
except litellm.InternalServerError as e:
print(e)
except litellm.RateLimitError as e:
print(e)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
from litellm.types.utils import ChatCompletionMessageToolCall, Function, Message from litellm.types.utils import ChatCompletionMessageToolCall, Function, Message

View file

@ -696,6 +696,7 @@ def test_stream_chunk_builder_openai_audio_output_usage():
api_key=os.getenv("OPENAI_API_KEY"), api_key=os.getenv("OPENAI_API_KEY"),
) )
try:
completion = client.chat.completions.create( completion = client.chat.completions.create(
model="gpt-4o-audio-preview", model="gpt-4o-audio-preview",
modalities=["text", "audio"], modalities=["text", "audio"],
@ -704,6 +705,9 @@ def test_stream_chunk_builder_openai_audio_output_usage():
stream=True, stream=True,
stream_options={"include_usage": True}, stream_options={"include_usage": True},
) )
except Exception as e:
if "openai-internal" in str(e):
pytest.skip("Skipping test due to openai-internal error")
chunks = [] chunks = []
for chunk in completion: for chunk in completion:

View file

@ -4065,20 +4065,32 @@ def test_mock_response_iterator_tool_use():
assert response_chunk["tool_use"] is not None assert response_chunk["tool_use"] is not None
def test_deepseek_reasoning_content_completion(): @pytest.mark.parametrize(
"model",
[
# "deepseek/deepseek-reasoner",
"anthropic/claude-3-7-sonnet-20250219",
],
)
def test_deepseek_reasoning_content_completion(model):
# litellm.set_verbose = True # litellm.set_verbose = True
try: try:
# litellm._turn_on_debug()
resp = litellm.completion( resp = litellm.completion(
model="deepseek/deepseek-reasoner", model=model,
messages=[{"role": "user", "content": "Tell me a joke."}], messages=[{"role": "user", "content": "Tell me a joke."}],
stream=True, stream=True,
thinking={"type": "enabled", "budget_tokens": 1024},
timeout=5, timeout=5,
) )
reasoning_content_exists = False reasoning_content_exists = False
for chunk in resp: for chunk in resp:
print(f"chunk: {chunk}") print(f"chunk 2: {chunk}")
if chunk.choices[0].delta.reasoning_content is not None: if (
hasattr(chunk.choices[0].delta, "reasoning_content")
and chunk.choices[0].delta.reasoning_content is not None
):
reasoning_content_exists = True reasoning_content_exists = True
break break
assert reasoning_content_exists assert reasoning_content_exists