fix(utils.py): fix cache hits for streaming

Fixes https://github.com/BerriAI/litellm/issues/4109
This commit is contained in:
Krrish Dholakia 2024-07-26 19:03:42 -07:00
parent c0717133a9
commit fe0b55f2ca
5 changed files with 42 additions and 16 deletions

View file

@ -463,7 +463,7 @@ class OpenTelemetry(CustomLogger):
############################################# #############################################
# OTEL Attributes for the RAW Request to https://docs.anthropic.com/en/api/messages # OTEL Attributes for the RAW Request to https://docs.anthropic.com/en/api/messages
if complete_input_dict: if complete_input_dict and isinstance(complete_input_dict, dict):
for param, val in complete_input_dict.items(): for param, val in complete_input_dict.items():
if not isinstance(val, str): if not isinstance(val, str):
val = str(val) val = str(val)

View file

@ -1220,7 +1220,9 @@ class Logging:
""" """
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
""" """
print_verbose("Logging Details LiteLLM-Async Success Call") print_verbose(
"Logging Details LiteLLM-Async Success Call, cache_hit={}".format(cache_hit)
)
start_time, end_time, result = self._success_handler_helper_fn( start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
) )

View file

@ -2,3 +2,7 @@ model_list:
- model_name: "*" - model_name: "*"
litellm_params: litellm_params:
model: "*" model: "*"
litellm_settings:
success_callback: ["logfire"]
cache: true

View file

@ -625,6 +625,7 @@ def test_chat_completion_optional_params(mock_acompletion, client_no_auth):
# Run the test # Run the test
# test_chat_completion_optional_params() # test_chat_completion_optional_params()
# Test Reading config.yaml file # Test Reading config.yaml file
from litellm.proxy.proxy_server import ProxyConfig from litellm.proxy.proxy_server import ProxyConfig

View file

@ -10009,6 +10009,12 @@ class CustomStreamWrapper:
return model_response return model_response
def __next__(self): def __next__(self):
cache_hit = False
if (
self.custom_llm_provider is not None
and self.custom_llm_provider == "cached_response"
):
cache_hit = True
try: try:
if self.completion_stream is None: if self.completion_stream is None:
self.fetch_sync_stream() self.fetch_sync_stream()
@ -10073,7 +10079,8 @@ class CustomStreamWrapper:
response.usage = complete_streaming_response.usage # type: ignore response.usage = complete_streaming_response.usage # type: ignore
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, args=(response,) target=self.logging_obj.success_handler,
args=(response, None, None, cache_hit),
).start() # log response ).start() # log response
self.sent_stream_usage = True self.sent_stream_usage = True
return response return response
@ -10083,7 +10090,8 @@ class CustomStreamWrapper:
processed_chunk = self.finish_reason_handler() processed_chunk = self.finish_reason_handler()
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, args=(processed_chunk,) target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log response ).start() # log response
return processed_chunk return processed_chunk
except Exception as e: except Exception as e:
@ -10120,6 +10128,12 @@ class CustomStreamWrapper:
return self.completion_stream return self.completion_stream
async def __anext__(self): async def __anext__(self):
cache_hit = False
if (
self.custom_llm_provider is not None
and self.custom_llm_provider == "cached_response"
):
cache_hit = True
try: try:
if self.completion_stream is None: if self.completion_stream is None:
await self.fetch_stream() await self.fetch_stream()
@ -10174,11 +10188,12 @@ class CustomStreamWrapper:
continue continue
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, args=(processed_chunk,) target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log response ).start() # log response
asyncio.create_task( asyncio.create_task(
self.logging_obj.async_success_handler( self.logging_obj.async_success_handler(
processed_chunk, processed_chunk, cache_hit=cache_hit
) )
) )
self.response_uptil_now += ( self.response_uptil_now += (
@ -10225,11 +10240,11 @@ class CustomStreamWrapper:
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, target=self.logging_obj.success_handler,
args=(processed_chunk,), args=(processed_chunk, None, None, cache_hit),
).start() # log processed_chunk ).start() # log processed_chunk
asyncio.create_task( asyncio.create_task(
self.logging_obj.async_success_handler( self.logging_obj.async_success_handler(
processed_chunk, processed_chunk, cache_hit=cache_hit
) )
) )
@ -10257,11 +10272,12 @@ class CustomStreamWrapper:
response.usage = complete_streaming_response.usage response.usage = complete_streaming_response.usage
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, args=(response,) target=self.logging_obj.success_handler,
args=(response, None, None, cache_hit),
).start() # log response ).start() # log response
asyncio.create_task( asyncio.create_task(
self.logging_obj.async_success_handler( self.logging_obj.async_success_handler(
response, response, cache_hit=cache_hit
) )
) )
self.sent_stream_usage = True self.sent_stream_usage = True
@ -10272,11 +10288,12 @@ class CustomStreamWrapper:
processed_chunk = self.finish_reason_handler() processed_chunk = self.finish_reason_handler()
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, args=(processed_chunk,) target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log response ).start() # log response
asyncio.create_task( asyncio.create_task(
self.logging_obj.async_success_handler( self.logging_obj.async_success_handler(
processed_chunk, processed_chunk, cache_hit=cache_hit
) )
) )
return processed_chunk return processed_chunk
@ -10295,11 +10312,12 @@ class CustomStreamWrapper:
response.usage = complete_streaming_response.usage response.usage = complete_streaming_response.usage
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, args=(response,) target=self.logging_obj.success_handler,
args=(response, None, None, cache_hit),
).start() # log response ).start() # log response
asyncio.create_task( asyncio.create_task(
self.logging_obj.async_success_handler( self.logging_obj.async_success_handler(
response, response, cache_hit=cache_hit
) )
) )
self.sent_stream_usage = True self.sent_stream_usage = True
@ -10310,11 +10328,12 @@ class CustomStreamWrapper:
processed_chunk = self.finish_reason_handler() processed_chunk = self.finish_reason_handler()
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, args=(processed_chunk,) target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log response ).start() # log response
asyncio.create_task( asyncio.create_task(
self.logging_obj.async_success_handler( self.logging_obj.async_success_handler(
processed_chunk, processed_chunk, cache_hit=cache_hit
) )
) )
return processed_chunk return processed_chunk