fix(streaming_handler.py): support logging complete streaming response on cache hit

This commit is contained in:
Krrish Dholakia 2025-03-17 18:10:39 -07:00
parent ba6369e359
commit 301375bf84
3 changed files with 66 additions and 22 deletions

View file

@ -8,6 +8,7 @@ from dotenv import load_dotenv
load_dotenv()
import os
import json
sys.path.insert(
0, os.path.abspath("../..")
@ -2146,29 +2147,62 @@ async def test_redis_proxy_batch_redis_get_cache():
assert "cache_key" in response._hidden_params
def test_logging_turn_off_message_logging_streaming():
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_logging_turn_off_message_logging_streaming(sync_mode):
litellm.turn_off_message_logging = True
mock_obj = Cache(type="local")
litellm.cache = mock_obj
with patch.object(mock_obj, "add_cache") as mock_client:
with patch.object(mock_obj, "add_cache") as mock_client, patch.object(
mock_obj, "async_add_cache"
) as mock_async_client:
print(f"mock_obj.add_cache: {mock_obj.add_cache}")
resp = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
mock_response="hello",
stream=True,
)
if sync_mode is True:
resp = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
mock_response="hello",
stream=True,
)
for chunk in resp:
continue
for chunk in resp:
continue
time.sleep(1)
time.sleep(1)
mock_client.assert_called_once()
print(f"mock_client.call_args: {mock_client.call_args}")
assert mock_client.call_args.args[0].choices[0].message.content == "hello"
else:
resp = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
mock_response="hello",
stream=True,
)
mock_client.assert_called_once()
print(f"mock_client.call_args: {mock_client.call_args}")
assert mock_client.call_args.args[0].choices[0].message.content == "hello"
async for chunk in resp:
continue
await asyncio.sleep(1)
mock_async_client.assert_called_once()
print(f"mock_async_client.call_args: {mock_async_client.call_args.args[0]}")
print(
f"mock_async_client.call_args: {json.loads(mock_async_client.call_args.args[0])}"
)
json_mock = json.loads(mock_async_client.call_args.args[0])
try:
assert json_mock["choices"][0]["message"]["content"] == "hello"
except Exception as e:
print(
f"mock_async_client.call_args.args[0]: {mock_async_client.call_args.args[0]}"
)
print(
f"mock_async_client.call_args.args[0]['choices']: {mock_async_client.call_args.args[0]['choices']}"
)
raise e
def test_basic_caching_import():