mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Implement reusable cache store abstraction with in-memory and Redis backends as foundation for prompt caching feature (PR1 of progressive delivery). - Add CacheStore protocol defining cache interface - Implement MemoryCacheStore with LRU, LFU, and TTL-only eviction policies - Implement RedisCacheStore with connection pooling and retry logic - Add CircuitBreaker for cache backend failure protection - Include comprehensive unit tests (55 tests, >80% coverage) - Add dependencies: cachetools>=5.5.0, redis>=5.2.0 This abstraction enables flexible caching implementations for the prompt caching middleware without coupling to specific storage backends. Signed-by: William Caban <willliam.caban@gmail.com>
421 lines
14 KiB
Python
421 lines
14 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
"""Unit tests for RedisCacheStore implementation."""
|
|
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from llama_stack.providers.utils.cache import CacheError, RedisCacheStore
|
|
|
|
|
|
class TestRedisCacheStore:
|
|
"""Test suite for RedisCacheStore."""
|
|
|
|
async def test_init_default_params(self):
|
|
"""Test initialization with default parameters."""
|
|
cache = RedisCacheStore()
|
|
assert cache.host == "localhost"
|
|
assert cache.port == 6379
|
|
assert cache.db == 0
|
|
assert cache.password is None
|
|
assert cache.connection_pool_size == 10
|
|
assert cache.timeout_ms == 100
|
|
assert cache.default_ttl == 600
|
|
assert cache.max_retries == 3
|
|
assert cache.key_prefix == "llama_stack:"
|
|
|
|
async def test_init_custom_params(self):
|
|
"""Test initialization with custom parameters."""
|
|
cache = RedisCacheStore(
|
|
host="redis.example.com",
|
|
port=6380,
|
|
db=1,
|
|
password="secret",
|
|
connection_pool_size=20,
|
|
timeout_ms=200,
|
|
default_ttl=300,
|
|
max_retries=5,
|
|
key_prefix="test:",
|
|
)
|
|
assert cache.host == "redis.example.com"
|
|
assert cache.port == 6380
|
|
assert cache.db == 1
|
|
assert cache.password == "secret"
|
|
assert cache.connection_pool_size == 20
|
|
assert cache.timeout_ms == 200
|
|
assert cache.default_ttl == 300
|
|
assert cache.max_retries == 5
|
|
assert cache.key_prefix == "test:"
|
|
|
|
async def test_init_invalid_params(self):
|
|
"""Test initialization with invalid parameters."""
|
|
with pytest.raises(ValueError, match="connection_pool_size must be positive"):
|
|
RedisCacheStore(connection_pool_size=0)
|
|
|
|
with pytest.raises(ValueError, match="timeout_ms must be positive"):
|
|
RedisCacheStore(timeout_ms=0)
|
|
|
|
with pytest.raises(ValueError, match="default_ttl must be positive"):
|
|
RedisCacheStore(default_ttl=0)
|
|
|
|
with pytest.raises(ValueError, match="max_retries must be non-negative"):
|
|
RedisCacheStore(max_retries=-1)
|
|
|
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
|
async def test_ensure_connection(self, mock_redis_class, mock_pool_class):
|
|
"""Test connection establishment."""
|
|
# Setup mocks
|
|
mock_pool = MagicMock()
|
|
mock_pool_class.return_value = mock_pool
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock()
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
# Create cache
|
|
cache = RedisCacheStore()
|
|
|
|
# Ensure connection
|
|
redis = await cache._ensure_connection()
|
|
|
|
# Verify connection was established
|
|
assert redis == mock_redis
|
|
mock_pool_class.assert_called_once()
|
|
mock_redis.ping.assert_called_once()
|
|
|
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
|
async def test_connection_failure(self, mock_redis_class, mock_pool_class):
|
|
"""Test connection failure handling."""
|
|
from redis.exceptions import ConnectionError as RedisConnectionError
|
|
|
|
# Setup mocks to fail
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock(side_effect=RedisConnectionError("Connection refused"))
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
# Create cache
|
|
cache = RedisCacheStore()
|
|
|
|
# Connection should fail
|
|
with pytest.raises(CacheError, match="Failed to connect to Redis"):
|
|
await cache._ensure_connection()
|
|
|
|
def test_make_key(self):
|
|
"""Test key prefixing."""
|
|
cache = RedisCacheStore(key_prefix="test:")
|
|
assert cache._make_key("mykey") == "test:mykey"
|
|
assert cache._make_key("another") == "test:another"
|
|
|
|
def test_serialize_deserialize(self):
|
|
"""Test value serialization."""
|
|
cache = RedisCacheStore()
|
|
|
|
# Simple value
|
|
assert cache._serialize("hello") == '"hello"'
|
|
assert cache._deserialize('"hello"') == "hello"
|
|
|
|
# Dictionary
|
|
data = {"key": "value", "number": 42}
|
|
serialized = cache._serialize(data)
|
|
assert cache._deserialize(serialized) == data
|
|
|
|
# List
|
|
list_data = [1, 2, "three"]
|
|
serialized = cache._serialize(list_data)
|
|
assert cache._deserialize(serialized) == list_data
|
|
|
|
def test_serialize_error(self):
|
|
"""Test serialization error handling."""
|
|
cache = RedisCacheStore()
|
|
|
|
# Object that can't be serialized
|
|
class NonSerializable:
|
|
pass
|
|
|
|
with pytest.raises(ValueError, match="Value is not JSON-serializable"):
|
|
cache._serialize(NonSerializable())
|
|
|
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
|
async def test_set_and_get(self, mock_redis_class, mock_pool_class):
|
|
"""Test set and get operations."""
|
|
# Setup mocks
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock()
|
|
mock_redis.get = AsyncMock(return_value=json.dumps("value1"))
|
|
mock_redis.setex = AsyncMock()
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
# Create cache
|
|
cache = RedisCacheStore()
|
|
|
|
# Set value
|
|
await cache.set("key1", "value1")
|
|
mock_redis.setex.assert_called_once()
|
|
|
|
# Get value
|
|
value = await cache.get("key1")
|
|
assert value == "value1"
|
|
mock_redis.get.assert_called_once()
|
|
|
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
|
async def test_get_nonexistent_key(self, mock_redis_class, mock_pool_class):
|
|
"""Test getting a non-existent key."""
|
|
# Setup mocks
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock()
|
|
mock_redis.get = AsyncMock(return_value=None)
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
# Create cache
|
|
cache = RedisCacheStore()
|
|
|
|
# Get non-existent key
|
|
value = await cache.get("nonexistent")
|
|
assert value is None
|
|
|
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
|
async def test_set_with_custom_ttl(self, mock_redis_class, mock_pool_class):
|
|
"""Test setting value with custom TTL."""
|
|
# Setup mocks
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock()
|
|
mock_redis.setex = AsyncMock()
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
# Create cache
|
|
cache = RedisCacheStore(default_ttl=600)
|
|
|
|
# Set with custom TTL
|
|
await cache.set("key1", "value1", ttl=300)
|
|
|
|
# Verify setex was called with custom TTL
|
|
call_args = mock_redis.setex.call_args
|
|
assert call_args[0][1] == 300 # TTL argument
|
|
|
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
|
async def test_delete(self, mock_redis_class, mock_pool_class):
|
|
"""Test deleting a key."""
|
|
# Setup mocks
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock()
|
|
mock_redis.delete = AsyncMock(return_value=1) # 1 key deleted
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
# Create cache
|
|
cache = RedisCacheStore()
|
|
|
|
# Delete key
|
|
deleted = await cache.delete("key1")
|
|
assert deleted is True
|
|
|
|
# Delete non-existent key
|
|
mock_redis.delete = AsyncMock(return_value=0)
|
|
deleted = await cache.delete("nonexistent")
|
|
assert deleted is False
|
|
|
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
|
async def test_exists(self, mock_redis_class, mock_pool_class):
|
|
"""Test checking key existence."""
|
|
# Setup mocks
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock()
|
|
mock_redis.exists = AsyncMock(return_value=1) # Exists
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
# Create cache
|
|
cache = RedisCacheStore()
|
|
|
|
# Check existing key
|
|
exists = await cache.exists("key1")
|
|
assert exists is True
|
|
|
|
# Check non-existent key
|
|
mock_redis.exists = AsyncMock(return_value=0)
|
|
exists = await cache.exists("nonexistent")
|
|
assert exists is False
|
|
|
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
|
async def test_ttl(self, mock_redis_class, mock_pool_class):
|
|
"""Test getting remaining TTL."""
|
|
# Setup mocks
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock()
|
|
mock_redis.ttl = AsyncMock(return_value=300)
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
# Create cache
|
|
cache = RedisCacheStore()
|
|
|
|
# Get TTL
|
|
ttl = await cache.ttl("key1")
|
|
assert ttl == 300
|
|
|
|
# Key doesn't exist
|
|
mock_redis.ttl = AsyncMock(return_value=-2)
|
|
ttl = await cache.ttl("nonexistent")
|
|
assert ttl is None
|
|
|
|
# Key has no TTL
|
|
mock_redis.ttl = AsyncMock(return_value=-1)
|
|
ttl = await cache.ttl("no_ttl_key")
|
|
assert ttl is None
|
|
|
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
|
async def test_clear(self, mock_redis_class, mock_pool_class):
|
|
"""Test clearing all entries."""
|
|
# Setup mocks
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock()
|
|
mock_redis.scan = AsyncMock(
|
|
side_effect=[
|
|
(10, ["llama_stack:key1", "llama_stack:key2"]),
|
|
(0, ["llama_stack:key3"]), # cursor 0 indicates end
|
|
]
|
|
)
|
|
mock_redis.delete = AsyncMock(return_value=3)
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
# Create cache
|
|
cache = RedisCacheStore()
|
|
|
|
# Clear cache
|
|
await cache.clear()
|
|
|
|
# Verify scan and delete were called
|
|
assert mock_redis.scan.call_count == 2
|
|
mock_redis.delete.assert_called()
|
|
|
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
|
async def test_size(self, mock_redis_class, mock_pool_class):
|
|
"""Test getting cache size."""
|
|
# Setup mocks
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock()
|
|
mock_redis.scan = AsyncMock(
|
|
side_effect=[
|
|
(10, ["llama_stack:key1", "llama_stack:key2"]),
|
|
(0, ["llama_stack:key3"]),
|
|
]
|
|
)
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
# Create cache
|
|
cache = RedisCacheStore()
|
|
|
|
# Get size
|
|
size = await cache.size()
|
|
assert size == 3
|
|
|
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
|
async def test_retry_logic(self, mock_redis_class, mock_pool_class):
|
|
"""Test retry logic for transient failures."""
|
|
from redis.exceptions import TimeoutError as RedisTimeoutError
|
|
|
|
# Setup mocks - fail twice, then succeed
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock()
|
|
mock_redis.get = AsyncMock(
|
|
side_effect=[
|
|
RedisTimeoutError("Timeout"),
|
|
RedisTimeoutError("Timeout"),
|
|
json.dumps("success"),
|
|
]
|
|
)
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
# Create cache with retries
|
|
cache = RedisCacheStore(max_retries=3)
|
|
|
|
# Should succeed after retries
|
|
value = await cache.get("key1")
|
|
assert value == "success"
|
|
assert mock_redis.get.call_count == 3
|
|
|
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
|
async def test_retry_exhaustion(self, mock_redis_class, mock_pool_class):
|
|
"""Test behavior when all retries are exhausted."""
|
|
from redis.exceptions import TimeoutError as RedisTimeoutError
|
|
|
|
# Setup mocks - always fail
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock()
|
|
mock_redis.get = AsyncMock(side_effect=RedisTimeoutError("Timeout"))
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
# Create cache with limited retries
|
|
cache = RedisCacheStore(max_retries=2)
|
|
|
|
# Should raise CacheError after exhausting retries
|
|
with pytest.raises(CacheError, match="failed after .* attempts"):
|
|
await cache.get("key1")
|
|
|
|
# Should have tried 3 times (initial + 2 retries)
|
|
assert mock_redis.get.call_count == 3
|
|
|
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
|
async def test_close(self, mock_redis_class, mock_pool_class):
|
|
"""Test closing Redis connection."""
|
|
# Setup mocks
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock()
|
|
mock_redis.close = AsyncMock()
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
mock_pool = AsyncMock()
|
|
mock_pool.disconnect = AsyncMock()
|
|
mock_pool_class.return_value = mock_pool
|
|
|
|
# Create cache and establish connection
|
|
cache = RedisCacheStore()
|
|
await cache._ensure_connection()
|
|
|
|
# Close connection
|
|
await cache.close()
|
|
|
|
# Verify cleanup
|
|
mock_redis.close.assert_called_once()
|
|
mock_pool.disconnect.assert_called_once()
|
|
|
|
def test_get_stats(self):
|
|
"""Test getting cache statistics."""
|
|
cache = RedisCacheStore(
|
|
host="redis.example.com",
|
|
port=6380,
|
|
db=1,
|
|
connection_pool_size=20,
|
|
timeout_ms=200,
|
|
default_ttl=300,
|
|
max_retries=5,
|
|
key_prefix="test:",
|
|
)
|
|
|
|
stats = cache.get_stats()
|
|
|
|
assert stats["host"] == "redis.example.com"
|
|
assert stats["port"] == 6380
|
|
assert stats["db"] == 1
|
|
assert stats["connection_pool_size"] == 20
|
|
assert stats["timeout_ms"] == 200
|
|
assert stats["default_ttl"] == 300
|
|
assert stats["max_retries"] == 5
|
|
assert stats["key_prefix"] == "test:"
|
|
assert stats["connected"] is False # Not connected yet
|