feat(cache): add cache store abstraction layer

Implement reusable cache store abstraction with in-memory and Redis
backends as foundation for prompt caching feature (PR1 of progressive delivery).

- Add CacheStore protocol defining cache interface
- Implement MemoryCacheStore with LRU, LFU, and TTL-only eviction policies
- Implement RedisCacheStore with connection pooling and retry logic
- Add CircuitBreaker for cache backend failure protection
- Include comprehensive unit tests (55 tests, >80% coverage)
- Add dependencies: cachetools>=5.5.0, redis>=5.2.0

This abstraction enables flexible caching implementations for the prompt
caching middleware without coupling to specific storage backends.

Signed-by: William Caban <willliam.caban@gmail.com>
This commit is contained in:
William Caban 2025-11-15 14:45:49 -05:00
parent 97f535c4f1
commit 299c575daa
10 changed files with 2175 additions and 1 deletions

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Unit tests for cache store implementations."""

View file

@ -0,0 +1,257 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Unit tests for cache store base classes and utilities."""
import asyncio
import pytest
from llama_stack.providers.utils.cache import CacheError, CircuitBreaker
class TestCacheError:
"""Test suite for CacheError exception."""
def test_init_with_message(self):
"""Test CacheError initialization with message."""
error = CacheError("Failed to connect to cache")
assert str(error) == "Failed to connect to cache"
assert error.cause is None
def test_init_with_cause(self):
"""Test CacheError initialization with underlying cause."""
cause = ValueError("Invalid value")
error = CacheError("Failed to set cache key", cause=cause)
assert str(error) == "Failed to set cache key"
assert error.cause == cause
class TestCircuitBreaker:
"""Test suite for CircuitBreaker."""
def test_init_default_params(self):
"""Test initialization with default parameters."""
breaker = CircuitBreaker()
assert breaker.failure_threshold == 10
assert breaker.recovery_timeout == 60
assert breaker.failure_count == 0
assert breaker.last_failure_time is None
assert breaker.state == "CLOSED"
def test_init_custom_params(self):
"""Test initialization with custom parameters."""
breaker = CircuitBreaker(failure_threshold=5, recovery_timeout=30)
assert breaker.failure_threshold == 5
assert breaker.recovery_timeout == 30
def test_is_closed_initial_state(self):
"""Test is_closed in initial state."""
breaker = CircuitBreaker()
assert breaker.is_closed() is True
assert breaker.get_state() == "CLOSED"
def test_record_success(self):
"""Test recording successful operations."""
breaker = CircuitBreaker()
# Record some failures
breaker.record_failure()
breaker.record_failure()
assert breaker.failure_count == 2
# Record success should reset
breaker.record_success()
assert breaker.failure_count == 0
assert breaker.last_failure_time is None
assert breaker.state == "CLOSED"
def test_record_failure_below_threshold(self):
"""Test recording failures below threshold."""
breaker = CircuitBreaker(failure_threshold=5)
# Record failures below threshold
for i in range(4):
breaker.record_failure()
assert breaker.is_closed() is True
assert breaker.state == "CLOSED"
assert breaker.failure_count == 4
def test_record_failure_reach_threshold(self):
"""Test circuit breaker opens when threshold reached."""
breaker = CircuitBreaker(failure_threshold=3)
# Record failures to reach threshold
for i in range(3):
breaker.record_failure()
# Should be open now
assert breaker.state == "OPEN"
assert breaker.is_closed() is False
def test_circuit_open_blocks_requests(self):
"""Test that open circuit blocks requests."""
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=10)
# Open the circuit
for i in range(3):
breaker.record_failure()
assert breaker.is_closed() is False
assert breaker.state == "OPEN"
async def test_recovery_timeout(self):
"""Test circuit breaker recovery after timeout."""
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
# Open the circuit
for i in range(3):
breaker.record_failure()
assert breaker.state == "OPEN"
assert breaker.is_closed() is False
# Wait for recovery timeout
await asyncio.sleep(1.1)
# Should enter HALF_OPEN state
assert breaker.is_closed() is True
assert breaker.state == "HALF_OPEN"
async def test_half_open_success_closes_circuit(self):
"""Test successful request in HALF_OPEN closes circuit."""
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
# Open the circuit
for i in range(3):
breaker.record_failure()
# Wait for recovery
await asyncio.sleep(1.1)
# Trigger state transition by calling is_closed()
assert breaker.is_closed() is True
assert breaker.state == "HALF_OPEN"
# Record success
breaker.record_success()
assert breaker.state == "CLOSED"
assert breaker.failure_count == 0
async def test_half_open_failure_reopens_circuit(self):
"""Test failed request in HALF_OPEN reopens circuit."""
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
# Open the circuit
for i in range(3):
breaker.record_failure()
# Wait for recovery
await asyncio.sleep(1.1)
# Trigger state transition by calling is_closed()
assert breaker.is_closed() is True
assert breaker.state == "HALF_OPEN"
# Record failure
breaker.record_failure()
assert breaker.state == "OPEN"
def test_reset(self):
"""Test manual reset of circuit breaker."""
breaker = CircuitBreaker(failure_threshold=3)
# Open the circuit
for i in range(3):
breaker.record_failure()
assert breaker.state == "OPEN"
# Manual reset
breaker.reset()
assert breaker.state == "CLOSED"
assert breaker.failure_count == 0
assert breaker.last_failure_time is None
def test_get_state(self):
"""Test getting circuit breaker state."""
breaker = CircuitBreaker(failure_threshold=3)
# Initial state
assert breaker.get_state() == "CLOSED"
# After failures
breaker.record_failure()
assert breaker.get_state() == "CLOSED"
# Open state
for i in range(2):
breaker.record_failure()
assert breaker.get_state() == "OPEN"
async def test_multiple_recovery_attempts(self):
"""Test multiple recovery attempts."""
breaker = CircuitBreaker(failure_threshold=2, recovery_timeout=1)
# Open the circuit
breaker.record_failure()
breaker.record_failure()
assert breaker.state == "OPEN"
# First recovery attempt fails
await asyncio.sleep(1.1)
assert breaker.is_closed() is True # Trigger state check
assert breaker.state == "HALF_OPEN"
breaker.record_failure()
assert breaker.state == "OPEN"
# Second recovery attempt succeeds
await asyncio.sleep(1.1)
assert breaker.is_closed() is True # Trigger state check
assert breaker.state == "HALF_OPEN"
breaker.record_success()
assert breaker.state == "CLOSED"
def test_failure_count_tracking(self):
"""Test failure count tracking."""
breaker = CircuitBreaker(failure_threshold=5)
# Track failures
assert breaker.failure_count == 0
breaker.record_failure()
assert breaker.failure_count == 1
breaker.record_failure()
assert breaker.failure_count == 2
# Success resets count
breaker.record_success()
assert breaker.failure_count == 0
async def test_concurrent_operations(self):
"""Test circuit breaker with concurrent operations."""
breaker = CircuitBreaker(failure_threshold=10)
async def record_failures(count: int):
for _ in range(count):
breaker.record_failure()
await asyncio.sleep(0.01)
# Concurrent failures
await asyncio.gather(
record_failures(3),
record_failures(3),
record_failures(3),
)
assert breaker.failure_count == 9
assert breaker.state == "CLOSED"
# One more should open it
breaker.record_failure()
assert breaker.state == "OPEN"

View file

@ -0,0 +1,332 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Unit tests for MemoryCacheStore implementation."""
import asyncio
import pytest
from llama_stack.providers.utils.cache import CacheError, MemoryCacheStore
class TestMemoryCacheStore:
"""Test suite for MemoryCacheStore."""
async def test_init_default_params(self):
"""Test initialization with default parameters."""
cache = MemoryCacheStore()
assert cache.max_entries == 1000
assert cache.max_memory_mb == 512
assert cache.default_ttl == 600
assert cache.eviction_policy == "lru"
async def test_init_custom_params(self):
"""Test initialization with custom parameters."""
cache = MemoryCacheStore(
max_entries=500,
max_memory_mb=256,
default_ttl=300,
eviction_policy="lfu",
)
assert cache.max_entries == 500
assert cache.max_memory_mb == 256
assert cache.default_ttl == 300
assert cache.eviction_policy == "lfu"
async def test_init_invalid_params(self):
"""Test initialization with invalid parameters."""
with pytest.raises(ValueError, match="max_entries must be positive"):
MemoryCacheStore(max_entries=0)
with pytest.raises(ValueError, match="default_ttl must be positive"):
MemoryCacheStore(default_ttl=0)
with pytest.raises(ValueError, match="max_memory_mb must be positive"):
MemoryCacheStore(max_memory_mb=0)
with pytest.raises(ValueError, match="Unknown eviction policy"):
MemoryCacheStore(eviction_policy="invalid") # type: ignore
async def test_set_and_get(self):
"""Test basic set and get operations."""
cache = MemoryCacheStore()
# Set value
await cache.set("key1", "value1")
# Get value
value = await cache.get("key1")
assert value == "value1"
async def test_get_nonexistent_key(self):
"""Test getting a non-existent key."""
cache = MemoryCacheStore()
value = await cache.get("nonexistent")
assert value is None
async def test_set_with_custom_ttl(self):
"""Test setting value with custom TTL."""
cache = MemoryCacheStore(default_ttl=10)
# Set with custom TTL
await cache.set("key1", "value1", ttl=1)
# Value should exist initially
value = await cache.get("key1")
assert value == "value1"
# Wait for expiration
await asyncio.sleep(1.1)
# Value should be expired
value = await cache.get("key1")
assert value is None
async def test_set_complex_value(self):
"""Test storing complex data types."""
cache = MemoryCacheStore()
# Test dictionary
data = {"nested": {"key": "value"}, "list": [1, 2, 3]}
await cache.set("complex", data)
value = await cache.get("complex")
assert value == data
# Test list
list_data = [1, "two", {"three": 3}]
await cache.set("list", list_data)
value = await cache.get("list")
assert value == list_data
async def test_delete(self):
"""Test deleting a key."""
cache = MemoryCacheStore()
# Set and delete
await cache.set("key1", "value1")
deleted = await cache.delete("key1")
assert deleted is True
# Verify deleted
value = await cache.get("key1")
assert value is None
# Delete non-existent key
deleted = await cache.delete("nonexistent")
assert deleted is False
async def test_exists(self):
"""Test checking key existence."""
cache = MemoryCacheStore()
# Non-existent key
exists = await cache.exists("key1")
assert exists is False
# Existing key
await cache.set("key1", "value1")
exists = await cache.exists("key1")
assert exists is True
# Expired key
await cache.set("key2", "value2", ttl=1)
await asyncio.sleep(1.1)
exists = await cache.exists("key2")
assert exists is False
async def test_ttl(self):
"""Test getting remaining TTL."""
cache = MemoryCacheStore()
# Non-existent key
ttl = await cache.ttl("nonexistent")
assert ttl is None
# Key with TTL
await cache.set("key1", "value1", ttl=10)
ttl = await cache.ttl("key1")
assert ttl is not None
assert 8 <= ttl <= 10 # Allow some tolerance
# Expired key
await cache.set("key2", "value2", ttl=1)
await asyncio.sleep(1.1)
ttl = await cache.ttl("key2")
assert ttl is None
async def test_clear(self):
"""Test clearing all entries."""
cache = MemoryCacheStore()
# Add multiple entries
await cache.set("key1", "value1")
await cache.set("key2", "value2")
await cache.set("key3", "value3")
# Clear
await cache.clear()
# Verify all cleared
assert await cache.get("key1") is None
assert await cache.get("key2") is None
assert await cache.get("key3") is None
async def test_size(self):
"""Test getting cache size."""
cache = MemoryCacheStore()
# Empty cache
size = await cache.size()
assert size == 0
# Add entries
await cache.set("key1", "value1")
await cache.set("key2", "value2")
size = await cache.size()
assert size == 2
# Delete entry
await cache.delete("key1")
size = await cache.size()
assert size == 1
# Clear cache
await cache.clear()
size = await cache.size()
assert size == 0
async def test_lru_eviction(self):
"""Test LRU eviction policy."""
cache = MemoryCacheStore(max_entries=3, eviction_policy="lru")
# Fill cache
await cache.set("key1", "value1")
await cache.set("key2", "value2")
await cache.set("key3", "value3")
# Access key1 to make it recently used
await cache.get("key1")
# Add new entry, should evict key2 (least recently used)
await cache.set("key4", "value4")
# key2 should be evicted
assert await cache.get("key1") == "value1"
assert await cache.get("key2") is None
assert await cache.get("key3") == "value3"
assert await cache.get("key4") == "value4"
async def test_lfu_eviction(self):
"""Test LFU eviction policy."""
cache = MemoryCacheStore(max_entries=3, eviction_policy="lfu")
# Fill cache
await cache.set("key1", "value1")
await cache.set("key2", "value2")
await cache.set("key3", "value3")
# Access key1 multiple times
await cache.get("key1")
await cache.get("key1")
await cache.get("key1")
# Access key2 twice
await cache.get("key2")
await cache.get("key2")
# key3 accessed once (least frequently)
# Add new entry, should evict key3 (least frequently used)
await cache.set("key4", "value4")
# key3 should be evicted
assert await cache.get("key1") == "value1"
assert await cache.get("key2") == "value2"
assert await cache.get("key3") is None
assert await cache.get("key4") == "value4"
async def test_concurrent_access(self):
"""Test concurrent access to cache."""
cache = MemoryCacheStore()
async def set_value(key: str, value: str):
await cache.set(key, value)
async def get_value(key: str):
return await cache.get(key)
# Concurrent sets
await asyncio.gather(
set_value("key1", "value1"),
set_value("key2", "value2"),
set_value("key3", "value3"),
)
# Concurrent gets
results = await asyncio.gather(
get_value("key1"),
get_value("key2"),
get_value("key3"),
)
assert results == ["value1", "value2", "value3"]
async def test_update_existing_key(self):
"""Test updating an existing key."""
cache = MemoryCacheStore()
# Set initial value
await cache.set("key1", "value1")
assert await cache.get("key1") == "value1"
# Update value
await cache.set("key1", "value2")
assert await cache.get("key1") == "value2"
async def test_get_stats(self):
"""Test getting cache statistics."""
cache = MemoryCacheStore(
max_entries=100,
max_memory_mb=128,
default_ttl=300,
eviction_policy="lru",
)
await cache.set("key1", "value1")
await cache.set("key2", "value2")
stats = cache.get_stats()
assert stats["size"] == 2
assert stats["max_entries"] == 100
assert stats["max_memory_mb"] == 128
assert stats["default_ttl"] == 300
assert stats["eviction_policy"] == "lru"
async def test_ttl_expiration_cleanup(self):
"""Test that expired entries are cleaned up properly."""
cache = MemoryCacheStore()
# Set entry with short TTL
await cache.set("key1", "value1", ttl=1)
await cache.set("key2", "value2", ttl=10)
# Initially both exist
assert await cache.size() == 2
# Wait for key1 to expire
await asyncio.sleep(1.1)
# Accessing expired key should clean it up
assert await cache.get("key1") is None
# Size should reflect cleanup
size = await cache.size()
assert size == 1
# key2 should still exist
assert await cache.get("key2") == "value2"

View file

@ -0,0 +1,421 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Unit tests for RedisCacheStore implementation."""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.providers.utils.cache import CacheError, RedisCacheStore
class TestRedisCacheStore:
"""Test suite for RedisCacheStore."""
async def test_init_default_params(self):
"""Test initialization with default parameters."""
cache = RedisCacheStore()
assert cache.host == "localhost"
assert cache.port == 6379
assert cache.db == 0
assert cache.password is None
assert cache.connection_pool_size == 10
assert cache.timeout_ms == 100
assert cache.default_ttl == 600
assert cache.max_retries == 3
assert cache.key_prefix == "llama_stack:"
async def test_init_custom_params(self):
"""Test initialization with custom parameters."""
cache = RedisCacheStore(
host="redis.example.com",
port=6380,
db=1,
password="secret",
connection_pool_size=20,
timeout_ms=200,
default_ttl=300,
max_retries=5,
key_prefix="test:",
)
assert cache.host == "redis.example.com"
assert cache.port == 6380
assert cache.db == 1
assert cache.password == "secret"
assert cache.connection_pool_size == 20
assert cache.timeout_ms == 200
assert cache.default_ttl == 300
assert cache.max_retries == 5
assert cache.key_prefix == "test:"
async def test_init_invalid_params(self):
"""Test initialization with invalid parameters."""
with pytest.raises(ValueError, match="connection_pool_size must be positive"):
RedisCacheStore(connection_pool_size=0)
with pytest.raises(ValueError, match="timeout_ms must be positive"):
RedisCacheStore(timeout_ms=0)
with pytest.raises(ValueError, match="default_ttl must be positive"):
RedisCacheStore(default_ttl=0)
with pytest.raises(ValueError, match="max_retries must be non-negative"):
RedisCacheStore(max_retries=-1)
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_ensure_connection(self, mock_redis_class, mock_pool_class):
"""Test connection establishment."""
# Setup mocks
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Ensure connection
redis = await cache._ensure_connection()
# Verify connection was established
assert redis == mock_redis
mock_pool_class.assert_called_once()
mock_redis.ping.assert_called_once()
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_connection_failure(self, mock_redis_class, mock_pool_class):
"""Test connection failure handling."""
from redis.exceptions import ConnectionError as RedisConnectionError
# Setup mocks to fail
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock(side_effect=RedisConnectionError("Connection refused"))
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Connection should fail
with pytest.raises(CacheError, match="Failed to connect to Redis"):
await cache._ensure_connection()
def test_make_key(self):
"""Test key prefixing."""
cache = RedisCacheStore(key_prefix="test:")
assert cache._make_key("mykey") == "test:mykey"
assert cache._make_key("another") == "test:another"
def test_serialize_deserialize(self):
"""Test value serialization."""
cache = RedisCacheStore()
# Simple value
assert cache._serialize("hello") == '"hello"'
assert cache._deserialize('"hello"') == "hello"
# Dictionary
data = {"key": "value", "number": 42}
serialized = cache._serialize(data)
assert cache._deserialize(serialized) == data
# List
list_data = [1, 2, "three"]
serialized = cache._serialize(list_data)
assert cache._deserialize(serialized) == list_data
def test_serialize_error(self):
"""Test serialization error handling."""
cache = RedisCacheStore()
# Object that can't be serialized
class NonSerializable:
pass
with pytest.raises(ValueError, match="Value is not JSON-serializable"):
cache._serialize(NonSerializable())
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_set_and_get(self, mock_redis_class, mock_pool_class):
"""Test set and get operations."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.get = AsyncMock(return_value=json.dumps("value1"))
mock_redis.setex = AsyncMock()
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Set value
await cache.set("key1", "value1")
mock_redis.setex.assert_called_once()
# Get value
value = await cache.get("key1")
assert value == "value1"
mock_redis.get.assert_called_once()
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_get_nonexistent_key(self, mock_redis_class, mock_pool_class):
"""Test getting a non-existent key."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.get = AsyncMock(return_value=None)
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Get non-existent key
value = await cache.get("nonexistent")
assert value is None
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_set_with_custom_ttl(self, mock_redis_class, mock_pool_class):
"""Test setting value with custom TTL."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.setex = AsyncMock()
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore(default_ttl=600)
# Set with custom TTL
await cache.set("key1", "value1", ttl=300)
# Verify setex was called with custom TTL
call_args = mock_redis.setex.call_args
assert call_args[0][1] == 300 # TTL argument
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_delete(self, mock_redis_class, mock_pool_class):
"""Test deleting a key."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.delete = AsyncMock(return_value=1) # 1 key deleted
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Delete key
deleted = await cache.delete("key1")
assert deleted is True
# Delete non-existent key
mock_redis.delete = AsyncMock(return_value=0)
deleted = await cache.delete("nonexistent")
assert deleted is False
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_exists(self, mock_redis_class, mock_pool_class):
"""Test checking key existence."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.exists = AsyncMock(return_value=1) # Exists
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Check existing key
exists = await cache.exists("key1")
assert exists is True
# Check non-existent key
mock_redis.exists = AsyncMock(return_value=0)
exists = await cache.exists("nonexistent")
assert exists is False
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_ttl(self, mock_redis_class, mock_pool_class):
"""Test getting remaining TTL."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.ttl = AsyncMock(return_value=300)
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Get TTL
ttl = await cache.ttl("key1")
assert ttl == 300
# Key doesn't exist
mock_redis.ttl = AsyncMock(return_value=-2)
ttl = await cache.ttl("nonexistent")
assert ttl is None
# Key has no TTL
mock_redis.ttl = AsyncMock(return_value=-1)
ttl = await cache.ttl("no_ttl_key")
assert ttl is None
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_clear(self, mock_redis_class, mock_pool_class):
"""Test clearing all entries."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.scan = AsyncMock(
side_effect=[
(10, ["llama_stack:key1", "llama_stack:key2"]),
(0, ["llama_stack:key3"]), # cursor 0 indicates end
]
)
mock_redis.delete = AsyncMock(return_value=3)
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Clear cache
await cache.clear()
# Verify scan and delete were called
assert mock_redis.scan.call_count == 2
mock_redis.delete.assert_called()
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_size(self, mock_redis_class, mock_pool_class):
"""Test getting cache size."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.scan = AsyncMock(
side_effect=[
(10, ["llama_stack:key1", "llama_stack:key2"]),
(0, ["llama_stack:key3"]),
]
)
mock_redis_class.return_value = mock_redis
# Create cache
cache = RedisCacheStore()
# Get size
size = await cache.size()
assert size == 3
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_retry_logic(self, mock_redis_class, mock_pool_class):
"""Test retry logic for transient failures."""
from redis.exceptions import TimeoutError as RedisTimeoutError
# Setup mocks - fail twice, then succeed
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.get = AsyncMock(
side_effect=[
RedisTimeoutError("Timeout"),
RedisTimeoutError("Timeout"),
json.dumps("success"),
]
)
mock_redis_class.return_value = mock_redis
# Create cache with retries
cache = RedisCacheStore(max_retries=3)
# Should succeed after retries
value = await cache.get("key1")
assert value == "success"
assert mock_redis.get.call_count == 3
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_retry_exhaustion(self, mock_redis_class, mock_pool_class):
"""Test behavior when all retries are exhausted."""
from redis.exceptions import TimeoutError as RedisTimeoutError
# Setup mocks - always fail
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.get = AsyncMock(side_effect=RedisTimeoutError("Timeout"))
mock_redis_class.return_value = mock_redis
# Create cache with limited retries
cache = RedisCacheStore(max_retries=2)
# Should raise CacheError after exhausting retries
with pytest.raises(CacheError, match="failed after .* attempts"):
await cache.get("key1")
# Should have tried 3 times (initial + 2 retries)
assert mock_redis.get.call_count == 3
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
@patch("llama_stack.providers.utils.cache.redis.Redis")
async def test_close(self, mock_redis_class, mock_pool_class):
"""Test closing Redis connection."""
# Setup mocks
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock()
mock_redis.close = AsyncMock()
mock_redis_class.return_value = mock_redis
mock_pool = AsyncMock()
mock_pool.disconnect = AsyncMock()
mock_pool_class.return_value = mock_pool
# Create cache and establish connection
cache = RedisCacheStore()
await cache._ensure_connection()
# Close connection
await cache.close()
# Verify cleanup
mock_redis.close.assert_called_once()
mock_pool.disconnect.assert_called_once()
def test_get_stats(self):
"""Test getting cache statistics."""
cache = RedisCacheStore(
host="redis.example.com",
port=6380,
db=1,
connection_pool_size=20,
timeout_ms=200,
default_ttl=300,
max_retries=5,
key_prefix="test:",
)
stats = cache.get_stats()
assert stats["host"] == "redis.example.com"
assert stats["port"] == 6380
assert stats["db"] == 1
assert stats["connection_pool_size"] == 20
assert stats["timeout_ms"] == 200
assert stats["default_ttl"] == 300
assert stats["max_retries"] == 5
assert stats["key_prefix"] == "test:"
assert stats["connected"] is False # Not connected yet