From f724f3131d8b6c2c5608ab810071e21503a932bf Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 17 Oct 2024 19:12:57 +0530 Subject: [PATCH] (testing) add unit tests for LLMCachingHandler Class (#6279) * add unit testing for test_async_set_cache * test test_async_log_cache_hit_on_callbacks * assert the correct response type is returned * test_convert_cached_result_to_model_response * unit testing for caching handler --- litellm/caching/__init__.py | 9 +- litellm/caching/caching_handler.py | 11 +- tests/local_testing/test_caching_handler.py | 343 ++++++++++++++++++++ 3 files changed, 361 insertions(+), 2 deletions(-) create mode 100644 tests/local_testing/test_caching_handler.py diff --git a/litellm/caching/__init__.py b/litellm/caching/__init__.py index 65aea7f09..f10675f5e 100644 --- a/litellm/caching/__init__.py +++ b/litellm/caching/__init__.py @@ -1 +1,8 @@ -from .caching import Cache +from .caching import Cache, LiteLLMCacheType +from .disk_cache import DiskCache +from .dual_cache import DualCache +from .in_memory_cache import InMemoryCache +from .qdrant_semantic_cache import QdrantSemanticCache +from .redis_cache import RedisCache +from .redis_semantic_cache import RedisSemanticCache +from .s3_cache import S3Cache diff --git a/litellm/caching/caching_handler.py b/litellm/caching/caching_handler.py index 771c319d7..1acc9f3c2 100644 --- a/litellm/caching/caching_handler.py +++ b/litellm/caching/caching_handler.py @@ -512,7 +512,16 @@ class LLMCachingHandler: model: str, args: Tuple[Any, ...], custom_llm_provider: Optional[str] = None, - ) -> Optional[Any]: + ) -> Optional[ + Union[ + ModelResponse, + TextCompletionResponse, + EmbeddingResponse, + RerankResponse, + TranscriptionResponse, + CustomStreamWrapper, + ] + ]: """ Internal method to process the cached result diff --git a/tests/local_testing/test_caching_handler.py b/tests/local_testing/test_caching_handler.py new file mode 100644 index 000000000..11f7831bc --- /dev/null +++ b/tests/local_testing/test_caching_handler.py @@ -0,0 +1,343 @@ +import os +import sys +import time +import traceback +import uuid + +from dotenv import load_dotenv +from test_rerank import assert_response_shape + + +load_dotenv() +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import asyncio +import hashlib +import random + +import pytest + +import litellm +from litellm import aembedding, completion, embedding +from litellm.caching.caching import Cache + +from unittest.mock import AsyncMock, patch, MagicMock +from litellm.caching.caching_handler import LLMCachingHandler, CachingHandlerResponse +from litellm.caching.caching import LiteLLMCacheType +from litellm.types.utils import CallTypes +from litellm.types.rerank import RerankResponse +from litellm.types.utils import ( + ModelResponse, + EmbeddingResponse, + TextCompletionResponse, + TranscriptionResponse, + Embedding, +) +from datetime import timedelta, datetime +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging +from litellm._logging import verbose_logger +import logging + + +def setup_cache(): + # Set up the cache + cache = Cache( + type=LiteLLMCacheType.REDIS, + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) + litellm.cache = cache + return cache + + +chat_completion_response = litellm.ModelResponse( + id=str(uuid.uuid4()), + choices=[ + litellm.Choices( + message=litellm.Message( + role="assistant", content="Hello, how can I help you today?" + ) + ) + ], +) + +text_completion_response = litellm.TextCompletionResponse( + id=str(uuid.uuid4()), + choices=[litellm.utils.TextChoices(text="Hello, how can I help you today?")], +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "response", [chat_completion_response, text_completion_response] +) +async def test_async_set_get_cache(response): + litellm.set_verbose = True + setup_cache() + verbose_logger.setLevel(logging.DEBUG) + caching_handler = LLMCachingHandler( + original_function=completion, request_kwargs={}, start_time=datetime.now() + ) + + messages = [{"role": "user", "content": f"Unique message {datetime.now()}"}] + + logging_obj = LiteLLMLogging( + litellm_call_id=str(datetime.now()), + call_type=CallTypes.completion.value, + model="gpt-3.5-turbo", + messages=messages, + function_id=str(uuid.uuid4()), + stream=False, + start_time=datetime.now(), + ) + + result = response + print("result", result) + + original_function = ( + litellm.acompletion + if isinstance(response, litellm.ModelResponse) + else litellm.atext_completion + ) + if isinstance(response, litellm.ModelResponse): + kwargs = {"messages": messages} + call_type = CallTypes.acompletion.value + else: + kwargs = {"prompt": f"Hello, how can I help you today? {datetime.now()}"} + call_type = CallTypes.atext_completion.value + + await caching_handler.async_set_cache( + result=result, original_function=original_function, kwargs=kwargs + ) + + await asyncio.sleep(2) + + # Verify the result was cached + cached_response = await caching_handler._async_get_cache( + model="gpt-3.5-turbo", + original_function=original_function, + logging_obj=logging_obj, + start_time=datetime.now(), + call_type=call_type, + kwargs=kwargs, + ) + + assert cached_response.cached_result is not None + assert cached_response.cached_result.id == result.id + + +@pytest.mark.asyncio +async def test_async_log_cache_hit_on_callbacks(): + """ + Assert logging callbacks are called after a cache hit + """ + # Setup + caching_handler = LLMCachingHandler( + original_function=completion, request_kwargs={}, start_time=datetime.now() + ) + + mock_logging_obj = MagicMock() + mock_logging_obj.async_success_handler = AsyncMock() + mock_logging_obj.success_handler = MagicMock() + + cached_result = "Mocked cached result" + start_time = datetime.now() + end_time = start_time + timedelta(seconds=1) + cache_hit = True + + # Call the method + caching_handler._async_log_cache_hit_on_callbacks( + logging_obj=mock_logging_obj, + cached_result=cached_result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + ) + + # Wait for the async task to complete + await asyncio.sleep(0.5) + + print("mock logging obj methods called", mock_logging_obj.mock_calls) + + # Assertions + mock_logging_obj.async_success_handler.assert_called_once_with( + cached_result, start_time, end_time, cache_hit + ) + + # Wait for the thread to complete + await asyncio.sleep(0.5) + + mock_logging_obj.success_handler.assert_called_once_with( + cached_result, start_time, end_time, cache_hit + ) + + +@pytest.mark.parametrize( + "call_type, cached_result, expected_type", + [ + ( + CallTypes.completion.value, + { + "id": "test", + "choices": [{"message": {"role": "assistant", "content": "Hello"}}], + }, + ModelResponse, + ), + ( + CallTypes.text_completion.value, + {"id": "test", "choices": [{"text": "Hello"}]}, + TextCompletionResponse, + ), + ( + CallTypes.embedding.value, + {"data": [{"embedding": [0.1, 0.2, 0.3]}]}, + EmbeddingResponse, + ), + ( + CallTypes.rerank.value, + {"id": "test", "results": [{"index": 0, "score": 0.9}]}, + RerankResponse, + ), + ( + CallTypes.transcription.value, + {"text": "Hello, world!"}, + TranscriptionResponse, + ), + ], +) +def test_convert_cached_result_to_model_response( + call_type, cached_result, expected_type +): + """ + Assert that the cached result is converted to the correct type + """ + caching_handler = LLMCachingHandler( + original_function=lambda: None, request_kwargs={}, start_time=datetime.now() + ) + logging_obj = LiteLLMLogging( + litellm_call_id=str(datetime.now()), + call_type=call_type, + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello, how can I help you today?"}], + function_id=str(uuid.uuid4()), + stream=False, + start_time=datetime.now(), + ) + + result = caching_handler._convert_cached_result_to_model_response( + cached_result=cached_result, + call_type=call_type, + kwargs={}, + logging_obj=logging_obj, + model="test-model", + args=(), + ) + + assert isinstance(result, expected_type) + assert result is not None + + +def test_combine_cached_embedding_response_with_api_result(): + """ + If the cached response has [cache_hit, None, cache_hit] + result should be [cache_hit, api_result, cache_hit] + """ + # Setup + caching_handler = LLMCachingHandler( + original_function=lambda: None, request_kwargs={}, start_time=datetime.now() + ) + + start_time = datetime.now() + end_time = start_time + timedelta(seconds=1) + + # Create a CachingHandlerResponse with some cached and some None values + cached_response = EmbeddingResponse( + data=[ + Embedding(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"), + None, + Embedding(embedding=[0.7, 0.8, 0.9], index=2, object="embedding"), + ] + ) + caching_handler_response = CachingHandlerResponse( + final_embedding_cached_response=cached_response + ) + + # Create an API EmbeddingResponse for the missing value + api_response = EmbeddingResponse( + data=[Embedding(embedding=[0.4, 0.5, 0.6], index=1, object="embedding")] + ) + + # Call the method + result = caching_handler._combine_cached_embedding_response_with_api_result( + _caching_handler_response=caching_handler_response, + embedding_response=api_response, + start_time=start_time, + end_time=end_time, + ) + + # Assertions + assert isinstance(result, EmbeddingResponse) + assert len(result.data) == 3 + assert result.data[0].embedding == [0.1, 0.2, 0.3] + assert result.data[1].embedding == [0.4, 0.5, 0.6] + assert result.data[2].embedding == [0.7, 0.8, 0.9] + assert result._hidden_params["cache_hit"] == True + assert isinstance(result._response_ms, float) + assert result._response_ms > 0 + + +def test_combine_cached_embedding_response_multiple_missing_values(): + """ + If the cached response has [cache_hit, None, None, cache_hit, None] + result should be [cache_hit, api_result, api_result, cache_hit, api_result] + """ + + # Setup + caching_handler = LLMCachingHandler( + original_function=lambda: None, request_kwargs={}, start_time=datetime.now() + ) + + start_time = datetime.now() + end_time = start_time + timedelta(seconds=1) + + # Create a CachingHandlerResponse with some cached and some None values + cached_response = EmbeddingResponse( + data=[ + Embedding(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"), + None, + None, + Embedding(embedding=[0.7, 0.8, 0.9], index=3, object="embedding"), + None, + ] + ) + + caching_handler_response = CachingHandlerResponse( + final_embedding_cached_response=cached_response + ) + + # Create an API EmbeddingResponse for the missing values + api_response = EmbeddingResponse( + data=[ + Embedding(embedding=[0.4, 0.5, 0.6], index=1, object="embedding"), + Embedding(embedding=[0.4, 0.5, 0.6], index=2, object="embedding"), + Embedding(embedding=[0.4, 0.5, 0.6], index=4, object="embedding"), + ] + ) + + # Call the method + result = caching_handler._combine_cached_embedding_response_with_api_result( + _caching_handler_response=caching_handler_response, + embedding_response=api_response, + start_time=start_time, + end_time=end_time, + ) + + # Assertions + assert isinstance(result, EmbeddingResponse) + assert len(result.data) == 5 + assert result.data[0].embedding == [0.1, 0.2, 0.3] + assert result.data[1].embedding == [0.4, 0.5, 0.6] + assert result.data[2].embedding == [0.4, 0.5, 0.6] + assert result.data[3].embedding == [0.7, 0.8, 0.9]