fix(utils.py): correctly log streaming cache hits (#5417) (#5426)

Fixes https://github.com/BerriAI/litellm/issues/5401
This commit is contained in:
Krish Dholakia 2024-08-28 22:50:33 -07:00 committed by GitHub
parent 8ce4b8e195
commit 3e8f5009f4
3 changed files with 103 additions and 26 deletions

View file

@ -610,12 +610,23 @@ class Logging:
self.model_call_details["litellm_params"]["metadata"][
"hidden_params"
] = result._hidden_params
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
)
)
else: # streaming chunks + image gen.
self.model_call_details["response_cost"] = None
if (
litellm.max_budget
and self.stream == False
and self.stream is False
and result is not None
and "content" in result
):
@ -628,17 +639,6 @@ class Logging:
total_time=float_diff,
)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
)
)
return start_time, end_time, result
except Exception as e:
raise Exception(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}")
@ -646,9 +646,7 @@ class Logging:
def success_handler(
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
):
verbose_logger.debug(
f"Logging Details LiteLLM-Success Call: Cache_hit={cache_hit}"
)
print_verbose(f"Logging Details LiteLLM-Success Call: Cache_hit={cache_hit}")
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time,
end_time=end_time,
@ -695,6 +693,16 @@ class Logging:
self.model_call_details["response_cost"] = (
self._response_cost_calculator(result=complete_streaming_response)
)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=complete_streaming_response,
start_time=start_time,
end_time=end_time,
logging_obj=self,
)
)
if self.dynamic_success_callbacks is not None and isinstance(
self.dynamic_success_callbacks, list
):
@ -714,7 +722,6 @@ class Logging:
)
## LOGGING HOOK ##
for callback in callbacks:
if isinstance(callback, CustomLogger):
self.model_call_details, result = callback.logging_hook(
@ -726,7 +733,7 @@ class Logging:
for callback in callbacks:
try:
litellm_params = self.model_call_details.get("litellm_params", {})
if litellm_params.get("no-log", False) == True:
if litellm_params.get("no-log", False) is True:
# proxy cost tracking cal backs should run
if not (
isinstance(callback, CustomLogger)
@ -1192,6 +1199,7 @@ class Logging:
)
)
result = self.model_call_details["complete_response"]
callback.log_success_event(
kwargs=self.model_call_details,
response_obj=result,
@ -1199,7 +1207,7 @@ class Logging:
end_time=end_time,
)
if (
callable(callback) == True
callable(callback) is True
and self.model_call_details.get("litellm_params", {}).get(
"acompletion", False
)
@ -1301,6 +1309,7 @@ class Logging:
result=complete_streaming_response
)
)
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
)
@ -1310,6 +1319,16 @@ class Logging:
)
self.model_call_details["response_cost"] = None
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=complete_streaming_response,
start_time=start_time,
end_time=end_time,
logging_obj=self,
)
)
if self.dynamic_async_success_callbacks is not None and isinstance(
self.dynamic_async_success_callbacks, list
):

View file

@ -1297,3 +1297,53 @@ def test_aaastandard_logging_payload_cache_hit():
assert standard_logging_object["cache_hit"] is True
assert standard_logging_object["response_cost"] == 0
assert standard_logging_object["saved_cache_cost"] > 0
def test_logging_async_cache_hit_sync_call():
from litellm.types.utils import StandardLoggingPayload
litellm.cache = Cache()
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
caching=True,
stream=True,
)
for chunk in response:
print(chunk)
time.sleep(3)
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
litellm.success_callback = []
with patch.object(
customHandler, "log_success_event", new=MagicMock()
) as mock_client:
resp = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
caching=True,
stream=True,
)
for chunk in resp:
print(chunk)
time.sleep(2)
mock_client.assert_called_once()
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
is not None
)
standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
"kwargs"
]["standard_logging_object"]
assert standard_logging_object["cache_hit"] is True
assert standard_logging_object["response_cost"] == 0
assert standard_logging_object["saved_cache_cost"] > 0

View file

@ -10548,8 +10548,8 @@ class CustomStreamWrapper:
"""
self.logging_loop = loop
def run_success_logging_in_thread(self, processed_chunk):
if litellm.disable_streaming_logging == True:
def run_success_logging_in_thread(self, processed_chunk, cache_hit: bool):
if litellm.disable_streaming_logging is True:
"""
[NOT RECOMMENDED]
Set this via `litellm.disable_streaming_logging = True`.
@ -10561,14 +10561,20 @@ class CustomStreamWrapper:
# Create an event loop for the new thread
if self.logging_loop is not None:
future = asyncio.run_coroutine_threadsafe(
self.logging_obj.async_success_handler(processed_chunk),
self.logging_obj.async_success_handler(
processed_chunk, None, None, cache_hit
),
loop=self.logging_loop,
)
result = future.result()
else:
asyncio.run(self.logging_obj.async_success_handler(processed_chunk))
asyncio.run(
self.logging_obj.async_success_handler(
processed_chunk, None, None, cache_hit
)
)
## SYNC LOGGING
self.logging_obj.success_handler(processed_chunk)
self.logging_obj.success_handler(processed_chunk, None, None, cache_hit)
def finish_reason_handler(self):
model_response = self.model_response_creator()
@ -10616,7 +10622,8 @@ class CustomStreamWrapper:
continue
## LOGGING
threading.Thread(
target=self.run_success_logging_in_thread, args=(response,)
target=self.run_success_logging_in_thread,
args=(response, cache_hit),
).start() # log response
self.response_uptil_now += (
response.choices[0].delta.get("content", "") or ""
@ -10678,8 +10685,8 @@ class CustomStreamWrapper:
processed_chunk._hidden_params["usage"] = usage
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
target=self.run_success_logging_in_thread,
args=(processed_chunk, cache_hit),
).start() # log response
return processed_chunk
except Exception as e:
@ -10776,6 +10783,7 @@ class CustomStreamWrapper:
if processed_chunk is None:
continue
## LOGGING
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),