fix(utils.py): fix cache hits for streaming

Fixes https://github.com/BerriAI/litellm/issues/4109
This commit is contained in:
Krrish Dholakia 2024-07-26 19:03:42 -07:00
parent c0717133a9
commit fe0b55f2ca
5 changed files with 42 additions and 16 deletions

View file

@ -10009,6 +10009,12 @@ class CustomStreamWrapper:
return model_response
def __next__(self):
cache_hit = False
if (
self.custom_llm_provider is not None
and self.custom_llm_provider == "cached_response"
):
cache_hit = True
try:
if self.completion_stream is None:
self.fetch_sync_stream()
@ -10073,7 +10079,8 @@ class CustomStreamWrapper:
response.usage = complete_streaming_response.usage # type: ignore
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(response,)
target=self.logging_obj.success_handler,
args=(response, None, None, cache_hit),
).start() # log response
self.sent_stream_usage = True
return response
@ -10083,7 +10090,8 @@ class CustomStreamWrapper:
processed_chunk = self.finish_reason_handler()
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(processed_chunk,)
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log response
return processed_chunk
except Exception as e:
@ -10120,6 +10128,12 @@ class CustomStreamWrapper:
return self.completion_stream
async def __anext__(self):
cache_hit = False
if (
self.custom_llm_provider is not None
and self.custom_llm_provider == "cached_response"
):
cache_hit = True
try:
if self.completion_stream is None:
await self.fetch_stream()
@ -10174,11 +10188,12 @@ class CustomStreamWrapper:
continue
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(processed_chunk,)
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk,
processed_chunk, cache_hit=cache_hit
)
)
self.response_uptil_now += (
@ -10225,11 +10240,11 @@ class CustomStreamWrapper:
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk,),
args=(processed_chunk, None, None, cache_hit),
).start() # log processed_chunk
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk,
processed_chunk, cache_hit=cache_hit
)
)
@ -10257,11 +10272,12 @@ class CustomStreamWrapper:
response.usage = complete_streaming_response.usage
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(response,)
target=self.logging_obj.success_handler,
args=(response, None, None, cache_hit),
).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
response,
response, cache_hit=cache_hit
)
)
self.sent_stream_usage = True
@ -10272,11 +10288,12 @@ class CustomStreamWrapper:
processed_chunk = self.finish_reason_handler()
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(processed_chunk,)
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk,
processed_chunk, cache_hit=cache_hit
)
)
return processed_chunk
@ -10295,11 +10312,12 @@ class CustomStreamWrapper:
response.usage = complete_streaming_response.usage
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(response,)
target=self.logging_obj.success_handler,
args=(response, None, None, cache_hit),
).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
response,
response, cache_hit=cache_hit
)
)
self.sent_stream_usage = True
@ -10310,11 +10328,12 @@ class CustomStreamWrapper:
processed_chunk = self.finish_reason_handler()
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(processed_chunk,)
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk,
processed_chunk, cache_hit=cache_hit
)
)
return processed_chunk