mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(streaming_handler.py): support logging complete streaming response on cache hit
This commit is contained in:
parent
ba6369e359
commit
301375bf84
3 changed files with 66 additions and 22 deletions
|
@ -790,6 +790,7 @@ class LLMCachingHandler:
|
||||||
- Else append the chunk to self.async_streaming_chunks
|
- Else append the chunk to self.async_streaming_chunks
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
complete_streaming_response: Optional[
|
complete_streaming_response: Optional[
|
||||||
Union[ModelResponse, TextCompletionResponse]
|
Union[ModelResponse, TextCompletionResponse]
|
||||||
] = _assemble_complete_response_from_streaming_chunks(
|
] = _assemble_complete_response_from_streaming_chunks(
|
||||||
|
@ -800,7 +801,6 @@ class LLMCachingHandler:
|
||||||
streaming_chunks=self.async_streaming_chunks,
|
streaming_chunks=self.async_streaming_chunks,
|
||||||
is_async=True,
|
is_async=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if a complete_streaming_response is assembled, add it to the cache
|
# if a complete_streaming_response is assembled, add it to the cache
|
||||||
if complete_streaming_response is not None:
|
if complete_streaming_response is not None:
|
||||||
await self.async_set_cache(
|
await self.async_set_cache(
|
||||||
|
|
|
@ -1481,6 +1481,15 @@ class CustomStreamWrapper:
|
||||||
processed_chunk
|
processed_chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def async_cache_streaming_response(self, processed_chunk, cache_hit: bool):
|
||||||
|
"""
|
||||||
|
Caches the streaming response
|
||||||
|
"""
|
||||||
|
if not cache_hit and self.logging_obj._llm_caching_handler is not None:
|
||||||
|
await self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
|
||||||
|
processed_chunk
|
||||||
|
)
|
||||||
|
|
||||||
def run_success_logging_and_cache_storage(self, processed_chunk, cache_hit: bool):
|
def run_success_logging_and_cache_storage(self, processed_chunk, cache_hit: bool):
|
||||||
"""
|
"""
|
||||||
Runs success logging in a thread and adds the response to the cache
|
Runs success logging in a thread and adds the response to the cache
|
||||||
|
@ -1711,13 +1720,6 @@ class CustomStreamWrapper:
|
||||||
if processed_chunk is None:
|
if processed_chunk is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self.logging_obj._llm_caching_handler is not None:
|
|
||||||
asyncio.create_task(
|
|
||||||
self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
|
|
||||||
processed_chunk=cast(ModelResponse, processed_chunk),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
choice = processed_chunk.choices[0]
|
choice = processed_chunk.choices[0]
|
||||||
if isinstance(choice, StreamingChoices):
|
if isinstance(choice, StreamingChoices):
|
||||||
self.response_uptil_now += choice.delta.get("content", "") or ""
|
self.response_uptil_now += choice.delta.get("content", "") or ""
|
||||||
|
@ -1788,6 +1790,14 @@ class CustomStreamWrapper:
|
||||||
"usage",
|
"usage",
|
||||||
getattr(complete_streaming_response, "usage"),
|
getattr(complete_streaming_response, "usage"),
|
||||||
)
|
)
|
||||||
|
asyncio.create_task(
|
||||||
|
self.async_cache_streaming_response(
|
||||||
|
processed_chunk=complete_streaming_response.model_copy(
|
||||||
|
deep=True
|
||||||
|
),
|
||||||
|
cache_hit=cache_hit,
|
||||||
|
)
|
||||||
|
)
|
||||||
if self.sent_stream_usage is False and self.send_stream_usage is True:
|
if self.sent_stream_usage is False and self.send_stream_usage is True:
|
||||||
self.sent_stream_usage = True
|
self.sent_stream_usage = True
|
||||||
return response
|
return response
|
||||||
|
|
|
@ -8,6 +8,7 @@ from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -2146,29 +2147,62 @@ async def test_redis_proxy_batch_redis_get_cache():
|
||||||
assert "cache_key" in response._hidden_params
|
assert "cache_key" in response._hidden_params
|
||||||
|
|
||||||
|
|
||||||
def test_logging_turn_off_message_logging_streaming():
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_logging_turn_off_message_logging_streaming(sync_mode):
|
||||||
litellm.turn_off_message_logging = True
|
litellm.turn_off_message_logging = True
|
||||||
mock_obj = Cache(type="local")
|
mock_obj = Cache(type="local")
|
||||||
litellm.cache = mock_obj
|
litellm.cache = mock_obj
|
||||||
|
|
||||||
with patch.object(mock_obj, "add_cache") as mock_client:
|
with patch.object(mock_obj, "add_cache") as mock_client, patch.object(
|
||||||
|
mock_obj, "async_add_cache"
|
||||||
|
) as mock_async_client:
|
||||||
print(f"mock_obj.add_cache: {mock_obj.add_cache}")
|
print(f"mock_obj.add_cache: {mock_obj.add_cache}")
|
||||||
|
|
||||||
resp = litellm.completion(
|
if sync_mode is True:
|
||||||
model="gpt-3.5-turbo",
|
resp = litellm.completion(
|
||||||
messages=[{"role": "user", "content": "hi"}],
|
model="gpt-3.5-turbo",
|
||||||
mock_response="hello",
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
stream=True,
|
mock_response="hello",
|
||||||
)
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
for chunk in resp:
|
for chunk in resp:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
print(f"mock_client.call_args: {mock_client.call_args}")
|
||||||
|
assert mock_client.call_args.args[0].choices[0].message.content == "hello"
|
||||||
|
else:
|
||||||
|
resp = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
mock_response="hello",
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
mock_client.assert_called_once()
|
async for chunk in resp:
|
||||||
print(f"mock_client.call_args: {mock_client.call_args}")
|
continue
|
||||||
assert mock_client.call_args.args[0].choices[0].message.content == "hello"
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
mock_async_client.assert_called_once()
|
||||||
|
print(f"mock_async_client.call_args: {mock_async_client.call_args.args[0]}")
|
||||||
|
print(
|
||||||
|
f"mock_async_client.call_args: {json.loads(mock_async_client.call_args.args[0])}"
|
||||||
|
)
|
||||||
|
json_mock = json.loads(mock_async_client.call_args.args[0])
|
||||||
|
try:
|
||||||
|
assert json_mock["choices"][0]["message"]["content"] == "hello"
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"mock_async_client.call_args.args[0]: {mock_async_client.call_args.args[0]}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"mock_async_client.call_args.args[0]['choices']: {mock_async_client.call_args.args[0]['choices']}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def test_basic_caching_import():
|
def test_basic_caching_import():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue