forked from phoenix/litellm-mirror
(bug fix) TTL not being set for embedding caching requests (#6095)
* fix ttl for cache pipeline settings * add test for caching * add test for setting ttls on redis caching
This commit is contained in:
parent
285b589095
commit
2c8bba293f
2 changed files with 99 additions and 14 deletions
|
@ -22,6 +22,9 @@ import litellm
|
|||
from litellm import aembedding, completion, embedding
|
||||
from litellm.caching import Cache
|
||||
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import datetime
|
||||
|
||||
# litellm.set_verbose=True
|
||||
|
||||
messages = [{"role": "user", "content": "who is ishaan Github? "}]
|
||||
|
@ -579,6 +582,56 @@ async def test_embedding_caching_base_64():
|
|||
assert embedding_val_2.data[1]["embedding"] == embedding_val_1.data[1]["embedding"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_caching_redis_ttl():
|
||||
"""
|
||||
Test default_in_redis_ttl is used for embedding caching
|
||||
|
||||
issue: https://github.com/BerriAI/litellm/issues/6010
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
|
||||
# Create a mock for the pipeline
|
||||
mock_pipeline = AsyncMock()
|
||||
mock_set = AsyncMock()
|
||||
mock_pipeline.__aenter__.return_value.set = mock_set
|
||||
# Patch the Redis class to return our mock
|
||||
with patch("redis.asyncio.Redis.pipeline", return_value=mock_pipeline):
|
||||
# Simulate the context manager behavior for the pipeline
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host="dummy_host",
|
||||
password="dummy_password",
|
||||
default_in_redis_ttl=2.5,
|
||||
)
|
||||
|
||||
inputs = [
|
||||
f"{uuid.uuid4()} hello this is ishaan",
|
||||
f"{uuid.uuid4()} hello this is ishaan again",
|
||||
]
|
||||
|
||||
# Call the embedding method
|
||||
embedding_val_1 = await litellm.aembedding(
|
||||
model="azure/azure-embedding-model",
|
||||
input=inputs,
|
||||
encoding_format="base64",
|
||||
)
|
||||
|
||||
await asyncio.sleep(3) # Wait for TTL to expire
|
||||
|
||||
# Check if set was called on the pipeline
|
||||
mock_set.assert_called()
|
||||
|
||||
# Check if the TTL was set correctly
|
||||
for call in mock_set.call_args_list:
|
||||
args, kwargs = call
|
||||
print(f"redis pipeline set args: {args}")
|
||||
print(f"redis pipeline set kwargs: {kwargs}")
|
||||
assert kwargs.get("ex") == datetime.timedelta(
|
||||
seconds=2.5
|
||||
) # Check if TTL is set to 2.5 seconds
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_basic():
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue