fix(streaming_handler.py): fix completion start time tracking (#9688)

* fix(streaming_handler.py): fix completion start time tracking

Fixes https://github.com/BerriAI/litellm/issues/9210

* feat(anthropic/chat/transformation.py): map openai 'reasoning_effort' to anthropic 'thinking' param

Fixes https://github.com/BerriAI/litellm/issues/9022

* feat: map 'reasoning_effort' to 'thinking' param across bedrock + vertex

Closes https://github.com/BerriAI/litellm/issues/9022#issuecomment-2705260808
This commit is contained in:
Krish Dholakia 2025-04-01 22:00:56 -07:00 committed by GitHub
parent 0690f7a3cb
commit 23051d89dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 135 additions and 11 deletions

View file

@ -290,6 +290,7 @@ class Logging(LiteLLMLoggingBaseClass):
"input": _input,
"litellm_params": litellm_params,
"applied_guardrails": applied_guardrails,
"model": model,
}
def process_dynamic_callbacks(self):
@ -1010,6 +1011,10 @@ class Logging(LiteLLMLoggingBaseClass):
return False
return True
def _update_completion_start_time(self, completion_start_time: datetime.datetime):
self.completion_start_time = completion_start_time
self.model_call_details["completion_start_time"] = self.completion_start_time
def _success_handler_helper_fn(
self,
result=None,

View file

@ -1,5 +1,6 @@
import asyncio
import collections.abc
import datetime
import json
import threading
import time
@ -1567,6 +1568,10 @@ class CustomStreamWrapper:
if response is None:
continue
if self.logging_obj.completion_start_time is None:
self.logging_obj._update_completion_start_time(
completion_start_time=datetime.datetime.now()
)
## LOGGING
executor.submit(
self.run_success_logging_and_cache_storage,
@ -1721,6 +1726,11 @@ class CustomStreamWrapper:
if processed_chunk is None:
continue
if self.logging_obj.completion_start_time is None:
self.logging_obj._update_completion_start_time(
completion_start_time=datetime.datetime.now()
)
choice = processed_chunk.choices[0]
if isinstance(choice, StreamingChoices):
self.response_uptil_now += choice.delta.get("content", "") or ""

View file

@ -18,8 +18,10 @@ from litellm.types.llms.anthropic import (
AnthropicMessagesTool,
AnthropicMessagesToolChoice,
AnthropicSystemMessageContent,
AnthropicThinkingParam,
)
from litellm.types.llms.openai import (
REASONING_EFFORT,
AllMessageValues,
ChatCompletionCachedContent,
ChatCompletionSystemMessage,
@ -94,6 +96,7 @@ class AnthropicConfig(BaseConfig):
"parallel_tool_calls",
"response_format",
"user",
"reasoning_effort",
]
if "claude-3-7-sonnet" in model:
@ -291,6 +294,21 @@ class AnthropicConfig(BaseConfig):
new_stop = new_v
return new_stop
@staticmethod
def _map_reasoning_effort(
reasoning_effort: Optional[Union[REASONING_EFFORT, str]]
) -> Optional[AnthropicThinkingParam]:
if reasoning_effort is None:
return None
elif reasoning_effort == "low":
return AnthropicThinkingParam(type="enabled", budget_tokens=1024)
elif reasoning_effort == "medium":
return AnthropicThinkingParam(type="enabled", budget_tokens=2048)
elif reasoning_effort == "high":
return AnthropicThinkingParam(type="enabled", budget_tokens=4096)
else:
raise ValueError(f"Unmapped reasoning effort: {reasoning_effort}")
def map_openai_params(
self,
non_default_params: dict,
@ -302,10 +320,6 @@ class AnthropicConfig(BaseConfig):
non_default_params=non_default_params
)
## handle thinking tokens
self.update_optional_params_with_thinking_tokens(
non_default_params=non_default_params, optional_params=optional_params
)
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
@ -370,7 +384,15 @@ class AnthropicConfig(BaseConfig):
optional_params["metadata"] = {"user_id": value}
if param == "thinking":
optional_params["thinking"] = value
elif param == "reasoning_effort" and isinstance(value, str):
optional_params["thinking"] = AnthropicConfig._map_reasoning_effort(
value
)
## handle thinking tokens
self.update_optional_params_with_thinking_tokens(
non_default_params=non_default_params, optional_params=optional_params
)
return optional_params
def _create_json_tool_call_for_response_format(

View file

@ -104,7 +104,10 @@ class BaseConfig(ABC):
return type_to_response_format_param(response_format=response_format)
def is_thinking_enabled(self, non_default_params: dict) -> bool:
return non_default_params.get("thinking", {}).get("type", None) == "enabled"
return (
non_default_params.get("thinking", {}).get("type") == "enabled"
or non_default_params.get("reasoning_effort") is not None
)
def update_optional_params_with_thinking_tokens(
self, non_default_params: dict, optional_params: dict
@ -116,9 +119,9 @@ class BaseConfig(ABC):
if 'thinking' is enabled and 'max_tokens' is not specified, set 'max_tokens' to the thinking token budget + DEFAULT_MAX_TOKENS
"""
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
is_thinking_enabled = self.is_thinking_enabled(optional_params)
if is_thinking_enabled and "max_tokens" not in non_default_params:
thinking_token_budget = cast(dict, non_default_params["thinking"]).get(
thinking_token_budget = cast(dict, optional_params["thinking"]).get(
"budget_tokens", None
)
if thinking_token_budget is not None:

View file

@ -17,6 +17,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
_bedrock_converse_messages_pt,
_bedrock_tools_pt,
)
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.bedrock import *
from litellm.types.llms.openai import (
@ -128,6 +129,7 @@ class AmazonConverseConfig(BaseConfig):
"claude-3-7" in model
): # [TODO]: move to a 'supports_reasoning_content' param from model cost map
supported_params.append("thinking")
supported_params.append("reasoning_effort")
return supported_params
def map_tool_choice_values(
@ -218,9 +220,7 @@ class AmazonConverseConfig(BaseConfig):
messages: Optional[List[AllMessageValues]] = None,
) -> dict:
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
self.update_optional_params_with_thinking_tokens(
non_default_params=non_default_params, optional_params=optional_params
)
for param, value in non_default_params.items():
if param == "response_format" and isinstance(value, dict):
ignore_response_format_types = ["text"]
@ -297,6 +297,14 @@ class AmazonConverseConfig(BaseConfig):
optional_params["tool_choice"] = _tool_choice_value
if param == "thinking":
optional_params["thinking"] = value
elif param == "reasoning_effort" and isinstance(value, str):
optional_params["thinking"] = AnthropicConfig._map_reasoning_effort(
value
)
self.update_optional_params_with_thinking_tokens(
non_default_params=non_default_params, optional_params=optional_params
)
return optional_params

View file

@ -1113,3 +1113,6 @@ ResponsesAPIStreamingResponse = Annotated[
],
Discriminator("type"),
]
REASONING_EFFORT = Literal["low", "medium", "high"]

View file

@ -5901,9 +5901,10 @@ class ModelResponseIterator:
class ModelResponseListIterator:
def __init__(self, model_responses):
def __init__(self, model_responses, delay: Optional[float] = None):
self.model_responses = model_responses
self.index = 0
self.delay = delay
# Sync iterator
def __iter__(self):
@ -5914,6 +5915,8 @@ class ModelResponseListIterator:
raise StopIteration
model_response = self.model_responses[self.index]
self.index += 1
if self.delay:
time.sleep(self.delay)
return model_response
# Async iterator
@ -5925,6 +5928,8 @@ class ModelResponseListIterator:
raise StopAsyncIteration
model_response = self.model_responses[self.index]
self.index += 1
if self.delay:
await asyncio.sleep(self.delay)
return model_response

View file

@ -1,6 +1,7 @@
import json
import os
import sys
import time
from unittest.mock import MagicMock, Mock, patch
import pytest
@ -19,6 +20,7 @@ from litellm.types.utils import (
Delta,
ModelResponseStream,
PromptTokensDetailsWrapper,
StandardLoggingPayload,
StreamingChoices,
Usage,
)
@ -36,6 +38,22 @@ def initialized_custom_stream_wrapper() -> CustomStreamWrapper:
return streaming_handler
@pytest.fixture
def logging_obj() -> Logging:
import time
logging_obj = Logging(
model="my-random-model",
messages=[{"role": "user", "content": "Hey"}],
stream=True,
call_type="completion",
start_time=time.time(),
litellm_call_id="12345",
function_id="1245",
)
return logging_obj
bedrock_chunks = [
ModelResponseStream(
id="chatcmpl-d249def8-a78b-464c-87b5-3a6f43565292",
@ -577,3 +595,36 @@ def test_streaming_handler_with_stop_chunk(
**args, model_response=ModelResponseStream()
)
assert returned_chunk is None
@pytest.mark.asyncio
async def test_streaming_completion_start_time(logging_obj: Logging):
"""Test that the start time is set correctly"""
from litellm.integrations.custom_logger import CustomLogger
class MockCallback(CustomLogger):
pass
mock_callback = MockCallback()
litellm.success_callback = [mock_callback, "langfuse"]
completion_stream = ModelResponseListIterator(
model_responses=bedrock_chunks, delay=0.1
)
response = CustomStreamWrapper(
completion_stream=completion_stream,
model="bedrock/claude-3-5-sonnet-20240620-v1:0",
logging_obj=logging_obj,
)
async for chunk in response:
print(chunk)
await asyncio.sleep(2)
assert logging_obj.model_call_details["completion_start_time"] is not None
assert (
logging_obj.model_call_details["completion_start_time"]
< logging_obj.model_call_details["end_time"]
)

View file

@ -1379,3 +1379,20 @@ def test_azure_modalities_param():
)
assert optional_params["modalities"] == ["text", "audio"]
assert optional_params["audio"] == {"type": "audio_input", "input": "test.wav"}
@pytest.mark.parametrize(
"model, provider",
[
("claude-3-7-sonnet-20240620-v1:0", "anthropic"),
("anthropic.claude-3-7-sonnet-20250219-v1:0", "bedrock"),
("invoke/anthropic.claude-3-7-sonnet-20240620-v1:0", "bedrock"),
("claude-3-7-sonnet@20250219", "vertex_ai"),
],
)
def test_anthropic_unified_reasoning_content(model, provider):
optional_params = get_optional_params(
model=model,
custom_llm_provider=provider,
reasoning_effort="high",
)
assert optional_params["thinking"] == {"type": "enabled", "budget_tokens": 4096}