fix(streaming_handler.py): support logging complete streaming response on cache hit

This commit is contained in:
Krrish Dholakia 2025-03-17 18:10:39 -07:00
parent ba6369e359
commit 301375bf84
3 changed files with 66 additions and 22 deletions

View file

@ -790,6 +790,7 @@ class LLMCachingHandler:
- Else append the chunk to self.async_streaming_chunks
"""
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
@ -800,7 +801,6 @@ class LLMCachingHandler:
streaming_chunks=self.async_streaming_chunks,
is_async=True,
)
# if a complete_streaming_response is assembled, add it to the cache
if complete_streaming_response is not None:
await self.async_set_cache(

View file

@ -1481,6 +1481,15 @@ class CustomStreamWrapper:
processed_chunk
)
async def async_cache_streaming_response(self, processed_chunk, cache_hit: bool):
"""
Caches the streaming response
"""
if not cache_hit and self.logging_obj._llm_caching_handler is not None:
await self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
processed_chunk
)
def run_success_logging_and_cache_storage(self, processed_chunk, cache_hit: bool):
"""
Runs success logging in a thread and adds the response to the cache
@ -1711,13 +1720,6 @@ class CustomStreamWrapper:
if processed_chunk is None:
continue
if self.logging_obj._llm_caching_handler is not None:
asyncio.create_task(
self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
processed_chunk=cast(ModelResponse, processed_chunk),
)
)
choice = processed_chunk.choices[0]
if isinstance(choice, StreamingChoices):
self.response_uptil_now += choice.delta.get("content", "") or ""
@ -1788,6 +1790,14 @@ class CustomStreamWrapper:
"usage",
getattr(complete_streaming_response, "usage"),
)
asyncio.create_task(
self.async_cache_streaming_response(
processed_chunk=complete_streaming_response.model_copy(
deep=True
),
cache_hit=cache_hit,
)
)
if self.sent_stream_usage is False and self.send_stream_usage is True:
self.sent_stream_usage = True
return response

View file

@ -8,6 +8,7 @@ from dotenv import load_dotenv
load_dotenv()
import os
import json
sys.path.insert(
0, os.path.abspath("../..")
@ -2146,29 +2147,62 @@ async def test_redis_proxy_batch_redis_get_cache():
assert "cache_key" in response._hidden_params
def test_logging_turn_off_message_logging_streaming():
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_logging_turn_off_message_logging_streaming(sync_mode):
litellm.turn_off_message_logging = True
mock_obj = Cache(type="local")
litellm.cache = mock_obj
with patch.object(mock_obj, "add_cache") as mock_client:
with patch.object(mock_obj, "add_cache") as mock_client, patch.object(
mock_obj, "async_add_cache"
) as mock_async_client:
print(f"mock_obj.add_cache: {mock_obj.add_cache}")
resp = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
mock_response="hello",
stream=True,
)
if sync_mode is True:
resp = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
mock_response="hello",
stream=True,
)
for chunk in resp:
continue
for chunk in resp:
continue
time.sleep(1)
time.sleep(1)
mock_client.assert_called_once()
print(f"mock_client.call_args: {mock_client.call_args}")
assert mock_client.call_args.args[0].choices[0].message.content == "hello"
else:
resp = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
mock_response="hello",
stream=True,
)
mock_client.assert_called_once()
print(f"mock_client.call_args: {mock_client.call_args}")
assert mock_client.call_args.args[0].choices[0].message.content == "hello"
async for chunk in resp:
continue
await asyncio.sleep(1)
mock_async_client.assert_called_once()
print(f"mock_async_client.call_args: {mock_async_client.call_args.args[0]}")
print(
f"mock_async_client.call_args: {json.loads(mock_async_client.call_args.args[0])}"
)
json_mock = json.loads(mock_async_client.call_args.args[0])
try:
assert json_mock["choices"][0]["message"]["content"] == "hello"
except Exception as e:
print(
f"mock_async_client.call_args.args[0]: {mock_async_client.call_args.args[0]}"
)
print(
f"mock_async_client.call_args.args[0]['choices']: {mock_async_client.call_args.args[0]['choices']}"
)
raise e
def test_basic_caching_import():