forked from phoenix/litellm-mirror
(testing) add unit tests for LLMCachingHandler Class (#6279)
* add unit testing for test_async_set_cache * test test_async_log_cache_hit_on_callbacks * assert the correct response type is returned * test_convert_cached_result_to_model_response * unit testing for caching handler
This commit is contained in:
parent
202b5cc2cd
commit
f724f3131d
3 changed files with 361 additions and 2 deletions
|
@ -1 +1,8 @@
|
||||||
from .caching import Cache
|
from .caching import Cache, LiteLLMCacheType
|
||||||
|
from .disk_cache import DiskCache
|
||||||
|
from .dual_cache import DualCache
|
||||||
|
from .in_memory_cache import InMemoryCache
|
||||||
|
from .qdrant_semantic_cache import QdrantSemanticCache
|
||||||
|
from .redis_cache import RedisCache
|
||||||
|
from .redis_semantic_cache import RedisSemanticCache
|
||||||
|
from .s3_cache import S3Cache
|
||||||
|
|
|
@ -512,7 +512,16 @@ class LLMCachingHandler:
|
||||||
model: str,
|
model: str,
|
||||||
args: Tuple[Any, ...],
|
args: Tuple[Any, ...],
|
||||||
custom_llm_provider: Optional[str] = None,
|
custom_llm_provider: Optional[str] = None,
|
||||||
) -> Optional[Any]:
|
) -> Optional[
|
||||||
|
Union[
|
||||||
|
ModelResponse,
|
||||||
|
TextCompletionResponse,
|
||||||
|
EmbeddingResponse,
|
||||||
|
RerankResponse,
|
||||||
|
TranscriptionResponse,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
]
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Internal method to process the cached result
|
Internal method to process the cached result
|
||||||
|
|
||||||
|
|
343
tests/local_testing/test_caching_handler.py
Normal file
343
tests/local_testing/test_caching_handler.py
Normal file
|
@ -0,0 +1,343 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from test_rerank import assert_response_shape
|
||||||
|
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import random
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import aembedding, completion, embedding
|
||||||
|
from litellm.caching.caching import Cache
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
from litellm.caching.caching_handler import LLMCachingHandler, CachingHandlerResponse
|
||||||
|
from litellm.caching.caching import LiteLLMCacheType
|
||||||
|
from litellm.types.utils import CallTypes
|
||||||
|
from litellm.types.rerank import RerankResponse
|
||||||
|
from litellm.types.utils import (
|
||||||
|
ModelResponse,
|
||||||
|
EmbeddingResponse,
|
||||||
|
TextCompletionResponse,
|
||||||
|
TranscriptionResponse,
|
||||||
|
Embedding,
|
||||||
|
)
|
||||||
|
from datetime import timedelta, datetime
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
def setup_cache():
|
||||||
|
# Set up the cache
|
||||||
|
cache = Cache(
|
||||||
|
type=LiteLLMCacheType.REDIS,
|
||||||
|
host=os.environ["REDIS_HOST"],
|
||||||
|
port=os.environ["REDIS_PORT"],
|
||||||
|
password=os.environ["REDIS_PASSWORD"],
|
||||||
|
)
|
||||||
|
litellm.cache = cache
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
chat_completion_response = litellm.ModelResponse(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
choices=[
|
||||||
|
litellm.Choices(
|
||||||
|
message=litellm.Message(
|
||||||
|
role="assistant", content="Hello, how can I help you today?"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
text_completion_response = litellm.TextCompletionResponse(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
choices=[litellm.utils.TextChoices(text="Hello, how can I help you today?")],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"response", [chat_completion_response, text_completion_response]
|
||||||
|
)
|
||||||
|
async def test_async_set_get_cache(response):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
setup_cache()
|
||||||
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
|
caching_handler = LLMCachingHandler(
|
||||||
|
original_function=completion, request_kwargs={}, start_time=datetime.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": f"Unique message {datetime.now()}"}]
|
||||||
|
|
||||||
|
logging_obj = LiteLLMLogging(
|
||||||
|
litellm_call_id=str(datetime.now()),
|
||||||
|
call_type=CallTypes.completion.value,
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
function_id=str(uuid.uuid4()),
|
||||||
|
stream=False,
|
||||||
|
start_time=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response
|
||||||
|
print("result", result)
|
||||||
|
|
||||||
|
original_function = (
|
||||||
|
litellm.acompletion
|
||||||
|
if isinstance(response, litellm.ModelResponse)
|
||||||
|
else litellm.atext_completion
|
||||||
|
)
|
||||||
|
if isinstance(response, litellm.ModelResponse):
|
||||||
|
kwargs = {"messages": messages}
|
||||||
|
call_type = CallTypes.acompletion.value
|
||||||
|
else:
|
||||||
|
kwargs = {"prompt": f"Hello, how can I help you today? {datetime.now()}"}
|
||||||
|
call_type = CallTypes.atext_completion.value
|
||||||
|
|
||||||
|
await caching_handler.async_set_cache(
|
||||||
|
result=result, original_function=original_function, kwargs=kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# Verify the result was cached
|
||||||
|
cached_response = await caching_handler._async_get_cache(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
original_function=original_function,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
start_time=datetime.now(),
|
||||||
|
call_type=call_type,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cached_response.cached_result is not None
|
||||||
|
assert cached_response.cached_result.id == result.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_log_cache_hit_on_callbacks():
|
||||||
|
"""
|
||||||
|
Assert logging callbacks are called after a cache hit
|
||||||
|
"""
|
||||||
|
# Setup
|
||||||
|
caching_handler = LLMCachingHandler(
|
||||||
|
original_function=completion, request_kwargs={}, start_time=datetime.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logging_obj = MagicMock()
|
||||||
|
mock_logging_obj.async_success_handler = AsyncMock()
|
||||||
|
mock_logging_obj.success_handler = MagicMock()
|
||||||
|
|
||||||
|
cached_result = "Mocked cached result"
|
||||||
|
start_time = datetime.now()
|
||||||
|
end_time = start_time + timedelta(seconds=1)
|
||||||
|
cache_hit = True
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
caching_handler._async_log_cache_hit_on_callbacks(
|
||||||
|
logging_obj=mock_logging_obj,
|
||||||
|
cached_result=cached_result,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
cache_hit=cache_hit,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for the async task to complete
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
print("mock logging obj methods called", mock_logging_obj.mock_calls)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
mock_logging_obj.async_success_handler.assert_called_once_with(
|
||||||
|
cached_result, start_time, end_time, cache_hit
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for the thread to complete
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
mock_logging_obj.success_handler.assert_called_once_with(
|
||||||
|
cached_result, start_time, end_time, cache_hit
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"call_type, cached_result, expected_type",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
CallTypes.completion.value,
|
||||||
|
{
|
||||||
|
"id": "test",
|
||||||
|
"choices": [{"message": {"role": "assistant", "content": "Hello"}}],
|
||||||
|
},
|
||||||
|
ModelResponse,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
CallTypes.text_completion.value,
|
||||||
|
{"id": "test", "choices": [{"text": "Hello"}]},
|
||||||
|
TextCompletionResponse,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
CallTypes.embedding.value,
|
||||||
|
{"data": [{"embedding": [0.1, 0.2, 0.3]}]},
|
||||||
|
EmbeddingResponse,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
CallTypes.rerank.value,
|
||||||
|
{"id": "test", "results": [{"index": 0, "score": 0.9}]},
|
||||||
|
RerankResponse,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
CallTypes.transcription.value,
|
||||||
|
{"text": "Hello, world!"},
|
||||||
|
TranscriptionResponse,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_convert_cached_result_to_model_response(
|
||||||
|
call_type, cached_result, expected_type
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Assert that the cached result is converted to the correct type
|
||||||
|
"""
|
||||||
|
caching_handler = LLMCachingHandler(
|
||||||
|
original_function=lambda: None, request_kwargs={}, start_time=datetime.now()
|
||||||
|
)
|
||||||
|
logging_obj = LiteLLMLogging(
|
||||||
|
litellm_call_id=str(datetime.now()),
|
||||||
|
call_type=call_type,
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how can I help you today?"}],
|
||||||
|
function_id=str(uuid.uuid4()),
|
||||||
|
stream=False,
|
||||||
|
start_time=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = caching_handler._convert_cached_result_to_model_response(
|
||||||
|
cached_result=cached_result,
|
||||||
|
call_type=call_type,
|
||||||
|
kwargs={},
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
model="test-model",
|
||||||
|
args=(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, expected_type)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_combine_cached_embedding_response_with_api_result():
|
||||||
|
"""
|
||||||
|
If the cached response has [cache_hit, None, cache_hit]
|
||||||
|
result should be [cache_hit, api_result, cache_hit]
|
||||||
|
"""
|
||||||
|
# Setup
|
||||||
|
caching_handler = LLMCachingHandler(
|
||||||
|
original_function=lambda: None, request_kwargs={}, start_time=datetime.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = datetime.now()
|
||||||
|
end_time = start_time + timedelta(seconds=1)
|
||||||
|
|
||||||
|
# Create a CachingHandlerResponse with some cached and some None values
|
||||||
|
cached_response = EmbeddingResponse(
|
||||||
|
data=[
|
||||||
|
Embedding(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"),
|
||||||
|
None,
|
||||||
|
Embedding(embedding=[0.7, 0.8, 0.9], index=2, object="embedding"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
caching_handler_response = CachingHandlerResponse(
|
||||||
|
final_embedding_cached_response=cached_response
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an API EmbeddingResponse for the missing value
|
||||||
|
api_response = EmbeddingResponse(
|
||||||
|
data=[Embedding(embedding=[0.4, 0.5, 0.6], index=1, object="embedding")]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = caching_handler._combine_cached_embedding_response_with_api_result(
|
||||||
|
_caching_handler_response=caching_handler_response,
|
||||||
|
embedding_response=api_response,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert isinstance(result, EmbeddingResponse)
|
||||||
|
assert len(result.data) == 3
|
||||||
|
assert result.data[0].embedding == [0.1, 0.2, 0.3]
|
||||||
|
assert result.data[1].embedding == [0.4, 0.5, 0.6]
|
||||||
|
assert result.data[2].embedding == [0.7, 0.8, 0.9]
|
||||||
|
assert result._hidden_params["cache_hit"] == True
|
||||||
|
assert isinstance(result._response_ms, float)
|
||||||
|
assert result._response_ms > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_combine_cached_embedding_response_multiple_missing_values():
|
||||||
|
"""
|
||||||
|
If the cached response has [cache_hit, None, None, cache_hit, None]
|
||||||
|
result should be [cache_hit, api_result, api_result, cache_hit, api_result]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
caching_handler = LLMCachingHandler(
|
||||||
|
original_function=lambda: None, request_kwargs={}, start_time=datetime.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = datetime.now()
|
||||||
|
end_time = start_time + timedelta(seconds=1)
|
||||||
|
|
||||||
|
# Create a CachingHandlerResponse with some cached and some None values
|
||||||
|
cached_response = EmbeddingResponse(
|
||||||
|
data=[
|
||||||
|
Embedding(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
Embedding(embedding=[0.7, 0.8, 0.9], index=3, object="embedding"),
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
caching_handler_response = CachingHandlerResponse(
|
||||||
|
final_embedding_cached_response=cached_response
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an API EmbeddingResponse for the missing values
|
||||||
|
api_response = EmbeddingResponse(
|
||||||
|
data=[
|
||||||
|
Embedding(embedding=[0.4, 0.5, 0.6], index=1, object="embedding"),
|
||||||
|
Embedding(embedding=[0.4, 0.5, 0.6], index=2, object="embedding"),
|
||||||
|
Embedding(embedding=[0.4, 0.5, 0.6], index=4, object="embedding"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = caching_handler._combine_cached_embedding_response_with_api_result(
|
||||||
|
_caching_handler_response=caching_handler_response,
|
||||||
|
embedding_response=api_response,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert isinstance(result, EmbeddingResponse)
|
||||||
|
assert len(result.data) == 5
|
||||||
|
assert result.data[0].embedding == [0.1, 0.2, 0.3]
|
||||||
|
assert result.data[1].embedding == [0.4, 0.5, 0.6]
|
||||||
|
assert result.data[2].embedding == [0.4, 0.5, 0.6]
|
||||||
|
assert result.data[3].embedding == [0.7, 0.8, 0.9]
|
Loading…
Add table
Add a link
Reference in a new issue