(bug fix) TTL not being set for embedding caching requests (#6095)

* fix ttl for cache pipeline settings

* add test for caching

* add test for setting ttls on redis caching
This commit is contained in:
Ishaan Jaff 2024-10-07 15:53:18 +05:30 committed by GitHub
parent 285b589095
commit 2c8bba293f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 99 additions and 14 deletions

View file

@ -22,6 +22,9 @@ import litellm
from litellm import aembedding, completion, embedding
from litellm.caching import Cache
from unittest.mock import AsyncMock, patch, MagicMock
import datetime
# litellm.set_verbose=True
messages = [{"role": "user", "content": "who is ishaan Github? "}]
@ -579,6 +582,56 @@ async def test_embedding_caching_base_64():
assert embedding_val_2.data[1]["embedding"] == embedding_val_1.data[1]["embedding"]
@pytest.mark.asyncio
async def test_embedding_caching_redis_ttl():
"""
Test default_in_redis_ttl is used for embedding caching
issue: https://github.com/BerriAI/litellm/issues/6010
"""
litellm.set_verbose = True
# Create a mock for the pipeline
mock_pipeline = AsyncMock()
mock_set = AsyncMock()
mock_pipeline.__aenter__.return_value.set = mock_set
# Patch the Redis class to return our mock
with patch("redis.asyncio.Redis.pipeline", return_value=mock_pipeline):
# Simulate the context manager behavior for the pipeline
litellm.cache = Cache(
type="redis",
host="dummy_host",
password="dummy_password",
default_in_redis_ttl=2.5,
)
inputs = [
f"{uuid.uuid4()} hello this is ishaan",
f"{uuid.uuid4()} hello this is ishaan again",
]
# Call the embedding method
embedding_val_1 = await litellm.aembedding(
model="azure/azure-embedding-model",
input=inputs,
encoding_format="base64",
)
await asyncio.sleep(3) # Wait for TTL to expire
# Check if set was called on the pipeline
mock_set.assert_called()
# Check if the TTL was set correctly
for call in mock_set.call_args_list:
args, kwargs = call
print(f"redis pipeline set args: {args}")
print(f"redis pipeline set kwargs: {kwargs}")
assert kwargs.get("ex") == datetime.timedelta(
seconds=2.5
) # Check if TTL is set to 2.5 seconds
@pytest.mark.asyncio
async def test_redis_cache_basic():
"""