mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(streaming_handler.py): support logging complete streaming response on cache hit
This commit is contained in:
parent
ba6369e359
commit
301375bf84
3 changed files with 66 additions and 22 deletions
|
@ -790,6 +790,7 @@ class LLMCachingHandler:
|
|||
- Else append the chunk to self.async_streaming_chunks
|
||||
|
||||
"""
|
||||
|
||||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
] = _assemble_complete_response_from_streaming_chunks(
|
||||
|
@ -800,7 +801,6 @@ class LLMCachingHandler:
|
|||
streaming_chunks=self.async_streaming_chunks,
|
||||
is_async=True,
|
||||
)
|
||||
|
||||
# if a complete_streaming_response is assembled, add it to the cache
|
||||
if complete_streaming_response is not None:
|
||||
await self.async_set_cache(
|
||||
|
|
|
@ -1481,6 +1481,15 @@ class CustomStreamWrapper:
|
|||
processed_chunk
|
||||
)
|
||||
|
||||
async def async_cache_streaming_response(self, processed_chunk, cache_hit: bool):
|
||||
"""
|
||||
Caches the streaming response
|
||||
"""
|
||||
if not cache_hit and self.logging_obj._llm_caching_handler is not None:
|
||||
await self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
|
||||
processed_chunk
|
||||
)
|
||||
|
||||
def run_success_logging_and_cache_storage(self, processed_chunk, cache_hit: bool):
|
||||
"""
|
||||
Runs success logging in a thread and adds the response to the cache
|
||||
|
@ -1711,13 +1720,6 @@ class CustomStreamWrapper:
|
|||
if processed_chunk is None:
|
||||
continue
|
||||
|
||||
if self.logging_obj._llm_caching_handler is not None:
|
||||
asyncio.create_task(
|
||||
self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
|
||||
processed_chunk=cast(ModelResponse, processed_chunk),
|
||||
)
|
||||
)
|
||||
|
||||
choice = processed_chunk.choices[0]
|
||||
if isinstance(choice, StreamingChoices):
|
||||
self.response_uptil_now += choice.delta.get("content", "") or ""
|
||||
|
@ -1788,6 +1790,14 @@ class CustomStreamWrapper:
|
|||
"usage",
|
||||
getattr(complete_streaming_response, "usage"),
|
||||
)
|
||||
asyncio.create_task(
|
||||
self.async_cache_streaming_response(
|
||||
processed_chunk=complete_streaming_response.model_copy(
|
||||
deep=True
|
||||
),
|
||||
cache_hit=cache_hit,
|
||||
)
|
||||
)
|
||||
if self.sent_stream_usage is False and self.send_stream_usage is True:
|
||||
self.sent_stream_usage = True
|
||||
return response
|
||||
|
|
|
@ -8,6 +8,7 @@ from dotenv import load_dotenv
|
|||
|
||||
load_dotenv()
|
||||
import os
|
||||
import json
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -2146,29 +2147,62 @@ async def test_redis_proxy_batch_redis_get_cache():
|
|||
assert "cache_key" in response._hidden_params
|
||||
|
||||
|
||||
def test_logging_turn_off_message_logging_streaming():
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_logging_turn_off_message_logging_streaming(sync_mode):
|
||||
litellm.turn_off_message_logging = True
|
||||
mock_obj = Cache(type="local")
|
||||
litellm.cache = mock_obj
|
||||
|
||||
with patch.object(mock_obj, "add_cache") as mock_client:
|
||||
with patch.object(mock_obj, "add_cache") as mock_client, patch.object(
|
||||
mock_obj, "async_add_cache"
|
||||
) as mock_async_client:
|
||||
print(f"mock_obj.add_cache: {mock_obj.add_cache}")
|
||||
|
||||
resp = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
mock_response="hello",
|
||||
stream=True,
|
||||
)
|
||||
if sync_mode is True:
|
||||
resp = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
mock_response="hello",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for chunk in resp:
|
||||
continue
|
||||
for chunk in resp:
|
||||
continue
|
||||
|
||||
time.sleep(1)
|
||||
time.sleep(1)
|
||||
mock_client.assert_called_once()
|
||||
print(f"mock_client.call_args: {mock_client.call_args}")
|
||||
assert mock_client.call_args.args[0].choices[0].message.content == "hello"
|
||||
else:
|
||||
resp = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
mock_response="hello",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
print(f"mock_client.call_args: {mock_client.call_args}")
|
||||
assert mock_client.call_args.args[0].choices[0].message.content == "hello"
|
||||
async for chunk in resp:
|
||||
continue
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
mock_async_client.assert_called_once()
|
||||
print(f"mock_async_client.call_args: {mock_async_client.call_args.args[0]}")
|
||||
print(
|
||||
f"mock_async_client.call_args: {json.loads(mock_async_client.call_args.args[0])}"
|
||||
)
|
||||
json_mock = json.loads(mock_async_client.call_args.args[0])
|
||||
try:
|
||||
assert json_mock["choices"][0]["message"]["content"] == "hello"
|
||||
except Exception as e:
|
||||
print(
|
||||
f"mock_async_client.call_args.args[0]: {mock_async_client.call_args.args[0]}"
|
||||
)
|
||||
print(
|
||||
f"mock_async_client.call_args.args[0]['choices']: {mock_async_client.call_args.args[0]['choices']}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def test_basic_caching_import():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue