This commit is contained in:
Yuki Watanabe 2025-04-23 00:48:24 -07:00 committed by GitHub
commit 2c20b3726b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 189 additions and 104 deletions

View file

@ -123,6 +123,8 @@ _known_custom_logger_compatible_callbacks: List = list(
callbacks: List[ callbacks: List[
Union[Callable, _custom_logger_compatible_callbacks_literal, CustomLogger] Union[Callable, _custom_logger_compatible_callbacks_literal, CustomLogger]
] = [] ] = []
# If true, callbacks will be executed synchronously (blocking).
sync_logging: bool = False
langfuse_default_tags: Optional[List[str]] = None langfuse_default_tags: Optional[List[str]] = None
langsmith_batch_size: Optional[int] = None langsmith_batch_size: Optional[int] = None
prometheus_initialize_budget_metrics: Optional[bool] = False prometheus_initialize_budget_metrics: Optional[bool] = False

View file

@ -277,10 +277,7 @@ class LLMCachingHandler:
is_async=False, is_async=False,
) )
threading.Thread( logging_obj.success_handler(cached_result, start_time, end_time, True)
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
cache_key = litellm.cache._get_preset_cache_key_from_kwargs( cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
**kwargs **kwargs
) )
@ -446,10 +443,7 @@ class LLMCachingHandler:
cached_result, start_time, end_time, cache_hit cached_result, start_time, end_time, cache_hit
) )
) )
threading.Thread( logging_obj.success_handler(cached_result, start_time, end_time, cache_hit)
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
async def _retrieve_from_cache( async def _retrieve_from_cache(
self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...] self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...]

View file

@ -125,7 +125,11 @@ class MlflowLogger(CustomLogger):
# If this is the final chunk, end the span. The final chunk # If this is the final chunk, end the span. The final chunk
# has complete_streaming_response that gathers the full response. # has complete_streaming_response that gathers the full response.
if final_response := kwargs.get("complete_streaming_response"): final_response = (
kwargs.get("complete_streaming_response")
or kwargs.get("async_complete_streaming_response")
)
if final_response:
end_time_ns = int(end_time.timestamp() * 1e9) end_time_ns = int(end_time.timestamp() * 1e9)
self._extract_and_set_chat_attributes(span, kwargs, final_response) self._extract_and_set_chat_attributes(span, kwargs, final_response)

View file

@ -8,12 +8,13 @@ import os
import re import re
import subprocess import subprocess
import sys import sys
import threading
import time import time
import traceback import traceback
import uuid import uuid
from datetime import datetime as dt_object from datetime import datetime as dt_object
from functools import lru_cache from functools import lru_cache
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast
from pydantic import BaseModel from pydantic import BaseModel
@ -1226,7 +1227,36 @@ class Logging(LiteLLMLoggingBaseClass):
except Exception as e: except Exception as e:
raise Exception(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}") raise Exception(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}")
def success_handler( # noqa: PLR0915 def success_handler(
self,
result=None,
start_time=None,
end_time=None,
cache_hit=None,
synchronous=None,
**kwargs
):
"""
Execute the success handler function in a sync or async manner.
If synchronous argument is not provided, global `litellm.sync_logging` config is used.
"""
if synchronous is None:
synchronous = litellm.sync_logging
if synchronous:
self._success_handler(result, start_time, end_time, cache_hit, **kwargs)
else:
executor.submit(
self._success_handler,
result,
start_time,
end_time,
cache_hit,
**kwargs,
)
def _success_handler( # noqa: PLR0915
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
): ):
verbose_logger.debug( verbose_logger.debug(
@ -2376,12 +2406,7 @@ class Logging(LiteLLMLoggingBaseClass):
if self._should_run_sync_callbacks_for_async_calls() is False: if self._should_run_sync_callbacks_for_async_calls() is False:
return return
executor.submit( self.success_handler(result, start_time, end_time)
self.success_handler,
result,
start_time,
end_time,
)
def _should_run_sync_callbacks_for_async_calls(self) -> bool: def _should_run_sync_callbacks_for_async_calls(self) -> bool:
""" """

View file

@ -1445,6 +1445,7 @@ class CustomStreamWrapper:
""" """
Runs success logging in a thread and adds the response to the cache Runs success logging in a thread and adds the response to the cache
""" """
def _run():
if litellm.disable_streaming_logging is True: if litellm.disable_streaming_logging is True:
""" """
[NOT RECOMMENDED] [NOT RECOMMENDED]
@ -1453,6 +1454,8 @@ class CustomStreamWrapper:
Disables streaming logging. Disables streaming logging.
""" """
return return
if not litellm.sync_logging:
## ASYNC LOGGING ## ASYNC LOGGING
# Create an event loop for the new thread # Create an event loop for the new thread
if self.logging_loop is not None: if self.logging_loop is not None:
@ -1469,9 +1472,21 @@ class CustomStreamWrapper:
processed_chunk, None, None, cache_hit processed_chunk, None, None, cache_hit
) )
) )
## SYNC LOGGING ## SYNC LOGGING
self.logging_obj.success_handler(processed_chunk, None, None, cache_hit) self.logging_obj.success_handler(processed_chunk, None, None, cache_hit)
## Sync store in cache
if self.logging_obj._llm_caching_handler is not None:
self.logging_obj._llm_caching_handler._sync_add_streaming_response_to_cache(
processed_chunk
)
if litellm.sync_logging:
_run()
else:
executor.submit(_run)
def finish_reason_handler(self): def finish_reason_handler(self):
model_response = self.model_response_creator() model_response = self.model_response_creator()
_finish_reason = self.received_finish_reason or self.intermittent_finish_reason _finish_reason = self.received_finish_reason or self.intermittent_finish_reason
@ -1522,11 +1537,8 @@ class CustomStreamWrapper:
completion_start_time=datetime.datetime.now() completion_start_time=datetime.datetime.now()
) )
## LOGGING ## LOGGING
executor.submit( self.run_success_logging_and_cache_storage(response, cache_hit)
self.run_success_logging_and_cache_storage,
response,
cache_hit,
) # log response
choice = response.choices[0] choice = response.choices[0]
if isinstance(choice, StreamingChoices): if isinstance(choice, StreamingChoices):
self.response_uptil_now += choice.delta.get("content", "") or "" self.response_uptil_now += choice.delta.get("content", "") or ""
@ -1576,21 +1588,12 @@ class CustomStreamWrapper:
), ),
cache_hit=cache_hit, cache_hit=cache_hit,
) )
executor.submit( logging_result = complete_streaming_response.model_copy(deep=True)
self.logging_obj.success_handler,
complete_streaming_response.model_copy(deep=True),
None,
None,
cache_hit,
)
else: else:
executor.submit( logging_result = response
self.logging_obj.success_handler,
response, self.logging_obj.success_handler(logging_result, None, None, cache_hit)
None,
None,
cache_hit,
)
if self.sent_stream_usage is False and self.send_stream_usage is True: if self.sent_stream_usage is False and self.send_stream_usage is True:
self.sent_stream_usage = True self.sent_stream_usage = True
return response return response
@ -1602,11 +1605,7 @@ class CustomStreamWrapper:
usage = calculate_total_usage(chunks=self.chunks) usage = calculate_total_usage(chunks=self.chunks)
processed_chunk._hidden_params["usage"] = usage processed_chunk._hidden_params["usage"] = usage
## LOGGING ## LOGGING
executor.submit( self.run_success_logging_and_cache_storage(processed_chunk, cache_hit)
self.run_success_logging_and_cache_storage,
processed_chunk,
cache_hit,
) # log response
return processed_chunk return processed_chunk
except Exception as e: except Exception as e:
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
@ -1762,22 +1761,19 @@ class CustomStreamWrapper:
self.sent_stream_usage = True self.sent_stream_usage = True
return response return response
asyncio.create_task(
self.logging_obj.async_success_handler(
complete_streaming_response,
cache_hit=cache_hit,
start_time=None,
end_time=None,
)
)
executor.submit( logging_params = dict(
self.logging_obj.success_handler, result=complete_streaming_response,
complete_streaming_response,
cache_hit=cache_hit, cache_hit=cache_hit,
start_time=None, start_time=None,
end_time=None, end_time=None,
) )
if litellm.sync_logging:
await self.logging_obj.async_success_handler(**logging_params)
else:
asyncio.create_task(self.logging_obj.async_success_handler(**logging_params))
self.logging_obj.success_handler(**logging_params)
raise StopAsyncIteration # Re-raise StopIteration raise StopAsyncIteration # Re-raise StopIteration
else: else:

View file

@ -122,15 +122,9 @@ class PassThroughStreamingHandler:
standard_logging_response_object = StandardPassThroughResponseObject( standard_logging_response_object = StandardPassThroughResponseObject(
response=f"cannot parse chunks to standard response object. Chunks={all_chunks}" response=f"cannot parse chunks to standard response object. Chunks={all_chunks}"
) )
threading.Thread( litellm_logging_obj.success_handler(
target=litellm_logging_obj.success_handler, standard_logging_response_object, start_time, end_time, False
args=( )
standard_logging_response_object,
start_time,
end_time,
False,
),
).start()
await litellm_logging_obj.async_success_handler( await litellm_logging_obj.async_success_handler(
result=standard_logging_response_object, result=standard_logging_response_object,
start_time=start_time, start_time=start_time,

View file

@ -170,8 +170,7 @@ class ResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
) )
) )
executor.submit( self.logging_obj.success_handler(
self.logging_obj.success_handler,
result=self.completed_response, result=self.completed_response,
cache_hit=None, cache_hit=None,
start_time=self.start_time, start_time=self.start_time,

View file

@ -808,9 +808,7 @@ async def _client_async_logging_helper(
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
) )
# check if user does not want this to be logged # check if user does not want this to be logged
asyncio.create_task( await logging_obj.async_success_handler(result, start_time, end_time)
logging_obj.async_success_handler(result, start_time, end_time)
)
logging_obj.handle_sync_success_callbacks_for_async_calls( logging_obj.handle_sync_success_callbacks_for_async_calls(
result=result, result=result,
start_time=start_time, start_time=start_time,
@ -1183,12 +1181,8 @@ def client(original_function): # noqa: PLR0915
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
verbose_logger.info("Wrapper: Completed Call, calling success_handler") verbose_logger.info("Wrapper: Completed Call, calling success_handler")
executor.submit( logging_obj.success_handler(result, start_time, end_time)
logging_obj.success_handler,
result,
start_time,
end_time,
)
# RETURN RESULT # RETURN RESULT
update_response_metadata( update_response_metadata(
result=result, result=result,
@ -1357,15 +1351,18 @@ def client(original_function): # noqa: PLR0915
) )
# LOG SUCCESS - handle streaming success logging in the _next_ object # LOG SUCCESS - handle streaming success logging in the _next_ object
asyncio.create_task( async_logging_params = dict(
_client_async_logging_helper(
logging_obj=logging_obj, logging_obj=logging_obj,
result=result, result=result,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
is_completion_with_fallbacks=is_completion_with_fallbacks, is_completion_with_fallbacks=is_completion_with_fallbacks,
) )
) if litellm.sync_logging:
await _client_async_logging_helper(**async_logging_params)
else:
asyncio.create_task(_client_async_logging_helper(**async_logging_params))
logging_obj.handle_sync_success_callbacks_for_async_calls( logging_obj.handle_sync_success_callbacks_for_async_calls(
result=result, result=result,
start_time=start_time, start_time=start_time,

View file

@ -2,8 +2,12 @@ import json
import os import os
import sys import sys
from datetime import datetime from datetime import datetime
import threading
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
from litellm.litellm_core_utils.dd_tracing import contextmanager
from litellm.utils import executor
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system-path ) # Adds the parent directory to the system-path
@ -297,6 +301,7 @@ def test_dynamic_logging_global_callback():
start_time=datetime.now(), start_time=datetime.now(),
end_time=datetime.now(), end_time=datetime.now(),
cache_hit=False, cache_hit=False,
synchronous=True,
) )
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
@ -323,3 +328,72 @@ def test_get_combined_callback_list():
assert "lago" in _logging.get_combined_callback_list( assert "lago" in _logging.get_combined_callback_list(
dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"] dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"]
) )
@pytest.mark.parametrize("sync_logging", [True, False])
def test_success_handler_sync_async(sync_logging):
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import ModelResponse, Choices, Message
cl = CustomLogger()
litellm_logging = LiteLLMLoggingObj(
model="claude-3-opus-20240229",
messages=[{"role": "user", "content": "hi"}],
stream=False,
call_type="completion",
start_time=datetime.now(),
litellm_call_id="123",
function_id="456",
)
litellm.sync_logging = sync_logging
result = ModelResponse(
id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70",
created=1732306261,
model="claude-3-opus-20240229",
object="chat.completion",
choices=[
Choices(
finish_reason="stop",
index=0,
message=Message(
content="hello",
role="assistant",
tool_calls=None,
function_call=None,
),
)
],
)
with (
patch.object(cl, "log_success_event") as mock_log_success_event,
patch.object(
executor, "submit",
side_effect=lambda *args: args[0](*args[1:])
) as mock_executor,
):
litellm.success_callback = [cl]
litellm_logging.success_handler(
result=result,
start_time=datetime.now(),
end_time=datetime.now(),
cache_hit=False,
)
if sync_logging:
mock_executor.assert_not_called()
mock_log_success_event.assert_called_once()
assert "standard_logging_object" in mock_log_success_event.call_args.kwargs["kwargs"]
else:
# Wait for the thread to finish
mock_executor.assert_called_once()
mock_log_success_event.assert_called_once()
assert "standard_logging_object" in mock_log_success_event.call_args.kwargs["kwargs"]