This commit is contained in:
Yuki Watanabe 2025-04-23 00:48:24 -07:00 committed by GitHub
commit 2c20b3726b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 189 additions and 104 deletions

View file

@ -123,6 +123,8 @@ _known_custom_logger_compatible_callbacks: List = list(
callbacks: List[
Union[Callable, _custom_logger_compatible_callbacks_literal, CustomLogger]
] = []
# If true, callbacks will be executed synchronously (blocking).
sync_logging: bool = False
langfuse_default_tags: Optional[List[str]] = None
langsmith_batch_size: Optional[int] = None
prometheus_initialize_budget_metrics: Optional[bool] = False

View file

@ -277,10 +277,7 @@ class LLMCachingHandler:
is_async=False,
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
logging_obj.success_handler(cached_result, start_time, end_time, True)
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
**kwargs
)
@ -446,10 +443,7 @@ class LLMCachingHandler:
cached_result, start_time, end_time, cache_hit
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
logging_obj.success_handler(cached_result, start_time, end_time, cache_hit)
async def _retrieve_from_cache(
self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...]

View file

@ -125,7 +125,11 @@ class MlflowLogger(CustomLogger):
# If this is the final chunk, end the span. The final chunk
# has complete_streaming_response that gathers the full response.
if final_response := kwargs.get("complete_streaming_response"):
final_response = (
kwargs.get("complete_streaming_response")
or kwargs.get("async_complete_streaming_response")
)
if final_response:
end_time_ns = int(end_time.timestamp() * 1e9)
self._extract_and_set_chat_attributes(span, kwargs, final_response)

View file

@ -8,12 +8,13 @@ import os
import re
import subprocess
import sys
import threading
import time
import traceback
import uuid
from datetime import datetime as dt_object
from functools import lru_cache
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast
from pydantic import BaseModel
@ -1226,7 +1227,36 @@ class Logging(LiteLLMLoggingBaseClass):
except Exception as e:
raise Exception(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}")
def success_handler( # noqa: PLR0915
def success_handler(
self,
result=None,
start_time=None,
end_time=None,
cache_hit=None,
synchronous=None,
**kwargs
):
"""
Execute the success handler function in a sync or async manner.
If synchronous argument is not provided, global `litellm.sync_logging` config is used.
"""
if synchronous is None:
synchronous = litellm.sync_logging
if synchronous:
self._success_handler(result, start_time, end_time, cache_hit, **kwargs)
else:
executor.submit(
self._success_handler,
result,
start_time,
end_time,
cache_hit,
**kwargs,
)
def _success_handler( # noqa: PLR0915
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
):
verbose_logger.debug(
@ -2376,12 +2406,7 @@ class Logging(LiteLLMLoggingBaseClass):
if self._should_run_sync_callbacks_for_async_calls() is False:
return
executor.submit(
self.success_handler,
result,
start_time,
end_time,
)
self.success_handler(result, start_time, end_time)
def _should_run_sync_callbacks_for_async_calls(self) -> bool:
"""

View file

@ -1445,32 +1445,47 @@ class CustomStreamWrapper:
"""
Runs success logging in a thread and adds the response to the cache
"""
if litellm.disable_streaming_logging is True:
"""
[NOT RECOMMENDED]
Set this via `litellm.disable_streaming_logging = True`.
def _run():
if litellm.disable_streaming_logging is True:
"""
[NOT RECOMMENDED]
Set this via `litellm.disable_streaming_logging = True`.
Disables streaming logging.
"""
return
## ASYNC LOGGING
# Create an event loop for the new thread
if self.logging_loop is not None:
future = asyncio.run_coroutine_threadsafe(
self.logging_obj.async_success_handler(
processed_chunk, None, None, cache_hit
),
loop=self.logging_loop,
)
future.result()
else:
asyncio.run(
self.logging_obj.async_success_handler(
processed_chunk, None, None, cache_hit
Disables streaming logging.
"""
return
if not litellm.sync_logging:
## ASYNC LOGGING
# Create an event loop for the new thread
if self.logging_loop is not None:
future = asyncio.run_coroutine_threadsafe(
self.logging_obj.async_success_handler(
processed_chunk, None, None, cache_hit
),
loop=self.logging_loop,
)
future.result()
else:
asyncio.run(
self.logging_obj.async_success_handler(
processed_chunk, None, None, cache_hit
)
)
## SYNC LOGGING
self.logging_obj.success_handler(processed_chunk, None, None, cache_hit)
## Sync store in cache
if self.logging_obj._llm_caching_handler is not None:
self.logging_obj._llm_caching_handler._sync_add_streaming_response_to_cache(
processed_chunk
)
)
## SYNC LOGGING
self.logging_obj.success_handler(processed_chunk, None, None, cache_hit)
if litellm.sync_logging:
_run()
else:
executor.submit(_run)
def finish_reason_handler(self):
model_response = self.model_response_creator()
@ -1522,11 +1537,8 @@ class CustomStreamWrapper:
completion_start_time=datetime.datetime.now()
)
## LOGGING
executor.submit(
self.run_success_logging_and_cache_storage,
response,
cache_hit,
) # log response
self.run_success_logging_and_cache_storage(response, cache_hit)
choice = response.choices[0]
if isinstance(choice, StreamingChoices):
self.response_uptil_now += choice.delta.get("content", "") or ""
@ -1576,21 +1588,12 @@ class CustomStreamWrapper:
),
cache_hit=cache_hit,
)
executor.submit(
self.logging_obj.success_handler,
complete_streaming_response.model_copy(deep=True),
None,
None,
cache_hit,
)
logging_result = complete_streaming_response.model_copy(deep=True)
else:
executor.submit(
self.logging_obj.success_handler,
response,
None,
None,
cache_hit,
)
logging_result = response
self.logging_obj.success_handler(logging_result, None, None, cache_hit)
if self.sent_stream_usage is False and self.send_stream_usage is True:
self.sent_stream_usage = True
return response
@ -1602,11 +1605,7 @@ class CustomStreamWrapper:
usage = calculate_total_usage(chunks=self.chunks)
processed_chunk._hidden_params["usage"] = usage
## LOGGING
executor.submit(
self.run_success_logging_and_cache_storage,
processed_chunk,
cache_hit,
) # log response
self.run_success_logging_and_cache_storage(processed_chunk, cache_hit)
return processed_chunk
except Exception as e:
traceback_exception = traceback.format_exc()
@ -1762,22 +1761,19 @@ class CustomStreamWrapper:
self.sent_stream_usage = True
return response
asyncio.create_task(
self.logging_obj.async_success_handler(
complete_streaming_response,
cache_hit=cache_hit,
start_time=None,
end_time=None,
)
)
executor.submit(
self.logging_obj.success_handler,
complete_streaming_response,
logging_params = dict(
result=complete_streaming_response,
cache_hit=cache_hit,
start_time=None,
end_time=None,
)
if litellm.sync_logging:
await self.logging_obj.async_success_handler(**logging_params)
else:
asyncio.create_task(self.logging_obj.async_success_handler(**logging_params))
self.logging_obj.success_handler(**logging_params)
raise StopAsyncIteration # Re-raise StopIteration
else:

View file

@ -122,15 +122,9 @@ class PassThroughStreamingHandler:
standard_logging_response_object = StandardPassThroughResponseObject(
response=f"cannot parse chunks to standard response object. Chunks={all_chunks}"
)
threading.Thread(
target=litellm_logging_obj.success_handler,
args=(
standard_logging_response_object,
start_time,
end_time,
False,
),
).start()
litellm_logging_obj.success_handler(
standard_logging_response_object, start_time, end_time, False
)
await litellm_logging_obj.async_success_handler(
result=standard_logging_response_object,
start_time=start_time,

View file

@ -170,8 +170,7 @@ class ResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
)
)
executor.submit(
self.logging_obj.success_handler,
self.logging_obj.success_handler(
result=self.completed_response,
cache_hit=None,
start_time=self.start_time,

View file

@ -808,9 +808,7 @@ async def _client_async_logging_helper(
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
)
# check if user does not want this to be logged
asyncio.create_task(
logging_obj.async_success_handler(result, start_time, end_time)
)
await logging_obj.async_success_handler(result, start_time, end_time)
logging_obj.handle_sync_success_callbacks_for_async_calls(
result=result,
start_time=start_time,
@ -1183,12 +1181,8 @@ def client(original_function): # noqa: PLR0915
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
verbose_logger.info("Wrapper: Completed Call, calling success_handler")
executor.submit(
logging_obj.success_handler,
result,
start_time,
end_time,
)
logging_obj.success_handler(result, start_time, end_time)
# RETURN RESULT
update_response_metadata(
result=result,
@ -1357,15 +1351,18 @@ def client(original_function): # noqa: PLR0915
)
# LOG SUCCESS - handle streaming success logging in the _next_ object
asyncio.create_task(
_client_async_logging_helper(
logging_obj=logging_obj,
result=result,
start_time=start_time,
end_time=end_time,
is_completion_with_fallbacks=is_completion_with_fallbacks,
)
async_logging_params = dict(
logging_obj=logging_obj,
result=result,
start_time=start_time,
end_time=end_time,
is_completion_with_fallbacks=is_completion_with_fallbacks,
)
if litellm.sync_logging:
await _client_async_logging_helper(**async_logging_params)
else:
asyncio.create_task(_client_async_logging_helper(**async_logging_params))
logging_obj.handle_sync_success_callbacks_for_async_calls(
result=result,
start_time=start_time,

View file

@ -2,8 +2,12 @@ import json
import os
import sys
from datetime import datetime
import threading
from unittest.mock import AsyncMock
from litellm.litellm_core_utils.dd_tracing import contextmanager
from litellm.utils import executor
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
@ -297,6 +301,7 @@ def test_dynamic_logging_global_callback():
start_time=datetime.now(),
end_time=datetime.now(),
cache_hit=False,
synchronous=True,
)
except Exception as e:
print(f"Error: {e}")
@ -323,3 +328,72 @@ def test_get_combined_callback_list():
assert "lago" in _logging.get_combined_callback_list(
dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"]
)
@pytest.mark.parametrize("sync_logging", [True, False])
def test_success_handler_sync_async(sync_logging):
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import ModelResponse, Choices, Message
cl = CustomLogger()
litellm_logging = LiteLLMLoggingObj(
model="claude-3-opus-20240229",
messages=[{"role": "user", "content": "hi"}],
stream=False,
call_type="completion",
start_time=datetime.now(),
litellm_call_id="123",
function_id="456",
)
litellm.sync_logging = sync_logging
result = ModelResponse(
id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70",
created=1732306261,
model="claude-3-opus-20240229",
object="chat.completion",
choices=[
Choices(
finish_reason="stop",
index=0,
message=Message(
content="hello",
role="assistant",
tool_calls=None,
function_call=None,
),
)
],
)
with (
patch.object(cl, "log_success_event") as mock_log_success_event,
patch.object(
executor, "submit",
side_effect=lambda *args: args[0](*args[1:])
) as mock_executor,
):
litellm.success_callback = [cl]
litellm_logging.success_handler(
result=result,
start_time=datetime.now(),
end_time=datetime.now(),
cache_hit=False,
)
if sync_logging:
mock_executor.assert_not_called()
mock_log_success_event.assert_called_once()
assert "standard_logging_object" in mock_log_success_event.call_args.kwargs["kwargs"]
else:
# Wait for the thread to finish
mock_executor.assert_called_once()
mock_log_success_event.assert_called_once()
assert "standard_logging_object" in mock_log_success_event.call_args.kwargs["kwargs"]