mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(streaming_handler.py): fix completion start time tracking (#9688)
* fix(streaming_handler.py): fix completion start time tracking Fixes https://github.com/BerriAI/litellm/issues/9210 * feat(anthropic/chat/transformation.py): map openai 'reasoning_effort' to anthropic 'thinking' param Fixes https://github.com/BerriAI/litellm/issues/9022 * feat: map 'reasoning_effort' to 'thinking' param across bedrock + vertex Closes https://github.com/BerriAI/litellm/issues/9022#issuecomment-2705260808
This commit is contained in:
parent
0690f7a3cb
commit
23051d89dd
9 changed files with 135 additions and 11 deletions
|
@ -290,6 +290,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
"input": _input,
|
||||
"litellm_params": litellm_params,
|
||||
"applied_guardrails": applied_guardrails,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
def process_dynamic_callbacks(self):
|
||||
|
@ -1010,6 +1011,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
return False
|
||||
return True
|
||||
|
||||
def _update_completion_start_time(self, completion_start_time: datetime.datetime):
|
||||
self.completion_start_time = completion_start_time
|
||||
self.model_call_details["completion_start_time"] = self.completion_start_time
|
||||
|
||||
def _success_handler_helper_fn(
|
||||
self,
|
||||
result=None,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import collections.abc
|
||||
import datetime
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
|
@ -1567,6 +1568,10 @@ class CustomStreamWrapper:
|
|||
|
||||
if response is None:
|
||||
continue
|
||||
if self.logging_obj.completion_start_time is None:
|
||||
self.logging_obj._update_completion_start_time(
|
||||
completion_start_time=datetime.datetime.now()
|
||||
)
|
||||
## LOGGING
|
||||
executor.submit(
|
||||
self.run_success_logging_and_cache_storage,
|
||||
|
@ -1721,6 +1726,11 @@ class CustomStreamWrapper:
|
|||
if processed_chunk is None:
|
||||
continue
|
||||
|
||||
if self.logging_obj.completion_start_time is None:
|
||||
self.logging_obj._update_completion_start_time(
|
||||
completion_start_time=datetime.datetime.now()
|
||||
)
|
||||
|
||||
choice = processed_chunk.choices[0]
|
||||
if isinstance(choice, StreamingChoices):
|
||||
self.response_uptil_now += choice.delta.get("content", "") or ""
|
||||
|
|
|
@ -18,8 +18,10 @@ from litellm.types.llms.anthropic import (
|
|||
AnthropicMessagesTool,
|
||||
AnthropicMessagesToolChoice,
|
||||
AnthropicSystemMessageContent,
|
||||
AnthropicThinkingParam,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
REASONING_EFFORT,
|
||||
AllMessageValues,
|
||||
ChatCompletionCachedContent,
|
||||
ChatCompletionSystemMessage,
|
||||
|
@ -94,6 +96,7 @@ class AnthropicConfig(BaseConfig):
|
|||
"parallel_tool_calls",
|
||||
"response_format",
|
||||
"user",
|
||||
"reasoning_effort",
|
||||
]
|
||||
|
||||
if "claude-3-7-sonnet" in model:
|
||||
|
@ -291,6 +294,21 @@ class AnthropicConfig(BaseConfig):
|
|||
new_stop = new_v
|
||||
return new_stop
|
||||
|
||||
@staticmethod
|
||||
def _map_reasoning_effort(
|
||||
reasoning_effort: Optional[Union[REASONING_EFFORT, str]]
|
||||
) -> Optional[AnthropicThinkingParam]:
|
||||
if reasoning_effort is None:
|
||||
return None
|
||||
elif reasoning_effort == "low":
|
||||
return AnthropicThinkingParam(type="enabled", budget_tokens=1024)
|
||||
elif reasoning_effort == "medium":
|
||||
return AnthropicThinkingParam(type="enabled", budget_tokens=2048)
|
||||
elif reasoning_effort == "high":
|
||||
return AnthropicThinkingParam(type="enabled", budget_tokens=4096)
|
||||
else:
|
||||
raise ValueError(f"Unmapped reasoning effort: {reasoning_effort}")
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
|
@ -302,10 +320,6 @@ class AnthropicConfig(BaseConfig):
|
|||
non_default_params=non_default_params
|
||||
)
|
||||
|
||||
## handle thinking tokens
|
||||
self.update_optional_params_with_thinking_tokens(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
)
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
|
@ -370,7 +384,15 @@ class AnthropicConfig(BaseConfig):
|
|||
optional_params["metadata"] = {"user_id": value}
|
||||
if param == "thinking":
|
||||
optional_params["thinking"] = value
|
||||
elif param == "reasoning_effort" and isinstance(value, str):
|
||||
optional_params["thinking"] = AnthropicConfig._map_reasoning_effort(
|
||||
value
|
||||
)
|
||||
|
||||
## handle thinking tokens
|
||||
self.update_optional_params_with_thinking_tokens(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
)
|
||||
return optional_params
|
||||
|
||||
def _create_json_tool_call_for_response_format(
|
||||
|
|
|
@ -104,7 +104,10 @@ class BaseConfig(ABC):
|
|||
return type_to_response_format_param(response_format=response_format)
|
||||
|
||||
def is_thinking_enabled(self, non_default_params: dict) -> bool:
|
||||
return non_default_params.get("thinking", {}).get("type", None) == "enabled"
|
||||
return (
|
||||
non_default_params.get("thinking", {}).get("type") == "enabled"
|
||||
or non_default_params.get("reasoning_effort") is not None
|
||||
)
|
||||
|
||||
def update_optional_params_with_thinking_tokens(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
|
@ -116,9 +119,9 @@ class BaseConfig(ABC):
|
|||
|
||||
if 'thinking' is enabled and 'max_tokens' is not specified, set 'max_tokens' to the thinking token budget + DEFAULT_MAX_TOKENS
|
||||
"""
|
||||
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
|
||||
is_thinking_enabled = self.is_thinking_enabled(optional_params)
|
||||
if is_thinking_enabled and "max_tokens" not in non_default_params:
|
||||
thinking_token_budget = cast(dict, non_default_params["thinking"]).get(
|
||||
thinking_token_budget = cast(dict, optional_params["thinking"]).get(
|
||||
"budget_tokens", None
|
||||
)
|
||||
if thinking_token_budget is not None:
|
||||
|
|
|
@ -17,6 +17,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
|||
_bedrock_converse_messages_pt,
|
||||
_bedrock_tools_pt,
|
||||
)
|
||||
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.bedrock import *
|
||||
from litellm.types.llms.openai import (
|
||||
|
@ -128,6 +129,7 @@ class AmazonConverseConfig(BaseConfig):
|
|||
"claude-3-7" in model
|
||||
): # [TODO]: move to a 'supports_reasoning_content' param from model cost map
|
||||
supported_params.append("thinking")
|
||||
supported_params.append("reasoning_effort")
|
||||
return supported_params
|
||||
|
||||
def map_tool_choice_values(
|
||||
|
@ -218,9 +220,7 @@ class AmazonConverseConfig(BaseConfig):
|
|||
messages: Optional[List[AllMessageValues]] = None,
|
||||
) -> dict:
|
||||
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
|
||||
self.update_optional_params_with_thinking_tokens(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
)
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param == "response_format" and isinstance(value, dict):
|
||||
ignore_response_format_types = ["text"]
|
||||
|
@ -297,6 +297,14 @@ class AmazonConverseConfig(BaseConfig):
|
|||
optional_params["tool_choice"] = _tool_choice_value
|
||||
if param == "thinking":
|
||||
optional_params["thinking"] = value
|
||||
elif param == "reasoning_effort" and isinstance(value, str):
|
||||
optional_params["thinking"] = AnthropicConfig._map_reasoning_effort(
|
||||
value
|
||||
)
|
||||
|
||||
self.update_optional_params_with_thinking_tokens(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
|
|
|
@ -1113,3 +1113,6 @@ ResponsesAPIStreamingResponse = Annotated[
|
|||
],
|
||||
Discriminator("type"),
|
||||
]
|
||||
|
||||
|
||||
REASONING_EFFORT = Literal["low", "medium", "high"]
|
||||
|
|
|
@ -5901,9 +5901,10 @@ class ModelResponseIterator:
|
|||
|
||||
|
||||
class ModelResponseListIterator:
|
||||
def __init__(self, model_responses):
|
||||
def __init__(self, model_responses, delay: Optional[float] = None):
|
||||
self.model_responses = model_responses
|
||||
self.index = 0
|
||||
self.delay = delay
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
|
@ -5914,6 +5915,8 @@ class ModelResponseListIterator:
|
|||
raise StopIteration
|
||||
model_response = self.model_responses[self.index]
|
||||
self.index += 1
|
||||
if self.delay:
|
||||
time.sleep(self.delay)
|
||||
return model_response
|
||||
|
||||
# Async iterator
|
||||
|
@ -5925,6 +5928,8 @@ class ModelResponseListIterator:
|
|||
raise StopAsyncIteration
|
||||
model_response = self.model_responses[self.index]
|
||||
self.index += 1
|
||||
if self.delay:
|
||||
await asyncio.sleep(self.delay)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
@ -19,6 +20,7 @@ from litellm.types.utils import (
|
|||
Delta,
|
||||
ModelResponseStream,
|
||||
PromptTokensDetailsWrapper,
|
||||
StandardLoggingPayload,
|
||||
StreamingChoices,
|
||||
Usage,
|
||||
)
|
||||
|
@ -36,6 +38,22 @@ def initialized_custom_stream_wrapper() -> CustomStreamWrapper:
|
|||
return streaming_handler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logging_obj() -> Logging:
|
||||
import time
|
||||
|
||||
logging_obj = Logging(
|
||||
model="my-random-model",
|
||||
messages=[{"role": "user", "content": "Hey"}],
|
||||
stream=True,
|
||||
call_type="completion",
|
||||
start_time=time.time(),
|
||||
litellm_call_id="12345",
|
||||
function_id="1245",
|
||||
)
|
||||
return logging_obj
|
||||
|
||||
|
||||
bedrock_chunks = [
|
||||
ModelResponseStream(
|
||||
id="chatcmpl-d249def8-a78b-464c-87b5-3a6f43565292",
|
||||
|
@ -577,3 +595,36 @@ def test_streaming_handler_with_stop_chunk(
|
|||
**args, model_response=ModelResponseStream()
|
||||
)
|
||||
assert returned_chunk is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_completion_start_time(logging_obj: Logging):
|
||||
"""Test that the start time is set correctly"""
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
class MockCallback(CustomLogger):
|
||||
pass
|
||||
|
||||
mock_callback = MockCallback()
|
||||
litellm.success_callback = [mock_callback, "langfuse"]
|
||||
|
||||
completion_stream = ModelResponseListIterator(
|
||||
model_responses=bedrock_chunks, delay=0.1
|
||||
)
|
||||
|
||||
response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model="bedrock/claude-3-5-sonnet-20240620-v1:0",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
assert logging_obj.model_call_details["completion_start_time"] is not None
|
||||
assert (
|
||||
logging_obj.model_call_details["completion_start_time"]
|
||||
< logging_obj.model_call_details["end_time"]
|
||||
)
|
||||
|
|
|
@ -1379,3 +1379,20 @@ def test_azure_modalities_param():
|
|||
)
|
||||
assert optional_params["modalities"] == ["text", "audio"]
|
||||
assert optional_params["audio"] == {"type": "audio_input", "input": "test.wav"}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, provider",
|
||||
[
|
||||
("claude-3-7-sonnet-20240620-v1:0", "anthropic"),
|
||||
("anthropic.claude-3-7-sonnet-20250219-v1:0", "bedrock"),
|
||||
("invoke/anthropic.claude-3-7-sonnet-20240620-v1:0", "bedrock"),
|
||||
("claude-3-7-sonnet@20250219", "vertex_ai"),
|
||||
],
|
||||
)
|
||||
def test_anthropic_unified_reasoning_content(model, provider):
|
||||
optional_params = get_optional_params(
|
||||
model=model,
|
||||
custom_llm_provider=provider,
|
||||
reasoning_effort="high",
|
||||
)
|
||||
assert optional_params["thinking"] == {"type": "enabled", "budget_tokens": 4096}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue