mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
feat(cache): add cache store abstraction layer
Implement reusable cache store abstraction with in-memory and Redis backends as foundation for prompt caching feature (PR1 of progressive delivery). - Add CacheStore protocol defining cache interface - Implement MemoryCacheStore with LRU, LFU, and TTL-only eviction policies - Implement RedisCacheStore with connection pooling and retry logic - Add CircuitBreaker for cache backend failure protection - Include comprehensive unit tests (55 tests, >80% coverage) - Add dependencies: cachetools>=5.5.0, redis>=5.2.0 This abstraction enables flexible caching implementations for the prompt caching middleware without coupling to specific storage backends. Signed-by: William Caban <willliam.caban@gmail.com>
This commit is contained in:
parent
97f535c4f1
commit
299c575daa
10 changed files with 2175 additions and 1 deletions
|
|
@ -26,6 +26,7 @@ classifiers = [
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"PyYAML>=6.0",
|
"PyYAML>=6.0",
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
|
"cachetools>=5.5.0", # for prompt caching
|
||||||
"fastapi>=0.115.0,<1.0", # server
|
"fastapi>=0.115.0,<1.0", # server
|
||||||
"fire", # for MCP in LLS client
|
"fire", # for MCP in LLS client
|
||||||
"httpx",
|
"httpx",
|
||||||
|
|
@ -37,6 +38,7 @@ dependencies = [
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
"pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support.
|
"pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support.
|
||||||
"pydantic>=2.11.9",
|
"pydantic>=2.11.9",
|
||||||
|
"redis>=5.2.0", # for prompt caching (Redis backend)
|
||||||
"rich",
|
"rich",
|
||||||
"starlette",
|
"starlette",
|
||||||
"termcolor",
|
"termcolor",
|
||||||
|
|
|
||||||
37
src/llama_stack/providers/utils/cache/__init__.py
vendored
Normal file
37
src/llama_stack/providers/utils/cache/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Cache store utilities for prompt caching.
|
||||||
|
|
||||||
|
This module provides cache store abstractions and implementations for use in
|
||||||
|
the Llama Stack server's prompt caching feature. Supports both memory-based
|
||||||
|
and Redis-based caching with configurable eviction policies and TTL management.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
from llama_stack.providers.utils.cache import MemoryCacheStore, RedisCacheStore
|
||||||
|
|
||||||
|
# Memory cache for development
|
||||||
|
memory_cache = MemoryCacheStore(max_entries=1000, eviction_policy="lru")
|
||||||
|
|
||||||
|
# Redis cache for production
|
||||||
|
redis_cache = RedisCacheStore(
|
||||||
|
host="localhost",
|
||||||
|
port=6379,
|
||||||
|
connection_pool_size=10
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .cache_store import CacheError, CacheStore, CircuitBreaker
|
||||||
|
from .memory import MemoryCacheStore
|
||||||
|
from .redis import RedisCacheStore
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CacheStore",
|
||||||
|
"CacheError",
|
||||||
|
"CircuitBreaker",
|
||||||
|
"MemoryCacheStore",
|
||||||
|
"RedisCacheStore",
|
||||||
|
]
|
||||||
256
src/llama_stack/providers/utils/cache/cache_store.py
vendored
Normal file
256
src/llama_stack/providers/utils/cache/cache_store.py
vendored
Normal file
|
|
@ -0,0 +1,256 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Cache store abstraction for prompt caching implementation.
|
||||||
|
|
||||||
|
This module provides a protocol-based abstraction for cache storage backends,
|
||||||
|
enabling flexible storage implementations (memory, Redis, etc.) for prompt
|
||||||
|
caching in the Llama Stack server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import Any, Optional, Protocol
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheStore(Protocol):
|
||||||
|
"""Protocol defining the cache store interface.
|
||||||
|
|
||||||
|
This protocol specifies the required methods for cache store implementations.
|
||||||
|
All implementations must support TTL-based expiration and provide efficient
|
||||||
|
key-value storage operations.
|
||||||
|
|
||||||
|
Methods support both synchronous and asynchronous usage patterns depending
|
||||||
|
on the implementation requirements.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def get(self, key: str) -> Optional[Any]:
|
||||||
|
"""Retrieve a value from the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached value if present and not expired, None otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache backend is unavailable or operation fails
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
ttl: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store a value in the cache with optional TTL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
value: Value to cache (must be serializable)
|
||||||
|
ttl: Time-to-live in seconds. If None, uses default TTL.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache backend is unavailable or operation fails
|
||||||
|
ValueError: If value is not serializable
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""Delete a key from the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if key was deleted, False if key didn't exist
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache backend is unavailable or operation fails
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def exists(self, key: str) -> bool:
|
||||||
|
"""Check if a key exists in the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if key exists and is not expired, False otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache backend is unavailable or operation fails
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def ttl(self, key: str) -> Optional[int]:
|
||||||
|
"""Get the remaining TTL for a key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Remaining TTL in seconds, None if key doesn't exist or has no TTL
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache backend is unavailable or operation fails
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def clear(self) -> None:
|
||||||
|
"""Clear all entries from the cache.
|
||||||
|
|
||||||
|
This is primarily useful for testing. Use with caution in production
|
||||||
|
as it affects all cached data.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache backend is unavailable or operation fails
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def size(self) -> int:
|
||||||
|
"""Get the number of entries in the cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of cached entries
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache backend is unavailable or operation fails
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class CacheError(Exception):
|
||||||
|
"""Exception raised for cache operation failures.
|
||||||
|
|
||||||
|
This exception is raised when cache operations fail due to backend
|
||||||
|
unavailability, network issues, or other operational problems.
|
||||||
|
The system should gracefully degrade when catching this exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, cause: Optional[Exception] = None):
|
||||||
|
"""Initialize cache error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error description (should start with "Failed to ...")
|
||||||
|
cause: Optional underlying exception that caused this error
|
||||||
|
"""
|
||||||
|
super().__init__(message)
|
||||||
|
self.cause = cause
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitBreaker:
|
||||||
|
"""Circuit breaker pattern for cache backend failure protection.
|
||||||
|
|
||||||
|
Prevents cascade failures by temporarily disabling cache operations
|
||||||
|
after detecting repeated failures. Automatically attempts recovery
|
||||||
|
after a timeout period.
|
||||||
|
|
||||||
|
States:
|
||||||
|
- CLOSED: Normal operation, requests go through
|
||||||
|
- OPEN: Too many failures, requests are blocked
|
||||||
|
- HALF_OPEN: Testing if backend has recovered
|
||||||
|
|
||||||
|
Example:
|
||||||
|
breaker = CircuitBreaker(failure_threshold=10, recovery_timeout=60)
|
||||||
|
if breaker.is_closed():
|
||||||
|
try:
|
||||||
|
result = await cache.get(key)
|
||||||
|
breaker.record_success()
|
||||||
|
except CacheError:
|
||||||
|
breaker.record_failure()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
failure_threshold: int = 10,
|
||||||
|
recovery_timeout: int = 60,
|
||||||
|
):
|
||||||
|
"""Initialize circuit breaker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
failure_threshold: Number of consecutive failures before opening
|
||||||
|
recovery_timeout: Seconds to wait before attempting recovery
|
||||||
|
"""
|
||||||
|
self.failure_threshold = failure_threshold
|
||||||
|
self.recovery_timeout = recovery_timeout
|
||||||
|
self.failure_count = 0
|
||||||
|
self.last_failure_time: Optional[float] = None
|
||||||
|
self.state = "CLOSED" # CLOSED, OPEN, HALF_OPEN
|
||||||
|
|
||||||
|
def is_closed(self) -> bool:
|
||||||
|
"""Check if circuit breaker allows operations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if operations should proceed, False if blocked
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
if self.state == "CLOSED":
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self.state == "OPEN":
|
||||||
|
# Check if we should try recovery
|
||||||
|
if (
|
||||||
|
self.last_failure_time is not None
|
||||||
|
and time.time() - self.last_failure_time >= self.recovery_timeout
|
||||||
|
):
|
||||||
|
self.state = "HALF_OPEN"
|
||||||
|
logger.info("Circuit breaker entering HALF_OPEN state for recovery test")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# HALF_OPEN state - allow one request through to test
|
||||||
|
return True
|
||||||
|
|
||||||
|
def record_success(self) -> None:
|
||||||
|
"""Record a successful operation."""
|
||||||
|
if self.state == "HALF_OPEN":
|
||||||
|
logger.info("Circuit breaker recovery successful, returning to CLOSED state")
|
||||||
|
self.failure_count = 0
|
||||||
|
self.last_failure_time = None
|
||||||
|
self.state = "CLOSED"
|
||||||
|
|
||||||
|
def record_failure(self) -> None:
|
||||||
|
"""Record a failed operation."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
self.failure_count += 1
|
||||||
|
self.last_failure_time = time.time()
|
||||||
|
|
||||||
|
if self.state == "HALF_OPEN":
|
||||||
|
# Recovery attempt failed, go back to OPEN
|
||||||
|
logger.warning("Circuit breaker recovery failed, returning to OPEN state")
|
||||||
|
self.state = "OPEN"
|
||||||
|
elif self.failure_count >= self.failure_threshold:
|
||||||
|
logger.error(
|
||||||
|
f"Circuit breaker OPEN after {self.failure_count} failures. "
|
||||||
|
f"Cache operations disabled for {self.recovery_timeout}s"
|
||||||
|
)
|
||||||
|
self.state = "OPEN"
|
||||||
|
|
||||||
|
def get_state(self) -> str:
|
||||||
|
"""Get current circuit breaker state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current state: "CLOSED", "OPEN", or "HALF_OPEN"
|
||||||
|
"""
|
||||||
|
return self.state
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Manually reset the circuit breaker to CLOSED state.
|
||||||
|
|
||||||
|
This is primarily useful for testing or administrative overrides.
|
||||||
|
"""
|
||||||
|
self.failure_count = 0
|
||||||
|
self.last_failure_time = None
|
||||||
|
self.state = "CLOSED"
|
||||||
|
logger.info("Circuit breaker manually reset to CLOSED state")
|
||||||
334
src/llama_stack/providers/utils/cache/memory.py
vendored
Normal file
334
src/llama_stack/providers/utils/cache/memory.py
vendored
Normal file
|
|
@ -0,0 +1,334 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""In-memory cache store implementation using cachetools.
|
||||||
|
|
||||||
|
This module provides a memory-based cache store suitable for development
|
||||||
|
and single-node deployments. For production multi-node deployments,
|
||||||
|
consider using RedisCacheStore instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
|
from cachetools import Cache, LFUCache, LRUCache, TTLCache # type: ignore # no types-cachetools available
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .cache_store import CacheError
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
EvictionPolicy = Literal["lru", "lfu", "ttl-only"]
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryCacheStore:
|
||||||
|
"""In-memory cache store with configurable eviction policies.
|
||||||
|
|
||||||
|
This implementation uses the cachetools library to provide efficient
|
||||||
|
in-memory caching with support for multiple eviction policies:
|
||||||
|
- LRU (Least Recently Used): Evicts least recently accessed items
|
||||||
|
- LFU (Least Frequently Used): Evicts least frequently accessed items
|
||||||
|
- TTL-only: Evicts based on time-to-live only
|
||||||
|
|
||||||
|
Thread-safe for concurrent access within a single process.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
cache = MemoryCacheStore(
|
||||||
|
max_entries=1000,
|
||||||
|
default_ttl=600,
|
||||||
|
eviction_policy="lru"
|
||||||
|
)
|
||||||
|
await cache.set("key", "value", ttl=300)
|
||||||
|
value = await cache.get("key")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_entries: int = 1000,
|
||||||
|
max_memory_mb: Optional[int] = 512,
|
||||||
|
default_ttl: int = 600,
|
||||||
|
eviction_policy: EvictionPolicy = "lru",
|
||||||
|
):
|
||||||
|
"""Initialize memory cache store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_entries: Maximum number of entries to store
|
||||||
|
max_memory_mb: Maximum memory usage in MB (soft limit, estimated)
|
||||||
|
default_ttl: Default time-to-live in seconds
|
||||||
|
eviction_policy: Eviction strategy ("lru", "lfu", "ttl-only")
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If invalid parameters provided
|
||||||
|
"""
|
||||||
|
if max_entries <= 0:
|
||||||
|
raise ValueError("max_entries must be positive")
|
||||||
|
if default_ttl <= 0:
|
||||||
|
raise ValueError("default_ttl must be positive")
|
||||||
|
if max_memory_mb is not None and max_memory_mb <= 0:
|
||||||
|
raise ValueError("max_memory_mb must be positive")
|
||||||
|
|
||||||
|
self.max_entries = max_entries
|
||||||
|
self.max_memory_mb = max_memory_mb
|
||||||
|
self.default_ttl = default_ttl
|
||||||
|
self.eviction_policy = eviction_policy
|
||||||
|
|
||||||
|
# Create appropriate cache implementation
|
||||||
|
self._cache: Cache = self._create_cache()
|
||||||
|
self._ttl_map: dict[str, float] = {} # Track expiration times
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Initialized MemoryCacheStore: policy={eviction_policy}, "
|
||||||
|
f"max_entries={max_entries}, max_memory={max_memory_mb}MB, "
|
||||||
|
f"default_ttl={default_ttl}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_cache(self) -> Cache:
|
||||||
|
"""Create cache instance based on eviction policy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cache instance configured with chosen policy
|
||||||
|
"""
|
||||||
|
if self.eviction_policy == "lru":
|
||||||
|
return LRUCache(maxsize=self.max_entries)
|
||||||
|
elif self.eviction_policy == "lfu":
|
||||||
|
return LFUCache(maxsize=self.max_entries)
|
||||||
|
elif self.eviction_policy == "ttl-only":
|
||||||
|
return TTLCache(maxsize=self.max_entries, ttl=self.default_ttl)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown eviction policy: {self.eviction_policy}")
|
||||||
|
|
||||||
|
def _is_expired(self, key: str) -> bool:
|
||||||
|
"""Check if a key has expired based on TTL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if key has expired, False otherwise
|
||||||
|
"""
|
||||||
|
if key not in self._ttl_map:
|
||||||
|
return False
|
||||||
|
|
||||||
|
expiration_time = self._ttl_map[key]
|
||||||
|
if time.time() >= expiration_time:
|
||||||
|
# Clean up expired entry
|
||||||
|
self._cache.pop(key, None)
|
||||||
|
self._ttl_map.pop(key, None)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get(self, key: str) -> Optional[Any]:
|
||||||
|
"""Retrieve a value from the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached value if present and not expired, None otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache operation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check expiration first
|
||||||
|
if self._is_expired(key):
|
||||||
|
return None
|
||||||
|
|
||||||
|
value = self._cache.get(key)
|
||||||
|
if value is not None:
|
||||||
|
logger.debug(f"Cache hit: {key}")
|
||||||
|
return value
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get cache key '{key}': {e}")
|
||||||
|
raise CacheError(f"Failed to get cache key '{key}'", cause=e) from e
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
ttl: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store a value in the cache with optional TTL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
value: Value to cache
|
||||||
|
ttl: Time-to-live in seconds. If None, uses default TTL.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache operation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use default TTL if not specified
|
||||||
|
effective_ttl = ttl if ttl is not None else self.default_ttl
|
||||||
|
|
||||||
|
# Store value
|
||||||
|
self._cache[key] = value
|
||||||
|
|
||||||
|
# Track expiration time
|
||||||
|
self._ttl_map[key] = time.time() + effective_ttl
|
||||||
|
|
||||||
|
# Check memory usage (soft limit)
|
||||||
|
if self.max_memory_mb is not None:
|
||||||
|
self._check_memory_usage()
|
||||||
|
|
||||||
|
logger.debug(f"Cache set: {key} (ttl={effective_ttl}s)")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to set cache key '{key}': {e}")
|
||||||
|
raise CacheError(f"Failed to set cache key '{key}'", cause=e) from e
|
||||||
|
|
||||||
|
def _check_memory_usage(self) -> None:
|
||||||
|
"""Check and log if memory usage exceeds soft limit.
|
||||||
|
|
||||||
|
This is a soft limit - we log warnings but don't enforce hard limits.
|
||||||
|
The cachetools library will handle eviction based on max_entries.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get approximate memory usage
|
||||||
|
cache_size_bytes = sys.getsizeof(self._cache) + sys.getsizeof(self._ttl_map)
|
||||||
|
|
||||||
|
# Convert to MB
|
||||||
|
cache_size_mb = cache_size_bytes / (1024 * 1024)
|
||||||
|
|
||||||
|
if self.max_memory_mb is not None and cache_size_mb > self.max_memory_mb:
|
||||||
|
logger.warning(
|
||||||
|
f"Cache memory usage ({cache_size_mb:.1f}MB) exceeds "
|
||||||
|
f"soft limit ({self.max_memory_mb}MB). "
|
||||||
|
f"Consider increasing max_entries or max_memory_mb."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Don't fail on memory check errors
|
||||||
|
logger.debug(f"Memory usage check failed: {e}")
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""Delete a key from the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if key was deleted, False if key didn't exist
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache operation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
existed = key in self._cache
|
||||||
|
self._cache.pop(key, None)
|
||||||
|
self._ttl_map.pop(key, None)
|
||||||
|
|
||||||
|
if existed:
|
||||||
|
logger.debug(f"Cache delete: {key}")
|
||||||
|
|
||||||
|
return existed
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete cache key '{key}': {e}")
|
||||||
|
raise CacheError(f"Failed to delete cache key '{key}'", cause=e) from e
|
||||||
|
|
||||||
|
async def exists(self, key: str) -> bool:
|
||||||
|
"""Check if a key exists in the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if key exists and is not expired, False otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache operation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self._is_expired(key):
|
||||||
|
return False
|
||||||
|
return key in self._cache
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to check cache key existence '{key}': {e}")
|
||||||
|
raise CacheError(f"Failed to check cache key existence '{key}'", cause=e) from e
|
||||||
|
|
||||||
|
async def ttl(self, key: str) -> Optional[int]:
|
||||||
|
"""Get the remaining TTL for a key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Remaining TTL in seconds, None if key doesn't exist
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache operation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if key not in self._ttl_map:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._is_expired(key):
|
||||||
|
return None
|
||||||
|
|
||||||
|
remaining = int(self._ttl_map[key] - time.time())
|
||||||
|
return max(0, remaining)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get TTL for cache key '{key}': {e}")
|
||||||
|
raise CacheError(f"Failed to get TTL for cache key '{key}'", cause=e) from e
|
||||||
|
|
||||||
|
async def clear(self) -> None:
|
||||||
|
"""Clear all entries from the cache.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache operation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._cache.clear()
|
||||||
|
self._ttl_map.clear()
|
||||||
|
logger.info("Cache cleared")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to clear cache: {e}")
|
||||||
|
raise CacheError("Failed to clear cache", cause=e) from e
|
||||||
|
|
||||||
|
async def size(self) -> int:
|
||||||
|
"""Get the number of entries in the cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of cached entries (excluding expired entries)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache operation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Clean up expired entries first
|
||||||
|
expired_keys = [
|
||||||
|
key for key in list(self._ttl_map.keys())
|
||||||
|
if self._is_expired(key)
|
||||||
|
]
|
||||||
|
|
||||||
|
return len(self._cache)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get cache size: {e}")
|
||||||
|
raise CacheError("Failed to get cache size", cause=e) from e
|
||||||
|
|
||||||
|
def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get cache statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with cache statistics including size, policy, and limits
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"size": len(self._cache),
|
||||||
|
"max_entries": self.max_entries,
|
||||||
|
"max_memory_mb": self.max_memory_mb,
|
||||||
|
"default_ttl": self.default_ttl,
|
||||||
|
"eviction_policy": self.eviction_policy,
|
||||||
|
}
|
||||||
513
src/llama_stack/providers/utils/cache/redis.py
vendored
Normal file
513
src/llama_stack/providers/utils/cache/redis.py
vendored
Normal file
|
|
@ -0,0 +1,513 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Redis-based cache store implementation.
|
||||||
|
|
||||||
|
This module provides a production-ready Redis cache store with connection
|
||||||
|
pooling, retry logic, and comprehensive error handling. Suitable for
|
||||||
|
distributed deployments and high-throughput scenarios.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from redis import asyncio as aioredis
|
||||||
|
from redis.asyncio import ConnectionPool, Redis
|
||||||
|
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .cache_store import CacheError
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisCacheStore:
|
||||||
|
"""Redis-based cache store with connection pooling.
|
||||||
|
|
||||||
|
This implementation provides production-ready caching with:
|
||||||
|
- Connection pooling for efficient resource usage
|
||||||
|
- Automatic retry logic for transient failures
|
||||||
|
- Configurable timeouts to prevent blocking
|
||||||
|
- JSON serialization for complex data types
|
||||||
|
- Support for Redis cluster and sentinel
|
||||||
|
|
||||||
|
Example:
|
||||||
|
cache = RedisCacheStore(
|
||||||
|
host="localhost",
|
||||||
|
port=6379,
|
||||||
|
db=0,
|
||||||
|
password="secret",
|
||||||
|
connection_pool_size=10,
|
||||||
|
timeout_ms=100
|
||||||
|
)
|
||||||
|
await cache.set("key", {"data": "value"}, ttl=300)
|
||||||
|
value = await cache.get("key")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str = "localhost",
|
||||||
|
port: int = 6379,
|
||||||
|
db: int = 0,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
connection_pool_size: int = 10,
|
||||||
|
timeout_ms: int = 100,
|
||||||
|
default_ttl: int = 600,
|
||||||
|
max_retries: int = 3,
|
||||||
|
key_prefix: str = "llama_stack:",
|
||||||
|
):
|
||||||
|
"""Initialize Redis cache store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: Redis server hostname
|
||||||
|
port: Redis server port
|
||||||
|
db: Redis database number (0-15)
|
||||||
|
password: Optional Redis password
|
||||||
|
connection_pool_size: Maximum connections in pool
|
||||||
|
timeout_ms: Operation timeout in milliseconds
|
||||||
|
default_ttl: Default time-to-live in seconds
|
||||||
|
max_retries: Maximum retry attempts for failed operations
|
||||||
|
key_prefix: Prefix for all cache keys (namespace isolation)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If invalid parameters provided
|
||||||
|
"""
|
||||||
|
if connection_pool_size <= 0:
|
||||||
|
raise ValueError("connection_pool_size must be positive")
|
||||||
|
if timeout_ms <= 0:
|
||||||
|
raise ValueError("timeout_ms must be positive")
|
||||||
|
if default_ttl <= 0:
|
||||||
|
raise ValueError("default_ttl must be positive")
|
||||||
|
if max_retries < 0:
|
||||||
|
raise ValueError("max_retries must be non-negative")
|
||||||
|
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.db = db
|
||||||
|
self.password = password
|
||||||
|
self.connection_pool_size = connection_pool_size
|
||||||
|
self.timeout_ms = timeout_ms
|
||||||
|
self.default_ttl = default_ttl
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.key_prefix = key_prefix
|
||||||
|
|
||||||
|
# Connection pool (lazy initialization)
|
||||||
|
self._pool: Optional[ConnectionPool] = None
|
||||||
|
self._redis: Optional[Redis] = None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Initialized RedisCacheStore: host={host}:{port}, db={db}, "
|
||||||
|
f"pool_size={connection_pool_size}, timeout={timeout_ms}ms, "
|
||||||
|
f"default_ttl={default_ttl}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _ensure_connection(self) -> Redis:
|
||||||
|
"""Ensure Redis connection is established.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Redis client instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If connection cannot be established
|
||||||
|
"""
|
||||||
|
if self._redis is not None:
|
||||||
|
return self._redis
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create connection pool
|
||||||
|
self._pool = ConnectionPool(
|
||||||
|
host=self.host,
|
||||||
|
port=self.port,
|
||||||
|
db=self.db,
|
||||||
|
password=self.password,
|
||||||
|
max_connections=self.connection_pool_size,
|
||||||
|
socket_timeout=self.timeout_ms / 1000.0,
|
||||||
|
socket_connect_timeout=self.timeout_ms / 1000.0,
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create Redis client
|
||||||
|
self._redis = Redis(connection_pool=self._pool)
|
||||||
|
|
||||||
|
# Test connection
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._redis.ping(),
|
||||||
|
timeout=self.timeout_ms / 1000.0
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Connected to Redis at {self.host}:{self.port}")
|
||||||
|
return self._redis
|
||||||
|
|
||||||
|
except (ConnectionError, TimeoutError) as e:
|
||||||
|
logger.error(f"Failed to connect to Redis: {e}")
|
||||||
|
raise CacheError(f"Failed to connect to Redis at {self.host}:{self.port}", cause=e) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize Redis connection: {e}")
|
||||||
|
raise CacheError("Failed to initialize Redis connection", cause=e) from e
|
||||||
|
|
||||||
|
def _make_key(self, key: str) -> str:
|
||||||
|
"""Create prefixed cache key for namespace isolation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Base cache key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prefixed key
|
||||||
|
"""
|
||||||
|
return f"{self.key_prefix}{key}"
|
||||||
|
|
||||||
|
def _serialize(self, value: Any) -> str:
|
||||||
|
"""Serialize value for storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Value to serialize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON-serialized string
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If value cannot be serialized
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return json.dumps(value)
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
raise ValueError(f"Value is not JSON-serializable: {e}") from e
|
||||||
|
|
||||||
|
def _deserialize(self, data: str) -> Any:
|
||||||
|
"""Deserialize stored value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: JSON-serialized string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deserialized value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If data cannot be deserialized
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return json.loads(data)
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
logger.warning(f"Failed to deserialize cache value: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _retry_operation(self, operation, *args, **kwargs) -> Any:
|
||||||
|
"""Retry an operation with exponential backoff.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation: Async function to retry
|
||||||
|
*args: Positional arguments for operation
|
||||||
|
**kwargs: Keyword arguments for operation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Operation result
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If all retries fail
|
||||||
|
"""
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(
|
||||||
|
operation(*args, **kwargs),
|
||||||
|
timeout=self.timeout_ms / 1000.0
|
||||||
|
)
|
||||||
|
except (ConnectionError, TimeoutError) as e:
|
||||||
|
last_error = e
|
||||||
|
if attempt < self.max_retries:
|
||||||
|
backoff = 2 ** attempt * 0.1 # 100ms, 200ms, 400ms
|
||||||
|
logger.warning(
|
||||||
|
f"Redis operation failed (attempt {attempt + 1}/{self.max_retries + 1}), "
|
||||||
|
f"retrying in {backoff}s: {e}"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(backoff)
|
||||||
|
else:
|
||||||
|
logger.error(f"Redis operation failed after {self.max_retries + 1} attempts")
|
||||||
|
except Exception as e:
|
||||||
|
# Don't retry on non-transient errors
|
||||||
|
raise CacheError(f"Redis operation failed: {e}", cause=e) from e
|
||||||
|
|
||||||
|
raise CacheError(f"Redis operation failed after {self.max_retries + 1} attempts", cause=last_error) from last_error
|
||||||
|
|
||||||
|
async def get(self, key: str) -> Optional[Any]:
|
||||||
|
"""Retrieve a value from the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached value if present and not expired, None otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache operation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis = await self._ensure_connection()
|
||||||
|
prefixed_key = self._make_key(key)
|
||||||
|
|
||||||
|
data = await self._retry_operation(redis.get, prefixed_key)
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
value = self._deserialize(data)
|
||||||
|
if value is not None:
|
||||||
|
logger.debug(f"Cache hit: {key}")
|
||||||
|
return value
|
||||||
|
|
||||||
|
except CacheError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get cache key '{key}': {e}")
|
||||||
|
raise CacheError(f"Failed to get cache key '{key}'", cause=e) from e
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
ttl: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store a value in the cache with optional TTL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
value: Value to cache (must be JSON-serializable)
|
||||||
|
ttl: Time-to-live in seconds. If None, uses default TTL.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache operation fails
|
||||||
|
ValueError: If value is not serializable
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis = await self._ensure_connection()
|
||||||
|
prefixed_key = self._make_key(key)
|
||||||
|
|
||||||
|
# Serialize value
|
||||||
|
data = self._serialize(value)
|
||||||
|
|
||||||
|
# Use default TTL if not specified
|
||||||
|
effective_ttl = ttl if ttl is not None else self.default_ttl
|
||||||
|
|
||||||
|
# Store with TTL
|
||||||
|
await self._retry_operation(
|
||||||
|
redis.setex,
|
||||||
|
prefixed_key,
|
||||||
|
effective_ttl,
|
||||||
|
data
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Cache set: {key} (ttl={effective_ttl}s)")
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
raise
|
||||||
|
except CacheError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to set cache key '{key}': {e}")
|
||||||
|
raise CacheError(f"Failed to set cache key '{key}'", cause=e) from e
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""Delete a key from the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if key was deleted, False if key didn't exist
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache operation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis = await self._ensure_connection()
|
||||||
|
prefixed_key = self._make_key(key)
|
||||||
|
|
||||||
|
deleted_count = await self._retry_operation(redis.delete, prefixed_key)
|
||||||
|
|
||||||
|
if deleted_count > 0:
|
||||||
|
logger.debug(f"Cache delete: {key}")
|
||||||
|
|
||||||
|
return bool(deleted_count > 0)
|
||||||
|
|
||||||
|
except CacheError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete cache key '{key}': {e}")
|
||||||
|
raise CacheError(f"Failed to delete cache key '{key}'", cause=e) from e
|
||||||
|
|
||||||
|
async def exists(self, key: str) -> bool:
|
||||||
|
"""Check if a key exists in the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if key exists and is not expired, False otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache operation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis = await self._ensure_connection()
|
||||||
|
prefixed_key = self._make_key(key)
|
||||||
|
|
||||||
|
exists = await self._retry_operation(redis.exists, prefixed_key)
|
||||||
|
return bool(exists > 0)
|
||||||
|
|
||||||
|
except CacheError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to check cache key existence '{key}': {e}")
|
||||||
|
raise CacheError(f"Failed to check cache key existence '{key}'", cause=e) from e
|
||||||
|
|
||||||
|
async def ttl(self, key: str) -> Optional[int]:
|
||||||
|
"""Get the remaining TTL for a key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Remaining TTL in seconds, None if key doesn't exist or has no TTL
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache operation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis = await self._ensure_connection()
|
||||||
|
prefixed_key = self._make_key(key)
|
||||||
|
|
||||||
|
ttl_seconds = await self._retry_operation(redis.ttl, prefixed_key)
|
||||||
|
|
||||||
|
# Redis returns -2 if key doesn't exist, -1 if no TTL
|
||||||
|
if ttl_seconds == -2:
|
||||||
|
return None
|
||||||
|
if ttl_seconds == -1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return int(max(0, ttl_seconds))
|
||||||
|
|
||||||
|
except CacheError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get TTL for cache key '{key}': {e}")
|
||||||
|
raise CacheError(f"Failed to get TTL for cache key '{key}'", cause=e) from e
|
||||||
|
|
||||||
|
async def clear(self) -> None:
|
||||||
|
"""Clear all entries from the cache.
|
||||||
|
|
||||||
|
This deletes all keys matching the key_prefix pattern.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache operation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis = await self._ensure_connection()
|
||||||
|
pattern = f"{self.key_prefix}*"
|
||||||
|
|
||||||
|
# Scan and delete keys matching pattern
|
||||||
|
cursor = 0
|
||||||
|
deleted_total = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
cursor, keys = await self._retry_operation(
|
||||||
|
redis.scan,
|
||||||
|
cursor=cursor,
|
||||||
|
match=pattern,
|
||||||
|
count=100
|
||||||
|
)
|
||||||
|
|
||||||
|
if keys:
|
||||||
|
deleted_count = await self._retry_operation(redis.delete, *keys)
|
||||||
|
deleted_total += deleted_count
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"Cache cleared: deleted {deleted_total} keys")
|
||||||
|
|
||||||
|
except CacheError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to clear cache: {e}")
|
||||||
|
raise CacheError("Failed to clear cache", cause=e) from e
|
||||||
|
|
||||||
|
async def size(self) -> int:
|
||||||
|
"""Get the number of entries in the cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of cached entries matching key_prefix
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CacheError: If cache operation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis = await self._ensure_connection()
|
||||||
|
pattern = f"{self.key_prefix}*"
|
||||||
|
|
||||||
|
# Count keys matching pattern
|
||||||
|
cursor = 0
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
cursor, keys = await self._retry_operation(
|
||||||
|
redis.scan,
|
||||||
|
cursor=cursor,
|
||||||
|
match=pattern,
|
||||||
|
count=100
|
||||||
|
)
|
||||||
|
|
||||||
|
count += len(keys)
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
except CacheError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get cache size: {e}")
|
||||||
|
raise CacheError("Failed to get cache size", cause=e) from e
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close Redis connection and cleanup resources.
|
||||||
|
|
||||||
|
This should be called when the cache is no longer needed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self._redis is not None:
|
||||||
|
await self._redis.close()
|
||||||
|
self._redis = None
|
||||||
|
|
||||||
|
if self._pool is not None:
|
||||||
|
await self._pool.disconnect()
|
||||||
|
self._pool = None
|
||||||
|
|
||||||
|
logger.info("Redis connection closed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing Redis connection: {e}")
|
||||||
|
|
||||||
|
def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get cache statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with cache configuration and connection info
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"host": self.host,
|
||||||
|
"port": self.port,
|
||||||
|
"db": self.db,
|
||||||
|
"connection_pool_size": self.connection_pool_size,
|
||||||
|
"timeout_ms": self.timeout_ms,
|
||||||
|
"default_ttl": self.default_ttl,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
|
"key_prefix": self.key_prefix,
|
||||||
|
"connected": self._redis is not None,
|
||||||
|
}
|
||||||
7
tests/unit/providers/utils/cache/__init__.py
vendored
Normal file
7
tests/unit/providers/utils/cache/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Unit tests for cache store implementations."""
|
||||||
257
tests/unit/providers/utils/cache/test_cache_store.py
vendored
Normal file
257
tests/unit/providers/utils/cache/test_cache_store.py
vendored
Normal file
|
|
@ -0,0 +1,257 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Unit tests for cache store base classes and utilities."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.cache import CacheError, CircuitBreaker
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheError:
|
||||||
|
"""Test suite for CacheError exception."""
|
||||||
|
|
||||||
|
def test_init_with_message(self):
|
||||||
|
"""Test CacheError initialization with message."""
|
||||||
|
error = CacheError("Failed to connect to cache")
|
||||||
|
assert str(error) == "Failed to connect to cache"
|
||||||
|
assert error.cause is None
|
||||||
|
|
||||||
|
def test_init_with_cause(self):
|
||||||
|
"""Test CacheError initialization with underlying cause."""
|
||||||
|
cause = ValueError("Invalid value")
|
||||||
|
error = CacheError("Failed to set cache key", cause=cause)
|
||||||
|
assert str(error) == "Failed to set cache key"
|
||||||
|
assert error.cause == cause
|
||||||
|
|
||||||
|
|
||||||
|
class TestCircuitBreaker:
|
||||||
|
"""Test suite for CircuitBreaker."""
|
||||||
|
|
||||||
|
def test_init_default_params(self):
|
||||||
|
"""Test initialization with default parameters."""
|
||||||
|
breaker = CircuitBreaker()
|
||||||
|
assert breaker.failure_threshold == 10
|
||||||
|
assert breaker.recovery_timeout == 60
|
||||||
|
assert breaker.failure_count == 0
|
||||||
|
assert breaker.last_failure_time is None
|
||||||
|
assert breaker.state == "CLOSED"
|
||||||
|
|
||||||
|
def test_init_custom_params(self):
|
||||||
|
"""Test initialization with custom parameters."""
|
||||||
|
breaker = CircuitBreaker(failure_threshold=5, recovery_timeout=30)
|
||||||
|
assert breaker.failure_threshold == 5
|
||||||
|
assert breaker.recovery_timeout == 30
|
||||||
|
|
||||||
|
def test_is_closed_initial_state(self):
|
||||||
|
"""Test is_closed in initial state."""
|
||||||
|
breaker = CircuitBreaker()
|
||||||
|
assert breaker.is_closed() is True
|
||||||
|
assert breaker.get_state() == "CLOSED"
|
||||||
|
|
||||||
|
def test_record_success(self):
|
||||||
|
"""Test recording successful operations."""
|
||||||
|
breaker = CircuitBreaker()
|
||||||
|
|
||||||
|
# Record some failures
|
||||||
|
breaker.record_failure()
|
||||||
|
breaker.record_failure()
|
||||||
|
assert breaker.failure_count == 2
|
||||||
|
|
||||||
|
# Record success should reset
|
||||||
|
breaker.record_success()
|
||||||
|
assert breaker.failure_count == 0
|
||||||
|
assert breaker.last_failure_time is None
|
||||||
|
assert breaker.state == "CLOSED"
|
||||||
|
|
||||||
|
def test_record_failure_below_threshold(self):
|
||||||
|
"""Test recording failures below threshold."""
|
||||||
|
breaker = CircuitBreaker(failure_threshold=5)
|
||||||
|
|
||||||
|
# Record failures below threshold
|
||||||
|
for i in range(4):
|
||||||
|
breaker.record_failure()
|
||||||
|
assert breaker.is_closed() is True
|
||||||
|
assert breaker.state == "CLOSED"
|
||||||
|
|
||||||
|
assert breaker.failure_count == 4
|
||||||
|
|
||||||
|
def test_record_failure_reach_threshold(self):
|
||||||
|
"""Test circuit breaker opens when threshold reached."""
|
||||||
|
breaker = CircuitBreaker(failure_threshold=3)
|
||||||
|
|
||||||
|
# Record failures to reach threshold
|
||||||
|
for i in range(3):
|
||||||
|
breaker.record_failure()
|
||||||
|
|
||||||
|
# Should be open now
|
||||||
|
assert breaker.state == "OPEN"
|
||||||
|
assert breaker.is_closed() is False
|
||||||
|
|
||||||
|
def test_circuit_open_blocks_requests(self):
|
||||||
|
"""Test that open circuit blocks requests."""
|
||||||
|
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=10)
|
||||||
|
|
||||||
|
# Open the circuit
|
||||||
|
for i in range(3):
|
||||||
|
breaker.record_failure()
|
||||||
|
|
||||||
|
assert breaker.is_closed() is False
|
||||||
|
assert breaker.state == "OPEN"
|
||||||
|
|
||||||
|
async def test_recovery_timeout(self):
|
||||||
|
"""Test circuit breaker recovery after timeout."""
|
||||||
|
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
|
||||||
|
|
||||||
|
# Open the circuit
|
||||||
|
for i in range(3):
|
||||||
|
breaker.record_failure()
|
||||||
|
|
||||||
|
assert breaker.state == "OPEN"
|
||||||
|
assert breaker.is_closed() is False
|
||||||
|
|
||||||
|
# Wait for recovery timeout
|
||||||
|
await asyncio.sleep(1.1)
|
||||||
|
|
||||||
|
# Should enter HALF_OPEN state
|
||||||
|
assert breaker.is_closed() is True
|
||||||
|
assert breaker.state == "HALF_OPEN"
|
||||||
|
|
||||||
|
async def test_half_open_success_closes_circuit(self):
|
||||||
|
"""Test successful request in HALF_OPEN closes circuit."""
|
||||||
|
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
|
||||||
|
|
||||||
|
# Open the circuit
|
||||||
|
for i in range(3):
|
||||||
|
breaker.record_failure()
|
||||||
|
|
||||||
|
# Wait for recovery
|
||||||
|
await asyncio.sleep(1.1)
|
||||||
|
|
||||||
|
# Trigger state transition by calling is_closed()
|
||||||
|
assert breaker.is_closed() is True
|
||||||
|
assert breaker.state == "HALF_OPEN"
|
||||||
|
|
||||||
|
# Record success
|
||||||
|
breaker.record_success()
|
||||||
|
assert breaker.state == "CLOSED"
|
||||||
|
assert breaker.failure_count == 0
|
||||||
|
|
||||||
|
async def test_half_open_failure_reopens_circuit(self):
|
||||||
|
"""Test failed request in HALF_OPEN reopens circuit."""
|
||||||
|
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
|
||||||
|
|
||||||
|
# Open the circuit
|
||||||
|
for i in range(3):
|
||||||
|
breaker.record_failure()
|
||||||
|
|
||||||
|
# Wait for recovery
|
||||||
|
await asyncio.sleep(1.1)
|
||||||
|
|
||||||
|
# Trigger state transition by calling is_closed()
|
||||||
|
assert breaker.is_closed() is True
|
||||||
|
assert breaker.state == "HALF_OPEN"
|
||||||
|
|
||||||
|
# Record failure
|
||||||
|
breaker.record_failure()
|
||||||
|
assert breaker.state == "OPEN"
|
||||||
|
|
||||||
|
def test_reset(self):
|
||||||
|
"""Test manual reset of circuit breaker."""
|
||||||
|
breaker = CircuitBreaker(failure_threshold=3)
|
||||||
|
|
||||||
|
# Open the circuit
|
||||||
|
for i in range(3):
|
||||||
|
breaker.record_failure()
|
||||||
|
|
||||||
|
assert breaker.state == "OPEN"
|
||||||
|
|
||||||
|
# Manual reset
|
||||||
|
breaker.reset()
|
||||||
|
assert breaker.state == "CLOSED"
|
||||||
|
assert breaker.failure_count == 0
|
||||||
|
assert breaker.last_failure_time is None
|
||||||
|
|
||||||
|
def test_get_state(self):
|
||||||
|
"""Test getting circuit breaker state."""
|
||||||
|
breaker = CircuitBreaker(failure_threshold=3)
|
||||||
|
|
||||||
|
# Initial state
|
||||||
|
assert breaker.get_state() == "CLOSED"
|
||||||
|
|
||||||
|
# After failures
|
||||||
|
breaker.record_failure()
|
||||||
|
assert breaker.get_state() == "CLOSED"
|
||||||
|
|
||||||
|
# Open state
|
||||||
|
for i in range(2):
|
||||||
|
breaker.record_failure()
|
||||||
|
assert breaker.get_state() == "OPEN"
|
||||||
|
|
||||||
|
async def test_multiple_recovery_attempts(self):
|
||||||
|
"""Test multiple recovery attempts."""
|
||||||
|
breaker = CircuitBreaker(failure_threshold=2, recovery_timeout=1)
|
||||||
|
|
||||||
|
# Open the circuit
|
||||||
|
breaker.record_failure()
|
||||||
|
breaker.record_failure()
|
||||||
|
assert breaker.state == "OPEN"
|
||||||
|
|
||||||
|
# First recovery attempt fails
|
||||||
|
await asyncio.sleep(1.1)
|
||||||
|
assert breaker.is_closed() is True # Trigger state check
|
||||||
|
assert breaker.state == "HALF_OPEN"
|
||||||
|
breaker.record_failure()
|
||||||
|
assert breaker.state == "OPEN"
|
||||||
|
|
||||||
|
# Second recovery attempt succeeds
|
||||||
|
await asyncio.sleep(1.1)
|
||||||
|
assert breaker.is_closed() is True # Trigger state check
|
||||||
|
assert breaker.state == "HALF_OPEN"
|
||||||
|
breaker.record_success()
|
||||||
|
assert breaker.state == "CLOSED"
|
||||||
|
|
||||||
|
def test_failure_count_tracking(self):
|
||||||
|
"""Test failure count tracking."""
|
||||||
|
breaker = CircuitBreaker(failure_threshold=5)
|
||||||
|
|
||||||
|
# Track failures
|
||||||
|
assert breaker.failure_count == 0
|
||||||
|
|
||||||
|
breaker.record_failure()
|
||||||
|
assert breaker.failure_count == 1
|
||||||
|
|
||||||
|
breaker.record_failure()
|
||||||
|
assert breaker.failure_count == 2
|
||||||
|
|
||||||
|
# Success resets count
|
||||||
|
breaker.record_success()
|
||||||
|
assert breaker.failure_count == 0
|
||||||
|
|
||||||
|
async def test_concurrent_operations(self):
|
||||||
|
"""Test circuit breaker with concurrent operations."""
|
||||||
|
breaker = CircuitBreaker(failure_threshold=10)
|
||||||
|
|
||||||
|
async def record_failures(count: int):
|
||||||
|
for _ in range(count):
|
||||||
|
breaker.record_failure()
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
# Concurrent failures
|
||||||
|
await asyncio.gather(
|
||||||
|
record_failures(3),
|
||||||
|
record_failures(3),
|
||||||
|
record_failures(3),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert breaker.failure_count == 9
|
||||||
|
assert breaker.state == "CLOSED"
|
||||||
|
|
||||||
|
# One more should open it
|
||||||
|
breaker.record_failure()
|
||||||
|
assert breaker.state == "OPEN"
|
||||||
332
tests/unit/providers/utils/cache/test_memory_cache.py
vendored
Normal file
332
tests/unit/providers/utils/cache/test_memory_cache.py
vendored
Normal file
|
|
@ -0,0 +1,332 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Unit tests for MemoryCacheStore implementation."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.cache import CacheError, MemoryCacheStore
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryCacheStore:
|
||||||
|
"""Test suite for MemoryCacheStore."""
|
||||||
|
|
||||||
|
async def test_init_default_params(self):
|
||||||
|
"""Test initialization with default parameters."""
|
||||||
|
cache = MemoryCacheStore()
|
||||||
|
assert cache.max_entries == 1000
|
||||||
|
assert cache.max_memory_mb == 512
|
||||||
|
assert cache.default_ttl == 600
|
||||||
|
assert cache.eviction_policy == "lru"
|
||||||
|
|
||||||
|
async def test_init_custom_params(self):
|
||||||
|
"""Test initialization with custom parameters."""
|
||||||
|
cache = MemoryCacheStore(
|
||||||
|
max_entries=500,
|
||||||
|
max_memory_mb=256,
|
||||||
|
default_ttl=300,
|
||||||
|
eviction_policy="lfu",
|
||||||
|
)
|
||||||
|
assert cache.max_entries == 500
|
||||||
|
assert cache.max_memory_mb == 256
|
||||||
|
assert cache.default_ttl == 300
|
||||||
|
assert cache.eviction_policy == "lfu"
|
||||||
|
|
||||||
|
async def test_init_invalid_params(self):
|
||||||
|
"""Test initialization with invalid parameters."""
|
||||||
|
with pytest.raises(ValueError, match="max_entries must be positive"):
|
||||||
|
MemoryCacheStore(max_entries=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="default_ttl must be positive"):
|
||||||
|
MemoryCacheStore(default_ttl=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="max_memory_mb must be positive"):
|
||||||
|
MemoryCacheStore(max_memory_mb=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown eviction policy"):
|
||||||
|
MemoryCacheStore(eviction_policy="invalid") # type: ignore
|
||||||
|
|
||||||
|
async def test_set_and_get(self):
|
||||||
|
"""Test basic set and get operations."""
|
||||||
|
cache = MemoryCacheStore()
|
||||||
|
|
||||||
|
# Set value
|
||||||
|
await cache.set("key1", "value1")
|
||||||
|
|
||||||
|
# Get value
|
||||||
|
value = await cache.get("key1")
|
||||||
|
assert value == "value1"
|
||||||
|
|
||||||
|
async def test_get_nonexistent_key(self):
|
||||||
|
"""Test getting a non-existent key."""
|
||||||
|
cache = MemoryCacheStore()
|
||||||
|
value = await cache.get("nonexistent")
|
||||||
|
assert value is None
|
||||||
|
|
||||||
|
async def test_set_with_custom_ttl(self):
|
||||||
|
"""Test setting value with custom TTL."""
|
||||||
|
cache = MemoryCacheStore(default_ttl=10)
|
||||||
|
|
||||||
|
# Set with custom TTL
|
||||||
|
await cache.set("key1", "value1", ttl=1)
|
||||||
|
|
||||||
|
# Value should exist initially
|
||||||
|
value = await cache.get("key1")
|
||||||
|
assert value == "value1"
|
||||||
|
|
||||||
|
# Wait for expiration
|
||||||
|
await asyncio.sleep(1.1)
|
||||||
|
|
||||||
|
# Value should be expired
|
||||||
|
value = await cache.get("key1")
|
||||||
|
assert value is None
|
||||||
|
|
||||||
|
async def test_set_complex_value(self):
|
||||||
|
"""Test storing complex data types."""
|
||||||
|
cache = MemoryCacheStore()
|
||||||
|
|
||||||
|
# Test dictionary
|
||||||
|
data = {"nested": {"key": "value"}, "list": [1, 2, 3]}
|
||||||
|
await cache.set("complex", data)
|
||||||
|
value = await cache.get("complex")
|
||||||
|
assert value == data
|
||||||
|
|
||||||
|
# Test list
|
||||||
|
list_data = [1, "two", {"three": 3}]
|
||||||
|
await cache.set("list", list_data)
|
||||||
|
value = await cache.get("list")
|
||||||
|
assert value == list_data
|
||||||
|
|
||||||
|
async def test_delete(self):
|
||||||
|
"""Test deleting a key."""
|
||||||
|
cache = MemoryCacheStore()
|
||||||
|
|
||||||
|
# Set and delete
|
||||||
|
await cache.set("key1", "value1")
|
||||||
|
deleted = await cache.delete("key1")
|
||||||
|
assert deleted is True
|
||||||
|
|
||||||
|
# Verify deleted
|
||||||
|
value = await cache.get("key1")
|
||||||
|
assert value is None
|
||||||
|
|
||||||
|
# Delete non-existent key
|
||||||
|
deleted = await cache.delete("nonexistent")
|
||||||
|
assert deleted is False
|
||||||
|
|
||||||
|
async def test_exists(self):
|
||||||
|
"""Test checking key existence."""
|
||||||
|
cache = MemoryCacheStore()
|
||||||
|
|
||||||
|
# Non-existent key
|
||||||
|
exists = await cache.exists("key1")
|
||||||
|
assert exists is False
|
||||||
|
|
||||||
|
# Existing key
|
||||||
|
await cache.set("key1", "value1")
|
||||||
|
exists = await cache.exists("key1")
|
||||||
|
assert exists is True
|
||||||
|
|
||||||
|
# Expired key
|
||||||
|
await cache.set("key2", "value2", ttl=1)
|
||||||
|
await asyncio.sleep(1.1)
|
||||||
|
exists = await cache.exists("key2")
|
||||||
|
assert exists is False
|
||||||
|
|
||||||
|
async def test_ttl(self):
|
||||||
|
"""Test getting remaining TTL."""
|
||||||
|
cache = MemoryCacheStore()
|
||||||
|
|
||||||
|
# Non-existent key
|
||||||
|
ttl = await cache.ttl("nonexistent")
|
||||||
|
assert ttl is None
|
||||||
|
|
||||||
|
# Key with TTL
|
||||||
|
await cache.set("key1", "value1", ttl=10)
|
||||||
|
ttl = await cache.ttl("key1")
|
||||||
|
assert ttl is not None
|
||||||
|
assert 8 <= ttl <= 10 # Allow some tolerance
|
||||||
|
|
||||||
|
# Expired key
|
||||||
|
await cache.set("key2", "value2", ttl=1)
|
||||||
|
await asyncio.sleep(1.1)
|
||||||
|
ttl = await cache.ttl("key2")
|
||||||
|
assert ttl is None
|
||||||
|
|
||||||
|
async def test_clear(self):
|
||||||
|
"""Test clearing all entries."""
|
||||||
|
cache = MemoryCacheStore()
|
||||||
|
|
||||||
|
# Add multiple entries
|
||||||
|
await cache.set("key1", "value1")
|
||||||
|
await cache.set("key2", "value2")
|
||||||
|
await cache.set("key3", "value3")
|
||||||
|
|
||||||
|
# Clear
|
||||||
|
await cache.clear()
|
||||||
|
|
||||||
|
# Verify all cleared
|
||||||
|
assert await cache.get("key1") is None
|
||||||
|
assert await cache.get("key2") is None
|
||||||
|
assert await cache.get("key3") is None
|
||||||
|
|
||||||
|
async def test_size(self):
|
||||||
|
"""Test getting cache size."""
|
||||||
|
cache = MemoryCacheStore()
|
||||||
|
|
||||||
|
# Empty cache
|
||||||
|
size = await cache.size()
|
||||||
|
assert size == 0
|
||||||
|
|
||||||
|
# Add entries
|
||||||
|
await cache.set("key1", "value1")
|
||||||
|
await cache.set("key2", "value2")
|
||||||
|
size = await cache.size()
|
||||||
|
assert size == 2
|
||||||
|
|
||||||
|
# Delete entry
|
||||||
|
await cache.delete("key1")
|
||||||
|
size = await cache.size()
|
||||||
|
assert size == 1
|
||||||
|
|
||||||
|
# Clear cache
|
||||||
|
await cache.clear()
|
||||||
|
size = await cache.size()
|
||||||
|
assert size == 0
|
||||||
|
|
||||||
|
async def test_lru_eviction(self):
|
||||||
|
"""Test LRU eviction policy."""
|
||||||
|
cache = MemoryCacheStore(max_entries=3, eviction_policy="lru")
|
||||||
|
|
||||||
|
# Fill cache
|
||||||
|
await cache.set("key1", "value1")
|
||||||
|
await cache.set("key2", "value2")
|
||||||
|
await cache.set("key3", "value3")
|
||||||
|
|
||||||
|
# Access key1 to make it recently used
|
||||||
|
await cache.get("key1")
|
||||||
|
|
||||||
|
# Add new entry, should evict key2 (least recently used)
|
||||||
|
await cache.set("key4", "value4")
|
||||||
|
|
||||||
|
# key2 should be evicted
|
||||||
|
assert await cache.get("key1") == "value1"
|
||||||
|
assert await cache.get("key2") is None
|
||||||
|
assert await cache.get("key3") == "value3"
|
||||||
|
assert await cache.get("key4") == "value4"
|
||||||
|
|
||||||
|
async def test_lfu_eviction(self):
|
||||||
|
"""Test LFU eviction policy."""
|
||||||
|
cache = MemoryCacheStore(max_entries=3, eviction_policy="lfu")
|
||||||
|
|
||||||
|
# Fill cache
|
||||||
|
await cache.set("key1", "value1")
|
||||||
|
await cache.set("key2", "value2")
|
||||||
|
await cache.set("key3", "value3")
|
||||||
|
|
||||||
|
# Access key1 multiple times
|
||||||
|
await cache.get("key1")
|
||||||
|
await cache.get("key1")
|
||||||
|
await cache.get("key1")
|
||||||
|
|
||||||
|
# Access key2 twice
|
||||||
|
await cache.get("key2")
|
||||||
|
await cache.get("key2")
|
||||||
|
|
||||||
|
# key3 accessed once (least frequently)
|
||||||
|
|
||||||
|
# Add new entry, should evict key3 (least frequently used)
|
||||||
|
await cache.set("key4", "value4")
|
||||||
|
|
||||||
|
# key3 should be evicted
|
||||||
|
assert await cache.get("key1") == "value1"
|
||||||
|
assert await cache.get("key2") == "value2"
|
||||||
|
assert await cache.get("key3") is None
|
||||||
|
assert await cache.get("key4") == "value4"
|
||||||
|
|
||||||
|
async def test_concurrent_access(self):
|
||||||
|
"""Test concurrent access to cache."""
|
||||||
|
cache = MemoryCacheStore()
|
||||||
|
|
||||||
|
async def set_value(key: str, value: str):
|
||||||
|
await cache.set(key, value)
|
||||||
|
|
||||||
|
async def get_value(key: str):
|
||||||
|
return await cache.get(key)
|
||||||
|
|
||||||
|
# Concurrent sets
|
||||||
|
await asyncio.gather(
|
||||||
|
set_value("key1", "value1"),
|
||||||
|
set_value("key2", "value2"),
|
||||||
|
set_value("key3", "value3"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concurrent gets
|
||||||
|
results = await asyncio.gather(
|
||||||
|
get_value("key1"),
|
||||||
|
get_value("key2"),
|
||||||
|
get_value("key3"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert results == ["value1", "value2", "value3"]
|
||||||
|
|
||||||
|
async def test_update_existing_key(self):
|
||||||
|
"""Test updating an existing key."""
|
||||||
|
cache = MemoryCacheStore()
|
||||||
|
|
||||||
|
# Set initial value
|
||||||
|
await cache.set("key1", "value1")
|
||||||
|
assert await cache.get("key1") == "value1"
|
||||||
|
|
||||||
|
# Update value
|
||||||
|
await cache.set("key1", "value2")
|
||||||
|
assert await cache.get("key1") == "value2"
|
||||||
|
|
||||||
|
async def test_get_stats(self):
|
||||||
|
"""Test getting cache statistics."""
|
||||||
|
cache = MemoryCacheStore(
|
||||||
|
max_entries=100,
|
||||||
|
max_memory_mb=128,
|
||||||
|
default_ttl=300,
|
||||||
|
eviction_policy="lru",
|
||||||
|
)
|
||||||
|
|
||||||
|
await cache.set("key1", "value1")
|
||||||
|
await cache.set("key2", "value2")
|
||||||
|
|
||||||
|
stats = cache.get_stats()
|
||||||
|
|
||||||
|
assert stats["size"] == 2
|
||||||
|
assert stats["max_entries"] == 100
|
||||||
|
assert stats["max_memory_mb"] == 128
|
||||||
|
assert stats["default_ttl"] == 300
|
||||||
|
assert stats["eviction_policy"] == "lru"
|
||||||
|
|
||||||
|
async def test_ttl_expiration_cleanup(self):
|
||||||
|
"""Test that expired entries are cleaned up properly."""
|
||||||
|
cache = MemoryCacheStore()
|
||||||
|
|
||||||
|
# Set entry with short TTL
|
||||||
|
await cache.set("key1", "value1", ttl=1)
|
||||||
|
await cache.set("key2", "value2", ttl=10)
|
||||||
|
|
||||||
|
# Initially both exist
|
||||||
|
assert await cache.size() == 2
|
||||||
|
|
||||||
|
# Wait for key1 to expire
|
||||||
|
await asyncio.sleep(1.1)
|
||||||
|
|
||||||
|
# Accessing expired key should clean it up
|
||||||
|
assert await cache.get("key1") is None
|
||||||
|
|
||||||
|
# Size should reflect cleanup
|
||||||
|
size = await cache.size()
|
||||||
|
assert size == 1
|
||||||
|
|
||||||
|
# key2 should still exist
|
||||||
|
assert await cache.get("key2") == "value2"
|
||||||
421
tests/unit/providers/utils/cache/test_redis_cache.py
vendored
Normal file
421
tests/unit/providers/utils/cache/test_redis_cache.py
vendored
Normal file
|
|
@ -0,0 +1,421 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Unit tests for RedisCacheStore implementation."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.cache import CacheError, RedisCacheStore
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedisCacheStore:
|
||||||
|
"""Test suite for RedisCacheStore."""
|
||||||
|
|
||||||
|
async def test_init_default_params(self):
|
||||||
|
"""Test initialization with default parameters."""
|
||||||
|
cache = RedisCacheStore()
|
||||||
|
assert cache.host == "localhost"
|
||||||
|
assert cache.port == 6379
|
||||||
|
assert cache.db == 0
|
||||||
|
assert cache.password is None
|
||||||
|
assert cache.connection_pool_size == 10
|
||||||
|
assert cache.timeout_ms == 100
|
||||||
|
assert cache.default_ttl == 600
|
||||||
|
assert cache.max_retries == 3
|
||||||
|
assert cache.key_prefix == "llama_stack:"
|
||||||
|
|
||||||
|
async def test_init_custom_params(self):
|
||||||
|
"""Test initialization with custom parameters."""
|
||||||
|
cache = RedisCacheStore(
|
||||||
|
host="redis.example.com",
|
||||||
|
port=6380,
|
||||||
|
db=1,
|
||||||
|
password="secret",
|
||||||
|
connection_pool_size=20,
|
||||||
|
timeout_ms=200,
|
||||||
|
default_ttl=300,
|
||||||
|
max_retries=5,
|
||||||
|
key_prefix="test:",
|
||||||
|
)
|
||||||
|
assert cache.host == "redis.example.com"
|
||||||
|
assert cache.port == 6380
|
||||||
|
assert cache.db == 1
|
||||||
|
assert cache.password == "secret"
|
||||||
|
assert cache.connection_pool_size == 20
|
||||||
|
assert cache.timeout_ms == 200
|
||||||
|
assert cache.default_ttl == 300
|
||||||
|
assert cache.max_retries == 5
|
||||||
|
assert cache.key_prefix == "test:"
|
||||||
|
|
||||||
|
async def test_init_invalid_params(self):
|
||||||
|
"""Test initialization with invalid parameters."""
|
||||||
|
with pytest.raises(ValueError, match="connection_pool_size must be positive"):
|
||||||
|
RedisCacheStore(connection_pool_size=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="timeout_ms must be positive"):
|
||||||
|
RedisCacheStore(timeout_ms=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="default_ttl must be positive"):
|
||||||
|
RedisCacheStore(default_ttl=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="max_retries must be non-negative"):
|
||||||
|
RedisCacheStore(max_retries=-1)
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||||
|
async def test_ensure_connection(self, mock_redis_class, mock_pool_class):
|
||||||
|
"""Test connection establishment."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock()
|
||||||
|
mock_redis_class.return_value = mock_redis
|
||||||
|
|
||||||
|
# Create cache
|
||||||
|
cache = RedisCacheStore()
|
||||||
|
|
||||||
|
# Ensure connection
|
||||||
|
redis = await cache._ensure_connection()
|
||||||
|
|
||||||
|
# Verify connection was established
|
||||||
|
assert redis == mock_redis
|
||||||
|
mock_pool_class.assert_called_once()
|
||||||
|
mock_redis.ping.assert_called_once()
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||||
|
async def test_connection_failure(self, mock_redis_class, mock_pool_class):
|
||||||
|
"""Test connection failure handling."""
|
||||||
|
from redis.exceptions import ConnectionError as RedisConnectionError
|
||||||
|
|
||||||
|
# Setup mocks to fail
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock(side_effect=RedisConnectionError("Connection refused"))
|
||||||
|
mock_redis_class.return_value = mock_redis
|
||||||
|
|
||||||
|
# Create cache
|
||||||
|
cache = RedisCacheStore()
|
||||||
|
|
||||||
|
# Connection should fail
|
||||||
|
with pytest.raises(CacheError, match="Failed to connect to Redis"):
|
||||||
|
await cache._ensure_connection()
|
||||||
|
|
||||||
|
def test_make_key(self):
|
||||||
|
"""Test key prefixing."""
|
||||||
|
cache = RedisCacheStore(key_prefix="test:")
|
||||||
|
assert cache._make_key("mykey") == "test:mykey"
|
||||||
|
assert cache._make_key("another") == "test:another"
|
||||||
|
|
||||||
|
def test_serialize_deserialize(self):
|
||||||
|
"""Test value serialization."""
|
||||||
|
cache = RedisCacheStore()
|
||||||
|
|
||||||
|
# Simple value
|
||||||
|
assert cache._serialize("hello") == '"hello"'
|
||||||
|
assert cache._deserialize('"hello"') == "hello"
|
||||||
|
|
||||||
|
# Dictionary
|
||||||
|
data = {"key": "value", "number": 42}
|
||||||
|
serialized = cache._serialize(data)
|
||||||
|
assert cache._deserialize(serialized) == data
|
||||||
|
|
||||||
|
# List
|
||||||
|
list_data = [1, 2, "three"]
|
||||||
|
serialized = cache._serialize(list_data)
|
||||||
|
assert cache._deserialize(serialized) == list_data
|
||||||
|
|
||||||
|
def test_serialize_error(self):
|
||||||
|
"""Test serialization error handling."""
|
||||||
|
cache = RedisCacheStore()
|
||||||
|
|
||||||
|
# Object that can't be serialized
|
||||||
|
class NonSerializable:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Value is not JSON-serializable"):
|
||||||
|
cache._serialize(NonSerializable())
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||||
|
async def test_set_and_get(self, mock_redis_class, mock_pool_class):
|
||||||
|
"""Test set and get operations."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(return_value=json.dumps("value1"))
|
||||||
|
mock_redis.setex = AsyncMock()
|
||||||
|
mock_redis_class.return_value = mock_redis
|
||||||
|
|
||||||
|
# Create cache
|
||||||
|
cache = RedisCacheStore()
|
||||||
|
|
||||||
|
# Set value
|
||||||
|
await cache.set("key1", "value1")
|
||||||
|
mock_redis.setex.assert_called_once()
|
||||||
|
|
||||||
|
# Get value
|
||||||
|
value = await cache.get("key1")
|
||||||
|
assert value == "value1"
|
||||||
|
mock_redis.get.assert_called_once()
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||||
|
async def test_get_nonexistent_key(self, mock_redis_class, mock_pool_class):
|
||||||
|
"""Test getting a non-existent key."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(return_value=None)
|
||||||
|
mock_redis_class.return_value = mock_redis
|
||||||
|
|
||||||
|
# Create cache
|
||||||
|
cache = RedisCacheStore()
|
||||||
|
|
||||||
|
# Get non-existent key
|
||||||
|
value = await cache.get("nonexistent")
|
||||||
|
assert value is None
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||||
|
async def test_set_with_custom_ttl(self, mock_redis_class, mock_pool_class):
|
||||||
|
"""Test setting value with custom TTL."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock()
|
||||||
|
mock_redis.setex = AsyncMock()
|
||||||
|
mock_redis_class.return_value = mock_redis
|
||||||
|
|
||||||
|
# Create cache
|
||||||
|
cache = RedisCacheStore(default_ttl=600)
|
||||||
|
|
||||||
|
# Set with custom TTL
|
||||||
|
await cache.set("key1", "value1", ttl=300)
|
||||||
|
|
||||||
|
# Verify setex was called with custom TTL
|
||||||
|
call_args = mock_redis.setex.call_args
|
||||||
|
assert call_args[0][1] == 300 # TTL argument
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||||
|
async def test_delete(self, mock_redis_class, mock_pool_class):
|
||||||
|
"""Test deleting a key."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock()
|
||||||
|
mock_redis.delete = AsyncMock(return_value=1) # 1 key deleted
|
||||||
|
mock_redis_class.return_value = mock_redis
|
||||||
|
|
||||||
|
# Create cache
|
||||||
|
cache = RedisCacheStore()
|
||||||
|
|
||||||
|
# Delete key
|
||||||
|
deleted = await cache.delete("key1")
|
||||||
|
assert deleted is True
|
||||||
|
|
||||||
|
# Delete non-existent key
|
||||||
|
mock_redis.delete = AsyncMock(return_value=0)
|
||||||
|
deleted = await cache.delete("nonexistent")
|
||||||
|
assert deleted is False
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||||
|
async def test_exists(self, mock_redis_class, mock_pool_class):
|
||||||
|
"""Test checking key existence."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock()
|
||||||
|
mock_redis.exists = AsyncMock(return_value=1) # Exists
|
||||||
|
mock_redis_class.return_value = mock_redis
|
||||||
|
|
||||||
|
# Create cache
|
||||||
|
cache = RedisCacheStore()
|
||||||
|
|
||||||
|
# Check existing key
|
||||||
|
exists = await cache.exists("key1")
|
||||||
|
assert exists is True
|
||||||
|
|
||||||
|
# Check non-existent key
|
||||||
|
mock_redis.exists = AsyncMock(return_value=0)
|
||||||
|
exists = await cache.exists("nonexistent")
|
||||||
|
assert exists is False
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||||
|
async def test_ttl(self, mock_redis_class, mock_pool_class):
|
||||||
|
"""Test getting remaining TTL."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock()
|
||||||
|
mock_redis.ttl = AsyncMock(return_value=300)
|
||||||
|
mock_redis_class.return_value = mock_redis
|
||||||
|
|
||||||
|
# Create cache
|
||||||
|
cache = RedisCacheStore()
|
||||||
|
|
||||||
|
# Get TTL
|
||||||
|
ttl = await cache.ttl("key1")
|
||||||
|
assert ttl == 300
|
||||||
|
|
||||||
|
# Key doesn't exist
|
||||||
|
mock_redis.ttl = AsyncMock(return_value=-2)
|
||||||
|
ttl = await cache.ttl("nonexistent")
|
||||||
|
assert ttl is None
|
||||||
|
|
||||||
|
# Key has no TTL
|
||||||
|
mock_redis.ttl = AsyncMock(return_value=-1)
|
||||||
|
ttl = await cache.ttl("no_ttl_key")
|
||||||
|
assert ttl is None
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||||
|
async def test_clear(self, mock_redis_class, mock_pool_class):
|
||||||
|
"""Test clearing all entries."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock()
|
||||||
|
mock_redis.scan = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
(10, ["llama_stack:key1", "llama_stack:key2"]),
|
||||||
|
(0, ["llama_stack:key3"]), # cursor 0 indicates end
|
||||||
|
]
|
||||||
|
)
|
||||||
|
mock_redis.delete = AsyncMock(return_value=3)
|
||||||
|
mock_redis_class.return_value = mock_redis
|
||||||
|
|
||||||
|
# Create cache
|
||||||
|
cache = RedisCacheStore()
|
||||||
|
|
||||||
|
# Clear cache
|
||||||
|
await cache.clear()
|
||||||
|
|
||||||
|
# Verify scan and delete were called
|
||||||
|
assert mock_redis.scan.call_count == 2
|
||||||
|
mock_redis.delete.assert_called()
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||||
|
async def test_size(self, mock_redis_class, mock_pool_class):
|
||||||
|
"""Test getting cache size."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock()
|
||||||
|
mock_redis.scan = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
(10, ["llama_stack:key1", "llama_stack:key2"]),
|
||||||
|
(0, ["llama_stack:key3"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
mock_redis_class.return_value = mock_redis
|
||||||
|
|
||||||
|
# Create cache
|
||||||
|
cache = RedisCacheStore()
|
||||||
|
|
||||||
|
# Get size
|
||||||
|
size = await cache.size()
|
||||||
|
assert size == 3
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||||
|
async def test_retry_logic(self, mock_redis_class, mock_pool_class):
|
||||||
|
"""Test retry logic for transient failures."""
|
||||||
|
from redis.exceptions import TimeoutError as RedisTimeoutError
|
||||||
|
|
||||||
|
# Setup mocks - fail twice, then succeed
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
RedisTimeoutError("Timeout"),
|
||||||
|
RedisTimeoutError("Timeout"),
|
||||||
|
json.dumps("success"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
mock_redis_class.return_value = mock_redis
|
||||||
|
|
||||||
|
# Create cache with retries
|
||||||
|
cache = RedisCacheStore(max_retries=3)
|
||||||
|
|
||||||
|
# Should succeed after retries
|
||||||
|
value = await cache.get("key1")
|
||||||
|
assert value == "success"
|
||||||
|
assert mock_redis.get.call_count == 3
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||||
|
async def test_retry_exhaustion(self, mock_redis_class, mock_pool_class):
|
||||||
|
"""Test behavior when all retries are exhausted."""
|
||||||
|
from redis.exceptions import TimeoutError as RedisTimeoutError
|
||||||
|
|
||||||
|
# Setup mocks - always fail
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(side_effect=RedisTimeoutError("Timeout"))
|
||||||
|
mock_redis_class.return_value = mock_redis
|
||||||
|
|
||||||
|
# Create cache with limited retries
|
||||||
|
cache = RedisCacheStore(max_retries=2)
|
||||||
|
|
||||||
|
# Should raise CacheError after exhausting retries
|
||||||
|
with pytest.raises(CacheError, match="failed after .* attempts"):
|
||||||
|
await cache.get("key1")
|
||||||
|
|
||||||
|
# Should have tried 3 times (initial + 2 retries)
|
||||||
|
assert mock_redis.get.call_count == 3
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||||
|
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||||
|
async def test_close(self, mock_redis_class, mock_pool_class):
|
||||||
|
"""Test closing Redis connection."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock()
|
||||||
|
mock_redis.close = AsyncMock()
|
||||||
|
mock_redis_class.return_value = mock_redis
|
||||||
|
|
||||||
|
mock_pool = AsyncMock()
|
||||||
|
mock_pool.disconnect = AsyncMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
# Create cache and establish connection
|
||||||
|
cache = RedisCacheStore()
|
||||||
|
await cache._ensure_connection()
|
||||||
|
|
||||||
|
# Close connection
|
||||||
|
await cache.close()
|
||||||
|
|
||||||
|
# Verify cleanup
|
||||||
|
mock_redis.close.assert_called_once()
|
||||||
|
mock_pool.disconnect.assert_called_once()
|
||||||
|
|
||||||
|
def test_get_stats(self):
|
||||||
|
"""Test getting cache statistics."""
|
||||||
|
cache = RedisCacheStore(
|
||||||
|
host="redis.example.com",
|
||||||
|
port=6380,
|
||||||
|
db=1,
|
||||||
|
connection_pool_size=20,
|
||||||
|
timeout_ms=200,
|
||||||
|
default_ttl=300,
|
||||||
|
max_retries=5,
|
||||||
|
key_prefix="test:",
|
||||||
|
)
|
||||||
|
|
||||||
|
stats = cache.get_stats()
|
||||||
|
|
||||||
|
assert stats["host"] == "redis.example.com"
|
||||||
|
assert stats["port"] == 6380
|
||||||
|
assert stats["db"] == 1
|
||||||
|
assert stats["connection_pool_size"] == 20
|
||||||
|
assert stats["timeout_ms"] == 200
|
||||||
|
assert stats["default_ttl"] == 300
|
||||||
|
assert stats["max_retries"] == 5
|
||||||
|
assert stats["key_prefix"] == "test:"
|
||||||
|
assert stats["connected"] is False # Not connected yet
|
||||||
17
uv.lock
generated
17
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
||||||
version = 1
|
version = 1
|
||||||
revision = 2
|
revision = 3
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
"(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||||
|
|
@ -1996,6 +1996,7 @@ dependencies = [
|
||||||
{ name = "aiohttp" },
|
{ name = "aiohttp" },
|
||||||
{ name = "aiosqlite" },
|
{ name = "aiosqlite" },
|
||||||
{ name = "asyncpg" },
|
{ name = "asyncpg" },
|
||||||
|
{ name = "cachetools" },
|
||||||
{ name = "fastapi" },
|
{ name = "fastapi" },
|
||||||
{ name = "fire" },
|
{ name = "fire" },
|
||||||
{ name = "h11" },
|
{ name = "h11" },
|
||||||
|
|
@ -2013,6 +2014,7 @@ dependencies = [
|
||||||
{ name = "python-dotenv" },
|
{ name = "python-dotenv" },
|
||||||
{ name = "python-multipart" },
|
{ name = "python-multipart" },
|
||||||
{ name = "pyyaml" },
|
{ name = "pyyaml" },
|
||||||
|
{ name = "redis" },
|
||||||
{ name = "rich" },
|
{ name = "rich" },
|
||||||
{ name = "sqlalchemy", extra = ["asyncio"] },
|
{ name = "sqlalchemy", extra = ["asyncio"] },
|
||||||
{ name = "starlette" },
|
{ name = "starlette" },
|
||||||
|
|
@ -2147,6 +2149,7 @@ requires-dist = [
|
||||||
{ name = "aiohttp" },
|
{ name = "aiohttp" },
|
||||||
{ name = "aiosqlite", specifier = ">=0.21.0" },
|
{ name = "aiosqlite", specifier = ">=0.21.0" },
|
||||||
{ name = "asyncpg" },
|
{ name = "asyncpg" },
|
||||||
|
{ name = "cachetools", specifier = ">=5.5.0" },
|
||||||
{ name = "fastapi", specifier = ">=0.115.0,<1.0" },
|
{ name = "fastapi", specifier = ">=0.115.0,<1.0" },
|
||||||
{ name = "fire" },
|
{ name = "fire" },
|
||||||
{ name = "h11", specifier = ">=0.16.0" },
|
{ name = "h11", specifier = ">=0.16.0" },
|
||||||
|
|
@ -2166,6 +2169,7 @@ requires-dist = [
|
||||||
{ name = "python-multipart", specifier = ">=0.0.20" },
|
{ name = "python-multipart", specifier = ">=0.0.20" },
|
||||||
{ name = "pyyaml", specifier = ">=6.0" },
|
{ name = "pyyaml", specifier = ">=6.0" },
|
||||||
{ name = "pyyaml", specifier = ">=6.0.2" },
|
{ name = "pyyaml", specifier = ">=6.0.2" },
|
||||||
|
{ name = "redis", specifier = ">=5.2.0" },
|
||||||
{ name = "rich" },
|
{ name = "rich" },
|
||||||
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
|
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
|
||||||
{ name = "starlette" },
|
{ name = "starlette" },
|
||||||
|
|
@ -4398,6 +4402,15 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/ef/33/d8df6a2b214ffbe4138db9a1efe3248f67dc3c671f82308bea1582ecbbb7/qdrant_client-1.15.1-py3-none-any.whl", hash = "sha256:2b975099b378382f6ca1cfb43f0d59e541be6e16a5892f282a4b8de7eff5cb63", size = 337331, upload-time = "2025-07-31T19:35:17.539Z" },
|
{ url = "https://files.pythonhosted.org/packages/ef/33/d8df6a2b214ffbe4138db9a1efe3248f67dc3c671f82308bea1582ecbbb7/qdrant_client-1.15.1-py3-none-any.whl", hash = "sha256:2b975099b378382f6ca1cfb43f0d59e541be6e16a5892f282a4b8de7eff5cb63", size = 337331, upload-time = "2025-07-31T19:35:17.539Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "redis"
|
||||||
|
version = "7.0.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/57/8f/f125feec0b958e8d22c8f0b492b30b1991d9499a4315dfde466cf4289edc/redis-7.0.1.tar.gz", hash = "sha256:c949df947dca995dc68fdf5a7863950bf6df24f8d6022394585acc98e81624f1", size = 4755322, upload-time = "2025-10-27T14:34:00.33Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e9/97/9f22a33c475cda519f20aba6babb340fb2f2254a02fb947816960d1e669a/redis-7.0.1-py3-none-any.whl", hash = "sha256:4977af3c7d67f8f0eb8b6fec0dafc9605db9343142f634041fb0235f67c0588a", size = 339938, upload-time = "2025-10-27T14:33:58.553Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "referencing"
|
name = "referencing"
|
||||||
version = "0.36.2"
|
version = "0.36.2"
|
||||||
|
|
@ -4656,6 +4669,8 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/6b/fa/3234f913fe9a6525a7b97c6dad1f51e72b917e6872e051a5e2ffd8b16fbb/ruamel.yaml.clib-0.2.14-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:70eda7703b8126f5e52fcf276e6c0f40b0d314674f896fc58c47b0aef2b9ae83", size = 137970, upload-time = "2025-09-22T19:51:09.472Z" },
|
{ url = "https://files.pythonhosted.org/packages/6b/fa/3234f913fe9a6525a7b97c6dad1f51e72b917e6872e051a5e2ffd8b16fbb/ruamel.yaml.clib-0.2.14-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:70eda7703b8126f5e52fcf276e6c0f40b0d314674f896fc58c47b0aef2b9ae83", size = 137970, upload-time = "2025-09-22T19:51:09.472Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/ef/ec/4edbf17ac2c87fa0845dd366ef8d5852b96eb58fcd65fc1ecf5fe27b4641/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a0cb71ccc6ef9ce36eecb6272c81afdc2f565950cdcec33ae8e6cd8f7fc86f27", size = 739639, upload-time = "2025-09-22T19:51:10.566Z" },
|
{ url = "https://files.pythonhosted.org/packages/ef/ec/4edbf17ac2c87fa0845dd366ef8d5852b96eb58fcd65fc1ecf5fe27b4641/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a0cb71ccc6ef9ce36eecb6272c81afdc2f565950cdcec33ae8e6cd8f7fc86f27", size = 739639, upload-time = "2025-09-22T19:51:10.566Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/15/18/b0e1fafe59051de9e79cdd431863b03593ecfa8341c110affad7c8121efc/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e7cb9ad1d525d40f7d87b6df7c0ff916a66bc52cb61b66ac1b2a16d0c1b07640", size = 764456, upload-time = "2025-09-22T19:51:11.736Z" },
|
{ url = "https://files.pythonhosted.org/packages/15/18/b0e1fafe59051de9e79cdd431863b03593ecfa8341c110affad7c8121efc/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e7cb9ad1d525d40f7d87b6df7c0ff916a66bc52cb61b66ac1b2a16d0c1b07640", size = 764456, upload-time = "2025-09-22T19:51:11.736Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e7/cd/150fdb96b8fab27fe08d8a59fe67554568727981806e6bc2677a16081ec7/ruamel_yaml_clib-0.2.14-cp314-cp314-win32.whl", hash = "sha256:9b4104bf43ca0cd4e6f738cb86326a3b2f6eef00f417bd1e7efb7bdffe74c539", size = 102394, upload-time = "2025-11-14T21:57:36.703Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bd/e6/a3fa40084558c7e1dc9546385f22a93949c890a8b2e445b2ba43935f51da/ruamel_yaml_clib-0.2.14-cp314-cp314-win_amd64.whl", hash = "sha256:13997d7d354a9890ea1ec5937a219817464e5cc344805b37671562a401ca3008", size = 122673, upload-time = "2025-11-14T21:57:38.177Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue