mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
feat(cache): add cache store abstraction layer
Implement reusable cache store abstraction with in-memory and Redis backends as foundation for prompt caching feature (PR1 of progressive delivery). - Add CacheStore protocol defining cache interface - Implement MemoryCacheStore with LRU, LFU, and TTL-only eviction policies - Implement RedisCacheStore with connection pooling and retry logic - Add CircuitBreaker for cache backend failure protection - Include comprehensive unit tests (55 tests, >80% coverage) - Add dependencies: cachetools>=5.5.0, redis>=5.2.0 This abstraction enables flexible caching implementations for the prompt caching middleware without coupling to specific storage backends. Signed-by: William Caban <willliam.caban@gmail.com>
This commit is contained in:
parent
97f535c4f1
commit
299c575daa
10 changed files with 2175 additions and 1 deletions
7
tests/unit/providers/utils/cache/__init__.py
vendored
Normal file
7
tests/unit/providers/utils/cache/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for cache store implementations."""
|
||||
257
tests/unit/providers/utils/cache/test_cache_store.py
vendored
Normal file
257
tests/unit/providers/utils/cache/test_cache_store.py
vendored
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for cache store base classes and utilities."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.cache import CacheError, CircuitBreaker
|
||||
|
||||
|
||||
class TestCacheError:
|
||||
"""Test suite for CacheError exception."""
|
||||
|
||||
def test_init_with_message(self):
|
||||
"""Test CacheError initialization with message."""
|
||||
error = CacheError("Failed to connect to cache")
|
||||
assert str(error) == "Failed to connect to cache"
|
||||
assert error.cause is None
|
||||
|
||||
def test_init_with_cause(self):
|
||||
"""Test CacheError initialization with underlying cause."""
|
||||
cause = ValueError("Invalid value")
|
||||
error = CacheError("Failed to set cache key", cause=cause)
|
||||
assert str(error) == "Failed to set cache key"
|
||||
assert error.cause == cause
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
"""Test suite for CircuitBreaker."""
|
||||
|
||||
def test_init_default_params(self):
|
||||
"""Test initialization with default parameters."""
|
||||
breaker = CircuitBreaker()
|
||||
assert breaker.failure_threshold == 10
|
||||
assert breaker.recovery_timeout == 60
|
||||
assert breaker.failure_count == 0
|
||||
assert breaker.last_failure_time is None
|
||||
assert breaker.state == "CLOSED"
|
||||
|
||||
def test_init_custom_params(self):
|
||||
"""Test initialization with custom parameters."""
|
||||
breaker = CircuitBreaker(failure_threshold=5, recovery_timeout=30)
|
||||
assert breaker.failure_threshold == 5
|
||||
assert breaker.recovery_timeout == 30
|
||||
|
||||
def test_is_closed_initial_state(self):
|
||||
"""Test is_closed in initial state."""
|
||||
breaker = CircuitBreaker()
|
||||
assert breaker.is_closed() is True
|
||||
assert breaker.get_state() == "CLOSED"
|
||||
|
||||
def test_record_success(self):
|
||||
"""Test recording successful operations."""
|
||||
breaker = CircuitBreaker()
|
||||
|
||||
# Record some failures
|
||||
breaker.record_failure()
|
||||
breaker.record_failure()
|
||||
assert breaker.failure_count == 2
|
||||
|
||||
# Record success should reset
|
||||
breaker.record_success()
|
||||
assert breaker.failure_count == 0
|
||||
assert breaker.last_failure_time is None
|
||||
assert breaker.state == "CLOSED"
|
||||
|
||||
def test_record_failure_below_threshold(self):
|
||||
"""Test recording failures below threshold."""
|
||||
breaker = CircuitBreaker(failure_threshold=5)
|
||||
|
||||
# Record failures below threshold
|
||||
for i in range(4):
|
||||
breaker.record_failure()
|
||||
assert breaker.is_closed() is True
|
||||
assert breaker.state == "CLOSED"
|
||||
|
||||
assert breaker.failure_count == 4
|
||||
|
||||
def test_record_failure_reach_threshold(self):
|
||||
"""Test circuit breaker opens when threshold reached."""
|
||||
breaker = CircuitBreaker(failure_threshold=3)
|
||||
|
||||
# Record failures to reach threshold
|
||||
for i in range(3):
|
||||
breaker.record_failure()
|
||||
|
||||
# Should be open now
|
||||
assert breaker.state == "OPEN"
|
||||
assert breaker.is_closed() is False
|
||||
|
||||
def test_circuit_open_blocks_requests(self):
|
||||
"""Test that open circuit blocks requests."""
|
||||
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=10)
|
||||
|
||||
# Open the circuit
|
||||
for i in range(3):
|
||||
breaker.record_failure()
|
||||
|
||||
assert breaker.is_closed() is False
|
||||
assert breaker.state == "OPEN"
|
||||
|
||||
async def test_recovery_timeout(self):
|
||||
"""Test circuit breaker recovery after timeout."""
|
||||
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
|
||||
|
||||
# Open the circuit
|
||||
for i in range(3):
|
||||
breaker.record_failure()
|
||||
|
||||
assert breaker.state == "OPEN"
|
||||
assert breaker.is_closed() is False
|
||||
|
||||
# Wait for recovery timeout
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Should enter HALF_OPEN state
|
||||
assert breaker.is_closed() is True
|
||||
assert breaker.state == "HALF_OPEN"
|
||||
|
||||
async def test_half_open_success_closes_circuit(self):
|
||||
"""Test successful request in HALF_OPEN closes circuit."""
|
||||
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
|
||||
|
||||
# Open the circuit
|
||||
for i in range(3):
|
||||
breaker.record_failure()
|
||||
|
||||
# Wait for recovery
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Trigger state transition by calling is_closed()
|
||||
assert breaker.is_closed() is True
|
||||
assert breaker.state == "HALF_OPEN"
|
||||
|
||||
# Record success
|
||||
breaker.record_success()
|
||||
assert breaker.state == "CLOSED"
|
||||
assert breaker.failure_count == 0
|
||||
|
||||
async def test_half_open_failure_reopens_circuit(self):
|
||||
"""Test failed request in HALF_OPEN reopens circuit."""
|
||||
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
|
||||
|
||||
# Open the circuit
|
||||
for i in range(3):
|
||||
breaker.record_failure()
|
||||
|
||||
# Wait for recovery
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Trigger state transition by calling is_closed()
|
||||
assert breaker.is_closed() is True
|
||||
assert breaker.state == "HALF_OPEN"
|
||||
|
||||
# Record failure
|
||||
breaker.record_failure()
|
||||
assert breaker.state == "OPEN"
|
||||
|
||||
def test_reset(self):
|
||||
"""Test manual reset of circuit breaker."""
|
||||
breaker = CircuitBreaker(failure_threshold=3)
|
||||
|
||||
# Open the circuit
|
||||
for i in range(3):
|
||||
breaker.record_failure()
|
||||
|
||||
assert breaker.state == "OPEN"
|
||||
|
||||
# Manual reset
|
||||
breaker.reset()
|
||||
assert breaker.state == "CLOSED"
|
||||
assert breaker.failure_count == 0
|
||||
assert breaker.last_failure_time is None
|
||||
|
||||
def test_get_state(self):
|
||||
"""Test getting circuit breaker state."""
|
||||
breaker = CircuitBreaker(failure_threshold=3)
|
||||
|
||||
# Initial state
|
||||
assert breaker.get_state() == "CLOSED"
|
||||
|
||||
# After failures
|
||||
breaker.record_failure()
|
||||
assert breaker.get_state() == "CLOSED"
|
||||
|
||||
# Open state
|
||||
for i in range(2):
|
||||
breaker.record_failure()
|
||||
assert breaker.get_state() == "OPEN"
|
||||
|
||||
async def test_multiple_recovery_attempts(self):
|
||||
"""Test multiple recovery attempts."""
|
||||
breaker = CircuitBreaker(failure_threshold=2, recovery_timeout=1)
|
||||
|
||||
# Open the circuit
|
||||
breaker.record_failure()
|
||||
breaker.record_failure()
|
||||
assert breaker.state == "OPEN"
|
||||
|
||||
# First recovery attempt fails
|
||||
await asyncio.sleep(1.1)
|
||||
assert breaker.is_closed() is True # Trigger state check
|
||||
assert breaker.state == "HALF_OPEN"
|
||||
breaker.record_failure()
|
||||
assert breaker.state == "OPEN"
|
||||
|
||||
# Second recovery attempt succeeds
|
||||
await asyncio.sleep(1.1)
|
||||
assert breaker.is_closed() is True # Trigger state check
|
||||
assert breaker.state == "HALF_OPEN"
|
||||
breaker.record_success()
|
||||
assert breaker.state == "CLOSED"
|
||||
|
||||
def test_failure_count_tracking(self):
|
||||
"""Test failure count tracking."""
|
||||
breaker = CircuitBreaker(failure_threshold=5)
|
||||
|
||||
# Track failures
|
||||
assert breaker.failure_count == 0
|
||||
|
||||
breaker.record_failure()
|
||||
assert breaker.failure_count == 1
|
||||
|
||||
breaker.record_failure()
|
||||
assert breaker.failure_count == 2
|
||||
|
||||
# Success resets count
|
||||
breaker.record_success()
|
||||
assert breaker.failure_count == 0
|
||||
|
||||
async def test_concurrent_operations(self):
|
||||
"""Test circuit breaker with concurrent operations."""
|
||||
breaker = CircuitBreaker(failure_threshold=10)
|
||||
|
||||
async def record_failures(count: int):
|
||||
for _ in range(count):
|
||||
breaker.record_failure()
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Concurrent failures
|
||||
await asyncio.gather(
|
||||
record_failures(3),
|
||||
record_failures(3),
|
||||
record_failures(3),
|
||||
)
|
||||
|
||||
assert breaker.failure_count == 9
|
||||
assert breaker.state == "CLOSED"
|
||||
|
||||
# One more should open it
|
||||
breaker.record_failure()
|
||||
assert breaker.state == "OPEN"
|
||||
332
tests/unit/providers/utils/cache/test_memory_cache.py
vendored
Normal file
332
tests/unit/providers/utils/cache/test_memory_cache.py
vendored
Normal file
|
|
@ -0,0 +1,332 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for MemoryCacheStore implementation."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.cache import CacheError, MemoryCacheStore
|
||||
|
||||
|
||||
class TestMemoryCacheStore:
|
||||
"""Test suite for MemoryCacheStore."""
|
||||
|
||||
async def test_init_default_params(self):
|
||||
"""Test initialization with default parameters."""
|
||||
cache = MemoryCacheStore()
|
||||
assert cache.max_entries == 1000
|
||||
assert cache.max_memory_mb == 512
|
||||
assert cache.default_ttl == 600
|
||||
assert cache.eviction_policy == "lru"
|
||||
|
||||
async def test_init_custom_params(self):
|
||||
"""Test initialization with custom parameters."""
|
||||
cache = MemoryCacheStore(
|
||||
max_entries=500,
|
||||
max_memory_mb=256,
|
||||
default_ttl=300,
|
||||
eviction_policy="lfu",
|
||||
)
|
||||
assert cache.max_entries == 500
|
||||
assert cache.max_memory_mb == 256
|
||||
assert cache.default_ttl == 300
|
||||
assert cache.eviction_policy == "lfu"
|
||||
|
||||
async def test_init_invalid_params(self):
|
||||
"""Test initialization with invalid parameters."""
|
||||
with pytest.raises(ValueError, match="max_entries must be positive"):
|
||||
MemoryCacheStore(max_entries=0)
|
||||
|
||||
with pytest.raises(ValueError, match="default_ttl must be positive"):
|
||||
MemoryCacheStore(default_ttl=0)
|
||||
|
||||
with pytest.raises(ValueError, match="max_memory_mb must be positive"):
|
||||
MemoryCacheStore(max_memory_mb=0)
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown eviction policy"):
|
||||
MemoryCacheStore(eviction_policy="invalid") # type: ignore
|
||||
|
||||
async def test_set_and_get(self):
|
||||
"""Test basic set and get operations."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Set value
|
||||
await cache.set("key1", "value1")
|
||||
|
||||
# Get value
|
||||
value = await cache.get("key1")
|
||||
assert value == "value1"
|
||||
|
||||
async def test_get_nonexistent_key(self):
|
||||
"""Test getting a non-existent key."""
|
||||
cache = MemoryCacheStore()
|
||||
value = await cache.get("nonexistent")
|
||||
assert value is None
|
||||
|
||||
async def test_set_with_custom_ttl(self):
|
||||
"""Test setting value with custom TTL."""
|
||||
cache = MemoryCacheStore(default_ttl=10)
|
||||
|
||||
# Set with custom TTL
|
||||
await cache.set("key1", "value1", ttl=1)
|
||||
|
||||
# Value should exist initially
|
||||
value = await cache.get("key1")
|
||||
assert value == "value1"
|
||||
|
||||
# Wait for expiration
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Value should be expired
|
||||
value = await cache.get("key1")
|
||||
assert value is None
|
||||
|
||||
async def test_set_complex_value(self):
|
||||
"""Test storing complex data types."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Test dictionary
|
||||
data = {"nested": {"key": "value"}, "list": [1, 2, 3]}
|
||||
await cache.set("complex", data)
|
||||
value = await cache.get("complex")
|
||||
assert value == data
|
||||
|
||||
# Test list
|
||||
list_data = [1, "two", {"three": 3}]
|
||||
await cache.set("list", list_data)
|
||||
value = await cache.get("list")
|
||||
assert value == list_data
|
||||
|
||||
async def test_delete(self):
|
||||
"""Test deleting a key."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Set and delete
|
||||
await cache.set("key1", "value1")
|
||||
deleted = await cache.delete("key1")
|
||||
assert deleted is True
|
||||
|
||||
# Verify deleted
|
||||
value = await cache.get("key1")
|
||||
assert value is None
|
||||
|
||||
# Delete non-existent key
|
||||
deleted = await cache.delete("nonexistent")
|
||||
assert deleted is False
|
||||
|
||||
async def test_exists(self):
|
||||
"""Test checking key existence."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Non-existent key
|
||||
exists = await cache.exists("key1")
|
||||
assert exists is False
|
||||
|
||||
# Existing key
|
||||
await cache.set("key1", "value1")
|
||||
exists = await cache.exists("key1")
|
||||
assert exists is True
|
||||
|
||||
# Expired key
|
||||
await cache.set("key2", "value2", ttl=1)
|
||||
await asyncio.sleep(1.1)
|
||||
exists = await cache.exists("key2")
|
||||
assert exists is False
|
||||
|
||||
async def test_ttl(self):
|
||||
"""Test getting remaining TTL."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Non-existent key
|
||||
ttl = await cache.ttl("nonexistent")
|
||||
assert ttl is None
|
||||
|
||||
# Key with TTL
|
||||
await cache.set("key1", "value1", ttl=10)
|
||||
ttl = await cache.ttl("key1")
|
||||
assert ttl is not None
|
||||
assert 8 <= ttl <= 10 # Allow some tolerance
|
||||
|
||||
# Expired key
|
||||
await cache.set("key2", "value2", ttl=1)
|
||||
await asyncio.sleep(1.1)
|
||||
ttl = await cache.ttl("key2")
|
||||
assert ttl is None
|
||||
|
||||
async def test_clear(self):
|
||||
"""Test clearing all entries."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Add multiple entries
|
||||
await cache.set("key1", "value1")
|
||||
await cache.set("key2", "value2")
|
||||
await cache.set("key3", "value3")
|
||||
|
||||
# Clear
|
||||
await cache.clear()
|
||||
|
||||
# Verify all cleared
|
||||
assert await cache.get("key1") is None
|
||||
assert await cache.get("key2") is None
|
||||
assert await cache.get("key3") is None
|
||||
|
||||
async def test_size(self):
|
||||
"""Test getting cache size."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Empty cache
|
||||
size = await cache.size()
|
||||
assert size == 0
|
||||
|
||||
# Add entries
|
||||
await cache.set("key1", "value1")
|
||||
await cache.set("key2", "value2")
|
||||
size = await cache.size()
|
||||
assert size == 2
|
||||
|
||||
# Delete entry
|
||||
await cache.delete("key1")
|
||||
size = await cache.size()
|
||||
assert size == 1
|
||||
|
||||
# Clear cache
|
||||
await cache.clear()
|
||||
size = await cache.size()
|
||||
assert size == 0
|
||||
|
||||
async def test_lru_eviction(self):
|
||||
"""Test LRU eviction policy."""
|
||||
cache = MemoryCacheStore(max_entries=3, eviction_policy="lru")
|
||||
|
||||
# Fill cache
|
||||
await cache.set("key1", "value1")
|
||||
await cache.set("key2", "value2")
|
||||
await cache.set("key3", "value3")
|
||||
|
||||
# Access key1 to make it recently used
|
||||
await cache.get("key1")
|
||||
|
||||
# Add new entry, should evict key2 (least recently used)
|
||||
await cache.set("key4", "value4")
|
||||
|
||||
# key2 should be evicted
|
||||
assert await cache.get("key1") == "value1"
|
||||
assert await cache.get("key2") is None
|
||||
assert await cache.get("key3") == "value3"
|
||||
assert await cache.get("key4") == "value4"
|
||||
|
||||
async def test_lfu_eviction(self):
|
||||
"""Test LFU eviction policy."""
|
||||
cache = MemoryCacheStore(max_entries=3, eviction_policy="lfu")
|
||||
|
||||
# Fill cache
|
||||
await cache.set("key1", "value1")
|
||||
await cache.set("key2", "value2")
|
||||
await cache.set("key3", "value3")
|
||||
|
||||
# Access key1 multiple times
|
||||
await cache.get("key1")
|
||||
await cache.get("key1")
|
||||
await cache.get("key1")
|
||||
|
||||
# Access key2 twice
|
||||
await cache.get("key2")
|
||||
await cache.get("key2")
|
||||
|
||||
# key3 accessed once (least frequently)
|
||||
|
||||
# Add new entry, should evict key3 (least frequently used)
|
||||
await cache.set("key4", "value4")
|
||||
|
||||
# key3 should be evicted
|
||||
assert await cache.get("key1") == "value1"
|
||||
assert await cache.get("key2") == "value2"
|
||||
assert await cache.get("key3") is None
|
||||
assert await cache.get("key4") == "value4"
|
||||
|
||||
async def test_concurrent_access(self):
|
||||
"""Test concurrent access to cache."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
async def set_value(key: str, value: str):
|
||||
await cache.set(key, value)
|
||||
|
||||
async def get_value(key: str):
|
||||
return await cache.get(key)
|
||||
|
||||
# Concurrent sets
|
||||
await asyncio.gather(
|
||||
set_value("key1", "value1"),
|
||||
set_value("key2", "value2"),
|
||||
set_value("key3", "value3"),
|
||||
)
|
||||
|
||||
# Concurrent gets
|
||||
results = await asyncio.gather(
|
||||
get_value("key1"),
|
||||
get_value("key2"),
|
||||
get_value("key3"),
|
||||
)
|
||||
|
||||
assert results == ["value1", "value2", "value3"]
|
||||
|
||||
async def test_update_existing_key(self):
|
||||
"""Test updating an existing key."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Set initial value
|
||||
await cache.set("key1", "value1")
|
||||
assert await cache.get("key1") == "value1"
|
||||
|
||||
# Update value
|
||||
await cache.set("key1", "value2")
|
||||
assert await cache.get("key1") == "value2"
|
||||
|
||||
async def test_get_stats(self):
|
||||
"""Test getting cache statistics."""
|
||||
cache = MemoryCacheStore(
|
||||
max_entries=100,
|
||||
max_memory_mb=128,
|
||||
default_ttl=300,
|
||||
eviction_policy="lru",
|
||||
)
|
||||
|
||||
await cache.set("key1", "value1")
|
||||
await cache.set("key2", "value2")
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats["size"] == 2
|
||||
assert stats["max_entries"] == 100
|
||||
assert stats["max_memory_mb"] == 128
|
||||
assert stats["default_ttl"] == 300
|
||||
assert stats["eviction_policy"] == "lru"
|
||||
|
||||
async def test_ttl_expiration_cleanup(self):
|
||||
"""Test that expired entries are cleaned up properly."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Set entry with short TTL
|
||||
await cache.set("key1", "value1", ttl=1)
|
||||
await cache.set("key2", "value2", ttl=10)
|
||||
|
||||
# Initially both exist
|
||||
assert await cache.size() == 2
|
||||
|
||||
# Wait for key1 to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Accessing expired key should clean it up
|
||||
assert await cache.get("key1") is None
|
||||
|
||||
# Size should reflect cleanup
|
||||
size = await cache.size()
|
||||
assert size == 1
|
||||
|
||||
# key2 should still exist
|
||||
assert await cache.get("key2") == "value2"
|
||||
421
tests/unit/providers/utils/cache/test_redis_cache.py
vendored
Normal file
421
tests/unit/providers/utils/cache/test_redis_cache.py
vendored
Normal file
|
|
@ -0,0 +1,421 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for RedisCacheStore implementation."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.cache import CacheError, RedisCacheStore
|
||||
|
||||
|
||||
class TestRedisCacheStore:
|
||||
"""Test suite for RedisCacheStore."""
|
||||
|
||||
async def test_init_default_params(self):
|
||||
"""Test initialization with default parameters."""
|
||||
cache = RedisCacheStore()
|
||||
assert cache.host == "localhost"
|
||||
assert cache.port == 6379
|
||||
assert cache.db == 0
|
||||
assert cache.password is None
|
||||
assert cache.connection_pool_size == 10
|
||||
assert cache.timeout_ms == 100
|
||||
assert cache.default_ttl == 600
|
||||
assert cache.max_retries == 3
|
||||
assert cache.key_prefix == "llama_stack:"
|
||||
|
||||
async def test_init_custom_params(self):
|
||||
"""Test initialization with custom parameters."""
|
||||
cache = RedisCacheStore(
|
||||
host="redis.example.com",
|
||||
port=6380,
|
||||
db=1,
|
||||
password="secret",
|
||||
connection_pool_size=20,
|
||||
timeout_ms=200,
|
||||
default_ttl=300,
|
||||
max_retries=5,
|
||||
key_prefix="test:",
|
||||
)
|
||||
assert cache.host == "redis.example.com"
|
||||
assert cache.port == 6380
|
||||
assert cache.db == 1
|
||||
assert cache.password == "secret"
|
||||
assert cache.connection_pool_size == 20
|
||||
assert cache.timeout_ms == 200
|
||||
assert cache.default_ttl == 300
|
||||
assert cache.max_retries == 5
|
||||
assert cache.key_prefix == "test:"
|
||||
|
||||
async def test_init_invalid_params(self):
|
||||
"""Test initialization with invalid parameters."""
|
||||
with pytest.raises(ValueError, match="connection_pool_size must be positive"):
|
||||
RedisCacheStore(connection_pool_size=0)
|
||||
|
||||
with pytest.raises(ValueError, match="timeout_ms must be positive"):
|
||||
RedisCacheStore(timeout_ms=0)
|
||||
|
||||
with pytest.raises(ValueError, match="default_ttl must be positive"):
|
||||
RedisCacheStore(default_ttl=0)
|
||||
|
||||
with pytest.raises(ValueError, match="max_retries must be non-negative"):
|
||||
RedisCacheStore(max_retries=-1)
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_ensure_connection(self, mock_redis_class, mock_pool_class):
|
||||
"""Test connection establishment."""
|
||||
# Setup mocks
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Ensure connection
|
||||
redis = await cache._ensure_connection()
|
||||
|
||||
# Verify connection was established
|
||||
assert redis == mock_redis
|
||||
mock_pool_class.assert_called_once()
|
||||
mock_redis.ping.assert_called_once()
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_connection_failure(self, mock_redis_class, mock_pool_class):
|
||||
"""Test connection failure handling."""
|
||||
from redis.exceptions import ConnectionError as RedisConnectionError
|
||||
|
||||
# Setup mocks to fail
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock(side_effect=RedisConnectionError("Connection refused"))
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Connection should fail
|
||||
with pytest.raises(CacheError, match="Failed to connect to Redis"):
|
||||
await cache._ensure_connection()
|
||||
|
||||
def test_make_key(self):
|
||||
"""Test key prefixing."""
|
||||
cache = RedisCacheStore(key_prefix="test:")
|
||||
assert cache._make_key("mykey") == "test:mykey"
|
||||
assert cache._make_key("another") == "test:another"
|
||||
|
||||
def test_serialize_deserialize(self):
|
||||
"""Test value serialization."""
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Simple value
|
||||
assert cache._serialize("hello") == '"hello"'
|
||||
assert cache._deserialize('"hello"') == "hello"
|
||||
|
||||
# Dictionary
|
||||
data = {"key": "value", "number": 42}
|
||||
serialized = cache._serialize(data)
|
||||
assert cache._deserialize(serialized) == data
|
||||
|
||||
# List
|
||||
list_data = [1, 2, "three"]
|
||||
serialized = cache._serialize(list_data)
|
||||
assert cache._deserialize(serialized) == list_data
|
||||
|
||||
def test_serialize_error(self):
|
||||
"""Test serialization error handling."""
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Object that can't be serialized
|
||||
class NonSerializable:
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match="Value is not JSON-serializable"):
|
||||
cache._serialize(NonSerializable())
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_set_and_get(self, mock_redis_class, mock_pool_class):
|
||||
"""Test set and get operations."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=json.dumps("value1"))
|
||||
mock_redis.setex = AsyncMock()
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Set value
|
||||
await cache.set("key1", "value1")
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
# Get value
|
||||
value = await cache.get("key1")
|
||||
assert value == "value1"
|
||||
mock_redis.get.assert_called_once()
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_get_nonexistent_key(self, mock_redis_class, mock_pool_class):
|
||||
"""Test getting a non-existent key."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Get non-existent key
|
||||
value = await cache.get("nonexistent")
|
||||
assert value is None
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_set_with_custom_ttl(self, mock_redis_class, mock_pool_class):
|
||||
"""Test setting value with custom TTL."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.setex = AsyncMock()
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore(default_ttl=600)
|
||||
|
||||
# Set with custom TTL
|
||||
await cache.set("key1", "value1", ttl=300)
|
||||
|
||||
# Verify setex was called with custom TTL
|
||||
call_args = mock_redis.setex.call_args
|
||||
assert call_args[0][1] == 300 # TTL argument
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_delete(self, mock_redis_class, mock_pool_class):
|
||||
"""Test deleting a key."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.delete = AsyncMock(return_value=1) # 1 key deleted
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Delete key
|
||||
deleted = await cache.delete("key1")
|
||||
assert deleted is True
|
||||
|
||||
# Delete non-existent key
|
||||
mock_redis.delete = AsyncMock(return_value=0)
|
||||
deleted = await cache.delete("nonexistent")
|
||||
assert deleted is False
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_exists(self, mock_redis_class, mock_pool_class):
|
||||
"""Test checking key existence."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.exists = AsyncMock(return_value=1) # Exists
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Check existing key
|
||||
exists = await cache.exists("key1")
|
||||
assert exists is True
|
||||
|
||||
# Check non-existent key
|
||||
mock_redis.exists = AsyncMock(return_value=0)
|
||||
exists = await cache.exists("nonexistent")
|
||||
assert exists is False
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_ttl(self, mock_redis_class, mock_pool_class):
|
||||
"""Test getting remaining TTL."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.ttl = AsyncMock(return_value=300)
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Get TTL
|
||||
ttl = await cache.ttl("key1")
|
||||
assert ttl == 300
|
||||
|
||||
# Key doesn't exist
|
||||
mock_redis.ttl = AsyncMock(return_value=-2)
|
||||
ttl = await cache.ttl("nonexistent")
|
||||
assert ttl is None
|
||||
|
||||
# Key has no TTL
|
||||
mock_redis.ttl = AsyncMock(return_value=-1)
|
||||
ttl = await cache.ttl("no_ttl_key")
|
||||
assert ttl is None
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_clear(self, mock_redis_class, mock_pool_class):
|
||||
"""Test clearing all entries."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.scan = AsyncMock(
|
||||
side_effect=[
|
||||
(10, ["llama_stack:key1", "llama_stack:key2"]),
|
||||
(0, ["llama_stack:key3"]), # cursor 0 indicates end
|
||||
]
|
||||
)
|
||||
mock_redis.delete = AsyncMock(return_value=3)
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Clear cache
|
||||
await cache.clear()
|
||||
|
||||
# Verify scan and delete were called
|
||||
assert mock_redis.scan.call_count == 2
|
||||
mock_redis.delete.assert_called()
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_size(self, mock_redis_class, mock_pool_class):
|
||||
"""Test getting cache size."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.scan = AsyncMock(
|
||||
side_effect=[
|
||||
(10, ["llama_stack:key1", "llama_stack:key2"]),
|
||||
(0, ["llama_stack:key3"]),
|
||||
]
|
||||
)
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Get size
|
||||
size = await cache.size()
|
||||
assert size == 3
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_retry_logic(self, mock_redis_class, mock_pool_class):
|
||||
"""Test retry logic for transient failures."""
|
||||
from redis.exceptions import TimeoutError as RedisTimeoutError
|
||||
|
||||
# Setup mocks - fail twice, then succeed
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.get = AsyncMock(
|
||||
side_effect=[
|
||||
RedisTimeoutError("Timeout"),
|
||||
RedisTimeoutError("Timeout"),
|
||||
json.dumps("success"),
|
||||
]
|
||||
)
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache with retries
|
||||
cache = RedisCacheStore(max_retries=3)
|
||||
|
||||
# Should succeed after retries
|
||||
value = await cache.get("key1")
|
||||
assert value == "success"
|
||||
assert mock_redis.get.call_count == 3
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_retry_exhaustion(self, mock_redis_class, mock_pool_class):
|
||||
"""Test behavior when all retries are exhausted."""
|
||||
from redis.exceptions import TimeoutError as RedisTimeoutError
|
||||
|
||||
# Setup mocks - always fail
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=RedisTimeoutError("Timeout"))
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache with limited retries
|
||||
cache = RedisCacheStore(max_retries=2)
|
||||
|
||||
# Should raise CacheError after exhausting retries
|
||||
with pytest.raises(CacheError, match="failed after .* attempts"):
|
||||
await cache.get("key1")
|
||||
|
||||
# Should have tried 3 times (initial + 2 retries)
|
||||
assert mock_redis.get.call_count == 3
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_close(self, mock_redis_class, mock_pool_class):
|
||||
"""Test closing Redis connection."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.close = AsyncMock()
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
mock_pool = AsyncMock()
|
||||
mock_pool.disconnect = AsyncMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Create cache and establish connection
|
||||
cache = RedisCacheStore()
|
||||
await cache._ensure_connection()
|
||||
|
||||
# Close connection
|
||||
await cache.close()
|
||||
|
||||
# Verify cleanup
|
||||
mock_redis.close.assert_called_once()
|
||||
mock_pool.disconnect.assert_called_once()
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting cache statistics."""
|
||||
cache = RedisCacheStore(
|
||||
host="redis.example.com",
|
||||
port=6380,
|
||||
db=1,
|
||||
connection_pool_size=20,
|
||||
timeout_ms=200,
|
||||
default_ttl=300,
|
||||
max_retries=5,
|
||||
key_prefix="test:",
|
||||
)
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats["host"] == "redis.example.com"
|
||||
assert stats["port"] == 6380
|
||||
assert stats["db"] == 1
|
||||
assert stats["connection_pool_size"] == 20
|
||||
assert stats["timeout_ms"] == 200
|
||||
assert stats["default_ttl"] == 300
|
||||
assert stats["max_retries"] == 5
|
||||
assert stats["key_prefix"] == "test:"
|
||||
assert stats["connected"] is False # Not connected yet
|
||||
Loading…
Add table
Add a link
Reference in a new issue