forked from phoenix/litellm-mirror
(testing) add unit tests for LLMCachingHandler Class (#6279)
* add unit testing for test_async_set_cache * test test_async_log_cache_hit_on_callbacks * assert the correct response type is returned * test_convert_cached_result_to_model_response * unit testing for caching handler
This commit is contained in:
parent
202b5cc2cd
commit
f724f3131d
3 changed files with 361 additions and 2 deletions
343
tests/local_testing/test_caching_handler.py
Normal file
343
tests/local_testing/test_caching_handler.py
Normal file
|
@ -0,0 +1,343 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from test_rerank import assert_response_shape
|
||||
|
||||
|
||||
load_dotenv()
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import asyncio
|
||||
import hashlib
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import aembedding, completion, embedding
|
||||
from litellm.caching.caching import Cache
|
||||
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from litellm.caching.caching_handler import LLMCachingHandler, CachingHandlerResponse
|
||||
from litellm.caching.caching import LiteLLMCacheType
|
||||
from litellm.types.utils import CallTypes
|
||||
from litellm.types.rerank import RerankResponse
|
||||
from litellm.types.utils import (
|
||||
ModelResponse,
|
||||
EmbeddingResponse,
|
||||
TextCompletionResponse,
|
||||
TranscriptionResponse,
|
||||
Embedding,
|
||||
)
|
||||
from datetime import timedelta, datetime
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from litellm._logging import verbose_logger
|
||||
import logging
|
||||
|
||||
|
||||
def setup_cache():
|
||||
# Set up the cache
|
||||
cache = Cache(
|
||||
type=LiteLLMCacheType.REDIS,
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
)
|
||||
litellm.cache = cache
|
||||
return cache
|
||||
|
||||
|
||||
chat_completion_response = litellm.ModelResponse(
|
||||
id=str(uuid.uuid4()),
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
message=litellm.Message(
|
||||
role="assistant", content="Hello, how can I help you today?"
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
text_completion_response = litellm.TextCompletionResponse(
|
||||
id=str(uuid.uuid4()),
|
||||
choices=[litellm.utils.TextChoices(text="Hello, how can I help you today?")],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"response", [chat_completion_response, text_completion_response]
|
||||
)
|
||||
async def test_async_set_get_cache(response):
|
||||
litellm.set_verbose = True
|
||||
setup_cache()
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
caching_handler = LLMCachingHandler(
|
||||
original_function=completion, request_kwargs={}, start_time=datetime.now()
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": f"Unique message {datetime.now()}"}]
|
||||
|
||||
logging_obj = LiteLLMLogging(
|
||||
litellm_call_id=str(datetime.now()),
|
||||
call_type=CallTypes.completion.value,
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
function_id=str(uuid.uuid4()),
|
||||
stream=False,
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
result = response
|
||||
print("result", result)
|
||||
|
||||
original_function = (
|
||||
litellm.acompletion
|
||||
if isinstance(response, litellm.ModelResponse)
|
||||
else litellm.atext_completion
|
||||
)
|
||||
if isinstance(response, litellm.ModelResponse):
|
||||
kwargs = {"messages": messages}
|
||||
call_type = CallTypes.acompletion.value
|
||||
else:
|
||||
kwargs = {"prompt": f"Hello, how can I help you today? {datetime.now()}"}
|
||||
call_type = CallTypes.atext_completion.value
|
||||
|
||||
await caching_handler.async_set_cache(
|
||||
result=result, original_function=original_function, kwargs=kwargs
|
||||
)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Verify the result was cached
|
||||
cached_response = await caching_handler._async_get_cache(
|
||||
model="gpt-3.5-turbo",
|
||||
original_function=original_function,
|
||||
logging_obj=logging_obj,
|
||||
start_time=datetime.now(),
|
||||
call_type=call_type,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
assert cached_response.cached_result is not None
|
||||
assert cached_response.cached_result.id == result.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_log_cache_hit_on_callbacks():
|
||||
"""
|
||||
Assert logging callbacks are called after a cache hit
|
||||
"""
|
||||
# Setup
|
||||
caching_handler = LLMCachingHandler(
|
||||
original_function=completion, request_kwargs={}, start_time=datetime.now()
|
||||
)
|
||||
|
||||
mock_logging_obj = MagicMock()
|
||||
mock_logging_obj.async_success_handler = AsyncMock()
|
||||
mock_logging_obj.success_handler = MagicMock()
|
||||
|
||||
cached_result = "Mocked cached result"
|
||||
start_time = datetime.now()
|
||||
end_time = start_time + timedelta(seconds=1)
|
||||
cache_hit = True
|
||||
|
||||
# Call the method
|
||||
caching_handler._async_log_cache_hit_on_callbacks(
|
||||
logging_obj=mock_logging_obj,
|
||||
cached_result=cached_result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
)
|
||||
|
||||
# Wait for the async task to complete
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
print("mock logging obj methods called", mock_logging_obj.mock_calls)
|
||||
|
||||
# Assertions
|
||||
mock_logging_obj.async_success_handler.assert_called_once_with(
|
||||
cached_result, start_time, end_time, cache_hit
|
||||
)
|
||||
|
||||
# Wait for the thread to complete
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
mock_logging_obj.success_handler.assert_called_once_with(
|
||||
cached_result, start_time, end_time, cache_hit
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"call_type, cached_result, expected_type",
|
||||
[
|
||||
(
|
||||
CallTypes.completion.value,
|
||||
{
|
||||
"id": "test",
|
||||
"choices": [{"message": {"role": "assistant", "content": "Hello"}}],
|
||||
},
|
||||
ModelResponse,
|
||||
),
|
||||
(
|
||||
CallTypes.text_completion.value,
|
||||
{"id": "test", "choices": [{"text": "Hello"}]},
|
||||
TextCompletionResponse,
|
||||
),
|
||||
(
|
||||
CallTypes.embedding.value,
|
||||
{"data": [{"embedding": [0.1, 0.2, 0.3]}]},
|
||||
EmbeddingResponse,
|
||||
),
|
||||
(
|
||||
CallTypes.rerank.value,
|
||||
{"id": "test", "results": [{"index": 0, "score": 0.9}]},
|
||||
RerankResponse,
|
||||
),
|
||||
(
|
||||
CallTypes.transcription.value,
|
||||
{"text": "Hello, world!"},
|
||||
TranscriptionResponse,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_cached_result_to_model_response(
|
||||
call_type, cached_result, expected_type
|
||||
):
|
||||
"""
|
||||
Assert that the cached result is converted to the correct type
|
||||
"""
|
||||
caching_handler = LLMCachingHandler(
|
||||
original_function=lambda: None, request_kwargs={}, start_time=datetime.now()
|
||||
)
|
||||
logging_obj = LiteLLMLogging(
|
||||
litellm_call_id=str(datetime.now()),
|
||||
call_type=call_type,
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello, how can I help you today?"}],
|
||||
function_id=str(uuid.uuid4()),
|
||||
stream=False,
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
result = caching_handler._convert_cached_result_to_model_response(
|
||||
cached_result=cached_result,
|
||||
call_type=call_type,
|
||||
kwargs={},
|
||||
logging_obj=logging_obj,
|
||||
model="test-model",
|
||||
args=(),
|
||||
)
|
||||
|
||||
assert isinstance(result, expected_type)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_combine_cached_embedding_response_with_api_result():
|
||||
"""
|
||||
If the cached response has [cache_hit, None, cache_hit]
|
||||
result should be [cache_hit, api_result, cache_hit]
|
||||
"""
|
||||
# Setup
|
||||
caching_handler = LLMCachingHandler(
|
||||
original_function=lambda: None, request_kwargs={}, start_time=datetime.now()
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
end_time = start_time + timedelta(seconds=1)
|
||||
|
||||
# Create a CachingHandlerResponse with some cached and some None values
|
||||
cached_response = EmbeddingResponse(
|
||||
data=[
|
||||
Embedding(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"),
|
||||
None,
|
||||
Embedding(embedding=[0.7, 0.8, 0.9], index=2, object="embedding"),
|
||||
]
|
||||
)
|
||||
caching_handler_response = CachingHandlerResponse(
|
||||
final_embedding_cached_response=cached_response
|
||||
)
|
||||
|
||||
# Create an API EmbeddingResponse for the missing value
|
||||
api_response = EmbeddingResponse(
|
||||
data=[Embedding(embedding=[0.4, 0.5, 0.6], index=1, object="embedding")]
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = caching_handler._combine_cached_embedding_response_with_api_result(
|
||||
_caching_handler_response=caching_handler_response,
|
||||
embedding_response=api_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert isinstance(result, EmbeddingResponse)
|
||||
assert len(result.data) == 3
|
||||
assert result.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
assert result.data[1].embedding == [0.4, 0.5, 0.6]
|
||||
assert result.data[2].embedding == [0.7, 0.8, 0.9]
|
||||
assert result._hidden_params["cache_hit"] == True
|
||||
assert isinstance(result._response_ms, float)
|
||||
assert result._response_ms > 0
|
||||
|
||||
|
||||
def test_combine_cached_embedding_response_multiple_missing_values():
|
||||
"""
|
||||
If the cached response has [cache_hit, None, None, cache_hit, None]
|
||||
result should be [cache_hit, api_result, api_result, cache_hit, api_result]
|
||||
"""
|
||||
|
||||
# Setup
|
||||
caching_handler = LLMCachingHandler(
|
||||
original_function=lambda: None, request_kwargs={}, start_time=datetime.now()
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
end_time = start_time + timedelta(seconds=1)
|
||||
|
||||
# Create a CachingHandlerResponse with some cached and some None values
|
||||
cached_response = EmbeddingResponse(
|
||||
data=[
|
||||
Embedding(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"),
|
||||
None,
|
||||
None,
|
||||
Embedding(embedding=[0.7, 0.8, 0.9], index=3, object="embedding"),
|
||||
None,
|
||||
]
|
||||
)
|
||||
|
||||
caching_handler_response = CachingHandlerResponse(
|
||||
final_embedding_cached_response=cached_response
|
||||
)
|
||||
|
||||
# Create an API EmbeddingResponse for the missing values
|
||||
api_response = EmbeddingResponse(
|
||||
data=[
|
||||
Embedding(embedding=[0.4, 0.5, 0.6], index=1, object="embedding"),
|
||||
Embedding(embedding=[0.4, 0.5, 0.6], index=2, object="embedding"),
|
||||
Embedding(embedding=[0.4, 0.5, 0.6], index=4, object="embedding"),
|
||||
]
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = caching_handler._combine_cached_embedding_response_with_api_result(
|
||||
_caching_handler_response=caching_handler_response,
|
||||
embedding_response=api_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert isinstance(result, EmbeddingResponse)
|
||||
assert len(result.data) == 5
|
||||
assert result.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
assert result.data[1].embedding == [0.4, 0.5, 0.6]
|
||||
assert result.data[2].embedding == [0.4, 0.5, 0.6]
|
||||
assert result.data[3].embedding == [0.7, 0.8, 0.9]
|
Loading…
Add table
Add a link
Reference in a new issue