feat(cache): add cache store abstraction layer

Implement reusable cache store abstraction with in-memory and Redis
backends as foundation for prompt caching feature (PR1 of progressive delivery).

- Add CacheStore protocol defining cache interface
- Implement MemoryCacheStore with LRU, LFU, and TTL-only eviction policies
- Implement RedisCacheStore with connection pooling and retry logic
- Add CircuitBreaker for cache backend failure protection
- Include comprehensive unit tests (55 tests, >80% coverage)
- Add dependencies: cachetools>=5.5.0, redis>=5.2.0

This abstraction enables flexible caching implementations for the prompt
caching middleware without coupling to specific storage backends.

Signed-by: William Caban <willliam.caban@gmail.com>
This commit is contained in:
William Caban 2025-11-15 14:45:49 -05:00
parent 97f535c4f1
commit 299c575daa
10 changed files with 2175 additions and 1 deletions

View file

@ -0,0 +1,421 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Unit tests for RedisCacheStore implementation."""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.providers.utils.cache import CacheError, RedisCacheStore
class TestRedisCacheStore:
"""Test suite for RedisCacheStore."""
async def test_init_default_params(self):
"""Test initialization with default parameters."""
cache = RedisCacheStore()
assert cache.host == "localhost"
assert cache.port == 6379
assert cache.db == 0
assert cache.password is None
assert cache.connection_pool_size == 10
assert cache.timeout_ms == 100
assert cache.default_ttl == 600
assert cache.max_retries == 3
assert cache.key_prefix == "llama_stack:"
async def test_init_custom_params(self):
"""Test initialization with custom parameters."""
cache = RedisCacheStore(
host="redis.example.com",
port=6380,
db=1,
password="secret",
connection_pool_size=20,
timeout_ms=200,
default_ttl=300,
max_retries=5,
key_prefix="test:",
)
assert cache.host == "redis.example.com"
assert cache.port == 6380
assert cache.db == 1
assert cache.password == "secret"
assert cache.connection_pool_size == 20
assert cache.timeout_ms == 200
assert cache.default_ttl == 300
assert cache.max_retries == 5
assert cache.key_prefix == "test:"
async def test_init_invalid_params(self):
"""Test initialization with invalid parameters."""
with pytest.raises(ValueError, match="connection_pool_size must be positive"):
RedisCacheStore(connection_pool_size=0)
with pytest.raises(ValueError, match="timeout_ms must be positive"):
RedisCacheStore(timeout_ms=0)
with pytest.raises(ValueError, match="default_ttl must be positive"):
RedisCacheStore(default_ttl=0)
with pytest.raises(ValueError, match="max_retries must be non-negative"):
RedisCacheStore(max_retries=-1)
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_ensure_connection(self, mock_redis_class, mock_pool_class):
"""Test connection establishment."""
# Setup mocks
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Ensure connection
redis = await cache._ensure_connection()
# Verify connection was established
assert redis == mock_redis
mock_pool_class.assert_called_once()
mock_redis.ping.assert_called_once()
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_connection_failure(self, mock_redis_class, mock_pool_class):
"""Test connection failure handling."""
from redis.exceptions import ConnectionError as RedisConnectionError
# Setup mocks to fail
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock(side_effect=RedisConnectionError("Connection refused"))
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Connection should fail
with pytest.raises(CacheError, match="Failed to connect to Redis"):
await cache._ensure_connection()
def test_make_key(self):
"""Test key prefixing."""
cache = RedisCacheStore(key_prefix="test:")
assert cache._make_key("mykey") == "test:mykey"
assert cache._make_key("another") == "test:another"
def test_serialize_deserialize(self):
"""Test value serialization."""
cache = RedisCacheStore()
# Simple value
assert cache._serialize("hello") == '"hello"'
assert cache._deserialize('"hello"') == "hello"
# Dictionary
data = {"key": "value", "number": 42}
serialized = cache._serialize(data)
assert cache._deserialize(serialized) == data
# List
list_data = [1, 2, "three"]
serialized = cache._serialize(list_data)
assert cache._deserialize(serialized) == list_data
def test_serialize_error(self):
"""Test serialization error handling."""
cache = RedisCacheStore()
# Object that can't be serialized
class NonSerializable:
pass
with pytest.raises(ValueError, match="Value is not JSON-serializable"):
cache._serialize(NonSerializable())
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_set_and_get(self, mock_redis_class, mock_pool_class):
"""Test set and get operations."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.get = AsyncMock(return_value=json.dumps("value1"))
mock_redis.setex = AsyncMock()
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Set value
await cache.set("key1", "value1")
mock_redis.setex.assert_called_once()
# Get value
value = await cache.get("key1")
assert value == "value1"
mock_redis.get.assert_called_once()
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_get_nonexistent_key(self, mock_redis_class, mock_pool_class):
"""Test getting a non-existent key."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.get = AsyncMock(return_value=None)
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Get non-existent key
value = await cache.get("nonexistent")
assert value is None
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_set_with_custom_ttl(self, mock_redis_class, mock_pool_class):
"""Test setting value with custom TTL."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.setex = AsyncMock()
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore(default_ttl=600)
# Set with custom TTL
await cache.set("key1", "value1", ttl=300)
# Verify setex was called with custom TTL
call_args = mock_redis.setex.call_args
assert call_args[0][1] == 300 # TTL argument
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_delete(self, mock_redis_class, mock_pool_class):
"""Test deleting a key."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.delete = AsyncMock(return_value=1) # 1 key deleted
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Delete key
deleted = await cache.delete("key1")
assert deleted is True
# Delete non-existent key
mock_redis.delete = AsyncMock(return_value=0)
deleted = await cache.delete("nonexistent")
assert deleted is False
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_exists(self, mock_redis_class, mock_pool_class):
"""Test checking key existence."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.exists = AsyncMock(return_value=1) # Exists
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Check existing key
exists = await cache.exists("key1")
assert exists is True
# Check non-existent key
mock_redis.exists = AsyncMock(return_value=0)
exists = await cache.exists("nonexistent")
assert exists is False
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_ttl(self, mock_redis_class, mock_pool_class):
"""Test getting remaining TTL."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.ttl = AsyncMock(return_value=300)
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Get TTL
ttl = await cache.ttl("key1")
assert ttl == 300
# Key doesn't exist
mock_redis.ttl = AsyncMock(return_value=-2)
ttl = await cache.ttl("nonexistent")
assert ttl is None
# Key has no TTL
mock_redis.ttl = AsyncMock(return_value=-1)
ttl = await cache.ttl("no_ttl_key")
assert ttl is None
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_clear(self, mock_redis_class, mock_pool_class):
"""Test clearing all entries."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.scan = AsyncMock(
side_effect=[
(10, ["llama_stack:key1", "llama_stack:key2"]),
(0, ["llama_stack:key3"]), # cursor 0 indicates end
]
)
mock_redis.delete = AsyncMock(return_value=3)
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Clear cache
await cache.clear()
# Verify scan and delete were called
assert mock_redis.scan.call_count == 2
mock_redis.delete.assert_called()
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_size(self, mock_redis_class, mock_pool_class):
"""Test getting cache size."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.scan = AsyncMock(
side_effect=[
(10, ["llama_stack:key1", "llama_stack:key2"]),
(0, ["llama_stack:key3"]),
]
)
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Get size
size = await cache.size()
assert size == 3
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_retry_logic(self, mock_redis_class, mock_pool_class):
"""Test retry logic for transient failures."""
from redis.exceptions import TimeoutError as RedisTimeoutError
# Setup mocks - fail twice, then succeed
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.get = AsyncMock(
side_effect=[
RedisTimeoutError("Timeout"),
RedisTimeoutError("Timeout"),
json.dumps("success"),
]
)
mock_redis_class.return_value = mock_redis
# Create cache with retries
cache = RedisCacheStore(max_retries=3)
# Should succeed after retries
value = await cache.get("key1")
assert value == "success"
assert mock_redis.get.call_count == 3
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_retry_exhaustion(self, mock_redis_class, mock_pool_class):
"""Test behavior when all retries are exhausted."""
from redis.exceptions import TimeoutError as RedisTimeoutError
# Setup mocks - always fail
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.get = AsyncMock(side_effect=RedisTimeoutError("Timeout"))
mock_redis_class.return_value = mock_redis
# Create cache with limited retries
cache = RedisCacheStore(max_retries=2)
# Should raise CacheError after exhausting retries
with pytest.raises(CacheError, match="failed after .* attempts"):
await cache.get("key1")
# Should have tried 3 times (initial + 2 retries)
assert mock_redis.get.call_count == 3
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_close(self, mock_redis_class, mock_pool_class):
"""Test closing Redis connection."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.close = AsyncMock()
mock_redis_class.return_value = mock_redis
mock_pool = AsyncMock()
mock_pool.disconnect = AsyncMock()
mock_pool_class.return_value = mock_pool
# Create cache and establish connection
cache = RedisCacheStore()
await cache._ensure_connection()
# Close connection
await cache.close()
# Verify cleanup
mock_redis.close.assert_called_once()
mock_pool.disconnect.assert_called_once()
def test_get_stats(self):
"""Test getting cache statistics."""
cache = RedisCacheStore(
host="redis.example.com",
port=6380,
db=1,
connection_pool_size=20,
timeout_ms=200,
default_ttl=300,
max_retries=5,
key_prefix="test:",
)
stats = cache.get_stats()
assert stats["host"] == "redis.example.com"
assert stats["port"] == 6380
assert stats["db"] == 1
assert stats["connection_pool_size"] == 20
assert stats["timeout_ms"] == 200
assert stats["default_ttl"] == 300
assert stats["max_retries"] == 5
assert stats["key_prefix"] == "test:"
assert stats["connected"] is False # Not connected yet