mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Add anthropic thinking + reasoning content support (#8778)
* feat(anthropic/chat/transformation.py): add anthropic thinking param support * feat(anthropic/chat/transformation.py): support returning thinking content for anthropic on streaming responses * feat(anthropic/chat/transformation.py): return list of thinking blocks (include block signature) allows usage in tool call responses * fix(types/utils.py): extract and map reasoning_content from anthropic as content str * test: add testing to ensure thinking_blocks are returned at the root * fix(anthropic/chat/handler.py): return thinking blocks on streaming - include signature * feat(factory.py): handle anthropic thinking blocks translation if in assistant response * test: handle openai internal instability * test: handle openai audio instability * ci: pin anthropic dep * test: handle openai audio instability * fix: fix linting error * refactor(anthropic/chat/transformation.py): refactor function to remain <50 LOC * fix: fix linting error * fix: fix linting error * fix: fix linting error * fix: fix linting error
This commit is contained in:
parent
9914c166b7
commit
142b195784
16 changed files with 332 additions and 62 deletions
|
@ -1939,7 +1939,7 @@ jobs:
|
|||
pip install "asyncio==3.4.3"
|
||||
pip install "PyGithub==1.59.1"
|
||||
pip install "google-cloud-aiplatform==1.59.0"
|
||||
pip install anthropic
|
||||
pip install "anthropic==0.21.3"
|
||||
# Run pytest and generate JUnit XML report
|
||||
- run:
|
||||
name: Build Docker image
|
||||
|
|
|
@ -1444,6 +1444,12 @@ def anthropic_messages_pt( # noqa: PLR0915
|
|||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||
assistant_content_block: ChatCompletionAssistantMessage = messages[msg_i] # type: ignore
|
||||
|
||||
thinking_blocks = assistant_content_block.get("thinking_blocks", None)
|
||||
if (
|
||||
thinking_blocks is not None
|
||||
): # IMPORTANT: ADD THIS FIRST, ELSE ANTHROPIC WILL RAISE AN ERROR
|
||||
assistant_content.extend(thinking_blocks)
|
||||
if "content" in assistant_content_block and isinstance(
|
||||
assistant_content_block["content"], list
|
||||
):
|
||||
|
|
|
@ -30,6 +30,7 @@ from litellm.types.llms.anthropic import (
|
|||
UsageDelta,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionThinkingBlock,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
)
|
||||
|
@ -507,6 +508,10 @@ class ModelResponseIterator:
|
|||
return usage_block
|
||||
|
||||
def _content_block_delta_helper(self, chunk: dict):
|
||||
"""
|
||||
Helper function to handle the content block delta
|
||||
"""
|
||||
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
provider_specific_fields = {}
|
||||
|
@ -526,7 +531,17 @@ class ModelResponseIterator:
|
|||
}
|
||||
elif "citation" in content_block["delta"]:
|
||||
provider_specific_fields["citation"] = content_block["delta"]["citation"]
|
||||
|
||||
elif (
|
||||
"thinking" in content_block["delta"]
|
||||
or "signature_delta" == content_block["delta"]
|
||||
):
|
||||
provider_specific_fields["thinking_blocks"] = [
|
||||
ChatCompletionThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=content_block["delta"].get("thinking"),
|
||||
signature_delta=content_block["delta"].get("signature"),
|
||||
)
|
||||
]
|
||||
return text, tool_use, provider_specific_fields
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -581,6 +581,43 @@ class AnthropicConfig(BaseConfig):
|
|||
)
|
||||
return _message
|
||||
|
||||
def extract_response_content(self, completion_response: dict) -> Tuple[
|
||||
str,
|
||||
Optional[List[Any]],
|
||||
Optional[List[Dict[str, Any]]],
|
||||
List[ChatCompletionToolCallChunk],
|
||||
]:
|
||||
text_content = ""
|
||||
citations: Optional[List[Any]] = None
|
||||
thinking_blocks: Optional[List[Dict[str, Any]]] = None
|
||||
tool_calls: List[ChatCompletionToolCallChunk] = []
|
||||
for idx, content in enumerate(completion_response["content"]):
|
||||
if content["type"] == "text":
|
||||
text_content += content["text"]
|
||||
## TOOL CALLING
|
||||
elif content["type"] == "tool_use":
|
||||
tool_calls.append(
|
||||
ChatCompletionToolCallChunk(
|
||||
id=content["id"],
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=content["name"],
|
||||
arguments=json.dumps(content["input"]),
|
||||
),
|
||||
index=idx,
|
||||
)
|
||||
)
|
||||
## CITATIONS
|
||||
if content.get("citations", None) is not None:
|
||||
if citations is None:
|
||||
citations = []
|
||||
citations.append(content["citations"])
|
||||
if content.get("thinking", None) is not None:
|
||||
if thinking_blocks is None:
|
||||
thinking_blocks = []
|
||||
thinking_blocks.append(content)
|
||||
return text_content, citations, thinking_blocks, tool_calls
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -628,32 +665,21 @@ class AnthropicConfig(BaseConfig):
|
|||
)
|
||||
else:
|
||||
text_content = ""
|
||||
citations: List[Any] = []
|
||||
citations: Optional[List[Any]] = None
|
||||
thinking_blocks: Optional[List[Dict[str, Any]]] = None
|
||||
tool_calls: List[ChatCompletionToolCallChunk] = []
|
||||
for idx, content in enumerate(completion_response["content"]):
|
||||
if content["type"] == "text":
|
||||
text_content += content["text"]
|
||||
## TOOL CALLING
|
||||
elif content["type"] == "tool_use":
|
||||
tool_calls.append(
|
||||
ChatCompletionToolCallChunk(
|
||||
id=content["id"],
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=content["name"],
|
||||
arguments=json.dumps(content["input"]),
|
||||
),
|
||||
index=idx,
|
||||
|
||||
text_content, citations, thinking_blocks, tool_calls = (
|
||||
self.extract_response_content(completion_response=completion_response)
|
||||
)
|
||||
)
|
||||
## CITATIONS
|
||||
if content.get("citations", None) is not None:
|
||||
citations.append(content["citations"])
|
||||
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=text_content or None,
|
||||
provider_specific_fields={"citations": citations},
|
||||
provider_specific_fields={
|
||||
"citations": citations,
|
||||
"thinking_blocks": thinking_blocks,
|
||||
},
|
||||
)
|
||||
|
||||
## HANDLE JSON MODE - anthropic returns single function call
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import List, Optional, Type, Union
|
|||
from openai.lib import _parsing, _pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ProviderSpecificModelInfo
|
||||
|
||||
|
@ -132,6 +133,9 @@ def map_developer_role_to_system_role(
|
|||
new_messages: List[AllMessageValues] = []
|
||||
for m in messages:
|
||||
if m["role"] == "developer":
|
||||
verbose_logger.debug(
|
||||
"Translating developer role to system role for non-OpenAI providers."
|
||||
) # ensure user knows what's happening with their input.
|
||||
new_messages.append({"role": "system", "content": m["content"]})
|
||||
else:
|
||||
new_messages.append(m)
|
||||
|
|
|
@ -18,7 +18,6 @@ from typing import (
|
|||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.openai import (
|
||||
|
@ -121,9 +120,6 @@ class BaseConfig(ABC):
|
|||
|
||||
Overriden by OpenAI/Azure
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
"Translating developer role to system role for non-OpenAI providers."
|
||||
) # ensure user knows what's happening with their input.
|
||||
return map_developer_role_to_system_role(messages=messages)
|
||||
|
||||
def should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
model_list:
|
||||
- model_name: anthropic/claude-3-7-sonnet-20250219
|
||||
litellm_params:
|
||||
model: anthropic/claude-3-7-sonnet-20250219
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: openai/gpt-3.5-turbo
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Any, Dict, Iterable, List, Optional, Union
|
|||
from pydantic import BaseModel, validator
|
||||
from typing_extensions import Literal, Required, TypedDict
|
||||
|
||||
from .openai import ChatCompletionCachedContent
|
||||
from .openai import ChatCompletionCachedContent, ChatCompletionThinkingBlock
|
||||
|
||||
|
||||
class AnthropicMessagesToolChoice(TypedDict, total=False):
|
||||
|
@ -62,6 +62,7 @@ class AnthropicMessagesToolUseParam(TypedDict):
|
|||
AnthropicMessagesAssistantMessageValues = Union[
|
||||
AnthropicMessagesTextParam,
|
||||
AnthropicMessagesToolUseParam,
|
||||
ChatCompletionThinkingBlock,
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -357,6 +357,12 @@ class ChatCompletionCachedContent(TypedDict):
|
|||
type: Literal["ephemeral"]
|
||||
|
||||
|
||||
class ChatCompletionThinkingBlock(TypedDict, total=False):
|
||||
type: Required[Literal["thinking"]]
|
||||
thinking: str
|
||||
signature_delta: str
|
||||
|
||||
|
||||
class OpenAIChatCompletionTextObject(TypedDict):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
@ -450,6 +456,7 @@ class OpenAIChatCompletionAssistantMessage(TypedDict, total=False):
|
|||
|
||||
class ChatCompletionAssistantMessage(OpenAIChatCompletionAssistantMessage, total=False):
|
||||
cache_control: ChatCompletionCachedContent
|
||||
thinking_blocks: Optional[List[ChatCompletionThinkingBlock]]
|
||||
|
||||
|
||||
class ChatCompletionToolMessage(TypedDict):
|
||||
|
|
|
@ -457,6 +457,43 @@ Reference:
|
|||
ChatCompletionMessage(content='This is a test', role='assistant', function_call=None, tool_calls=None))
|
||||
"""
|
||||
|
||||
REASONING_CONTENT_COMPATIBLE_PARAMS = [
|
||||
"thinking_blocks",
|
||||
"reasoning_content",
|
||||
]
|
||||
|
||||
|
||||
def map_reasoning_content(provider_specific_fields: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract reasoning_content from provider_specific_fields
|
||||
"""
|
||||
|
||||
reasoning_content: str = ""
|
||||
for k, v in provider_specific_fields.items():
|
||||
if k == "thinking_blocks" and isinstance(v, list):
|
||||
_reasoning_content = ""
|
||||
for block in v:
|
||||
if block.get("type") == "thinking":
|
||||
_reasoning_content += block.get("thinking", "")
|
||||
reasoning_content = _reasoning_content
|
||||
elif k == "reasoning_content":
|
||||
reasoning_content = v
|
||||
return reasoning_content
|
||||
|
||||
|
||||
def add_provider_specific_fields(
|
||||
object: BaseModel, provider_specific_fields: Optional[Dict[str, Any]]
|
||||
):
|
||||
if not provider_specific_fields: # set if provider_specific_fields is not empty
|
||||
return
|
||||
setattr(object, "provider_specific_fields", provider_specific_fields)
|
||||
for k, v in provider_specific_fields.items():
|
||||
if v is not None:
|
||||
setattr(object, k, v)
|
||||
if k in REASONING_CONTENT_COMPATIBLE_PARAMS and k != "reasoning_content":
|
||||
reasoning_content = map_reasoning_content({k: v})
|
||||
setattr(object, "reasoning_content", reasoning_content)
|
||||
|
||||
|
||||
class Message(OpenAIObject):
|
||||
content: Optional[str]
|
||||
|
@ -511,10 +548,7 @@ class Message(OpenAIObject):
|
|||
# OpenAI compatible APIs like mistral API will raise an error if audio is passed in
|
||||
del self.audio
|
||||
|
||||
if provider_specific_fields: # set if provider_specific_fields is not empty
|
||||
self.provider_specific_fields = provider_specific_fields
|
||||
for k, v in provider_specific_fields.items():
|
||||
setattr(self, k, v)
|
||||
add_provider_specific_fields(self, provider_specific_fields)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
|
@ -551,11 +585,7 @@ class Delta(OpenAIObject):
|
|||
**params,
|
||||
):
|
||||
super(Delta, self).__init__(**params)
|
||||
provider_specific_fields: Dict[str, Any] = {}
|
||||
|
||||
if "reasoning_content" in params:
|
||||
provider_specific_fields["reasoning_content"] = params["reasoning_content"]
|
||||
setattr(self, "reasoning_content", params["reasoning_content"])
|
||||
add_provider_specific_fields(self, params.get("provider_specific_fields", {}))
|
||||
self.content = content
|
||||
self.role = role
|
||||
# Set default values and correct types
|
||||
|
@ -563,9 +593,6 @@ class Delta(OpenAIObject):
|
|||
self.tool_calls: Optional[List[Union[ChatCompletionDeltaToolCall, Any]]] = None
|
||||
self.audio: Optional[ChatCompletionAudioResponse] = None
|
||||
|
||||
if provider_specific_fields: # set if provider_specific_fields is not empty
|
||||
self.provider_specific_fields = provider_specific_fields
|
||||
|
||||
if function_call is not None and isinstance(function_call, dict):
|
||||
self.function_call = FunctionCall(**function_call)
|
||||
else:
|
||||
|
|
|
@ -1161,3 +1161,53 @@ def test_anthropic_citations_api_streaming():
|
|||
has_citations = True
|
||||
|
||||
assert has_citations
|
||||
|
||||
|
||||
def test_anthropic_thinking_output():
|
||||
from litellm import completion
|
||||
|
||||
resp = completion(
|
||||
model="anthropic/claude-3-7-sonnet-20250219",
|
||||
messages=[{"role": "user", "content": "What is the capital of France?"}],
|
||||
thinking={"type": "enabled", "budget_tokens": 1024},
|
||||
)
|
||||
|
||||
print(resp.choices[0].message)
|
||||
assert (
|
||||
resp.choices[0].message.provider_specific_fields["thinking_blocks"] is not None
|
||||
)
|
||||
assert resp.choices[0].message.reasoning_content is not None
|
||||
assert isinstance(resp.choices[0].message.reasoning_content, str)
|
||||
assert resp.choices[0].message.thinking_blocks is not None
|
||||
assert isinstance(resp.choices[0].message.thinking_blocks, list)
|
||||
assert len(resp.choices[0].message.thinking_blocks) > 0
|
||||
|
||||
|
||||
def test_anthropic_thinking_output_stream():
|
||||
# litellm.set_verbose = True
|
||||
try:
|
||||
# litellm._turn_on_debug()
|
||||
resp = litellm.completion(
|
||||
model="anthropic/claude-3-7-sonnet-20250219",
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||
stream=True,
|
||||
thinking={"type": "enabled", "budget_tokens": 1024},
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
reasoning_content_exists = False
|
||||
for chunk in resp:
|
||||
print(f"chunk 2: {chunk}")
|
||||
if (
|
||||
hasattr(chunk.choices[0].delta, "thinking_blocks")
|
||||
and chunk.choices[0].delta.thinking_blocks is not None
|
||||
and chunk.choices[0].delta.reasoning_content is not None
|
||||
and isinstance(chunk.choices[0].delta.thinking_blocks, list)
|
||||
and len(chunk.choices[0].delta.thinking_blocks) > 0
|
||||
and isinstance(chunk.choices[0].delta.reasoning_content, str)
|
||||
):
|
||||
reasoning_content_exists = True
|
||||
break
|
||||
assert reasoning_content_exists
|
||||
except litellm.Timeout:
|
||||
pytest.skip("Model is timing out")
|
||||
|
|
|
@ -67,6 +67,9 @@ async def test_audio_output_from_model(stream):
|
|||
except litellm.Timeout as e:
|
||||
print(e)
|
||||
pytest.skip("Skipping test due to timeout")
|
||||
except Exception as e:
|
||||
if "openai-internal" in str(e):
|
||||
pytest.skip("Skipping test due to openai-internal error")
|
||||
|
||||
if stream is True:
|
||||
await check_streaming_response(completion)
|
||||
|
@ -86,7 +89,7 @@ async def test_audio_input_to_model(stream):
|
|||
audio_format = "pcm16"
|
||||
if stream is False:
|
||||
audio_format = "wav"
|
||||
litellm.set_verbose = True
|
||||
litellm._turn_on_debug()
|
||||
url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav"
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
@ -114,7 +117,9 @@ async def test_audio_input_to_model(stream):
|
|||
except litellm.Timeout as e:
|
||||
print(e)
|
||||
pytest.skip("Skipping test due to timeout")
|
||||
|
||||
except Exception as e:
|
||||
if "openai-internal" in str(e):
|
||||
pytest.skip("Skipping test due to openai-internal error")
|
||||
if stream is True:
|
||||
await check_streaming_response(completion)
|
||||
else:
|
||||
|
|
|
@ -1320,13 +1320,19 @@ def test_standard_logging_payload_audio(turn_off_message_logging, stream):
|
|||
with patch.object(
|
||||
customHandler, "log_success_event", new=MagicMock()
|
||||
) as mock_client:
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="gpt-4o-audio-preview",
|
||||
modalities=["text", "audio"],
|
||||
audio={"voice": "alloy", "format": "pcm16"},
|
||||
messages=[{"role": "user", "content": "response in 1 word - yes or no"}],
|
||||
messages=[
|
||||
{"role": "user", "content": "response in 1 word - yes or no"}
|
||||
],
|
||||
stream=stream,
|
||||
)
|
||||
except Exception as e:
|
||||
if "openai-internal" in str(e):
|
||||
pytest.skip("Skipping test due to openai-internal error")
|
||||
|
||||
if stream:
|
||||
for chunk in response:
|
||||
|
|
|
@ -157,6 +157,113 @@ def test_aaparallel_function_call(model):
|
|||
# test_parallel_function_call()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"anthropic/claude-3-7-sonnet-20250219",
|
||||
],
|
||||
)
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
def test_aaparallel_function_call_with_anthropic_thinking(model):
|
||||
try:
|
||||
litellm._turn_on_debug()
|
||||
litellm.modify_params = True
|
||||
# Step 1: send the conversation and available functions to the model
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses",
|
||||
}
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
response = litellm.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto", # auto is default, but we'll be explicit
|
||||
thinking={"type": "enabled", "budget_tokens": 1024},
|
||||
)
|
||||
print("Response\n", response)
|
||||
response_message = response.choices[0].message
|
||||
tool_calls = response_message.tool_calls
|
||||
|
||||
print("Expecting there to be 3 tool calls")
|
||||
assert (
|
||||
len(tool_calls) > 0
|
||||
) # this has to call the function for SF, Tokyo and paris
|
||||
|
||||
# Step 2: check if the model wanted to call a function
|
||||
print(f"tool_calls: {tool_calls}")
|
||||
if tool_calls:
|
||||
# Step 3: call the function
|
||||
# Note: the JSON response may not always be valid; be sure to handle errors
|
||||
available_functions = {
|
||||
"get_current_weather": get_current_weather,
|
||||
} # only one function in this example, but you can have multiple
|
||||
messages.append(
|
||||
response_message
|
||||
) # extend conversation with assistant's reply
|
||||
print("Response message\n", response_message)
|
||||
# Step 4: send the info for each function call and function response to the model
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
if function_name not in available_functions:
|
||||
# the model called a function that does not exist in available_functions - don't try calling anything
|
||||
return
|
||||
function_to_call = available_functions[function_name]
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
function_response = function_to_call(
|
||||
location=function_args.get("location"),
|
||||
unit=function_args.get("unit"),
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": tool_call.id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": function_response,
|
||||
}
|
||||
) # extend conversation with function response
|
||||
print(f"messages: {messages}")
|
||||
second_response = litellm.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
seed=22,
|
||||
# tools=tools,
|
||||
drop_params=True,
|
||||
thinking={"type": "enabled", "budget_tokens": 1024},
|
||||
) # get a new response from the model where it can see the function response
|
||||
print("second response\n", second_response)
|
||||
except litellm.InternalServerError as e:
|
||||
print(e)
|
||||
except litellm.RateLimitError as e:
|
||||
print(e)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
from litellm.types.utils import ChatCompletionMessageToolCall, Function, Message
|
||||
|
||||
|
||||
|
|
|
@ -696,6 +696,7 @@ def test_stream_chunk_builder_openai_audio_output_usage():
|
|||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
|
||||
try:
|
||||
completion = client.chat.completions.create(
|
||||
model="gpt-4o-audio-preview",
|
||||
modalities=["text", "audio"],
|
||||
|
@ -704,6 +705,9 @@ def test_stream_chunk_builder_openai_audio_output_usage():
|
|||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
except Exception as e:
|
||||
if "openai-internal" in str(e):
|
||||
pytest.skip("Skipping test due to openai-internal error")
|
||||
|
||||
chunks = []
|
||||
for chunk in completion:
|
||||
|
|
|
@ -4065,20 +4065,32 @@ def test_mock_response_iterator_tool_use():
|
|||
assert response_chunk["tool_use"] is not None
|
||||
|
||||
|
||||
def test_deepseek_reasoning_content_completion():
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
# "deepseek/deepseek-reasoner",
|
||||
"anthropic/claude-3-7-sonnet-20250219",
|
||||
],
|
||||
)
|
||||
def test_deepseek_reasoning_content_completion(model):
|
||||
# litellm.set_verbose = True
|
||||
try:
|
||||
# litellm._turn_on_debug()
|
||||
resp = litellm.completion(
|
||||
model="deepseek/deepseek-reasoner",
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||
stream=True,
|
||||
thinking={"type": "enabled", "budget_tokens": 1024},
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
reasoning_content_exists = False
|
||||
for chunk in resp:
|
||||
print(f"chunk: {chunk}")
|
||||
if chunk.choices[0].delta.reasoning_content is not None:
|
||||
print(f"chunk 2: {chunk}")
|
||||
if (
|
||||
hasattr(chunk.choices[0].delta, "reasoning_content")
|
||||
and chunk.choices[0].delta.reasoning_content is not None
|
||||
):
|
||||
reasoning_content_exists = True
|
||||
break
|
||||
assert reasoning_content_exists
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue