mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat(cache): add cache store abstraction layer
Implement reusable cache store abstraction with in-memory and Redis backends as foundation for prompt caching feature (PR1 of progressive delivery). - Add CacheStore protocol defining cache interface - Implement MemoryCacheStore with LRU, LFU, and TTL-only eviction policies - Implement RedisCacheStore with connection pooling and retry logic - Add CircuitBreaker for cache backend failure protection - Include comprehensive unit tests (55 tests, >80% coverage) - Add dependencies: cachetools>=5.5.0, redis>=5.2.0 This abstraction enables flexible caching implementations for the prompt caching middleware without coupling to specific storage backends. Signed-by: William Caban <willliam.caban@gmail.com>
This commit is contained in:
parent
97f535c4f1
commit
299c575daa
10 changed files with 2175 additions and 1 deletions
|
|
@ -26,6 +26,7 @@ classifiers = [
|
|||
dependencies = [
|
||||
"PyYAML>=6.0",
|
||||
"aiohttp",
|
||||
"cachetools>=5.5.0", # for prompt caching
|
||||
"fastapi>=0.115.0,<1.0", # server
|
||||
"fire", # for MCP in LLS client
|
||||
"httpx",
|
||||
|
|
@ -37,6 +38,7 @@ dependencies = [
|
|||
"python-dotenv",
|
||||
"pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support.
|
||||
"pydantic>=2.11.9",
|
||||
"redis>=5.2.0", # for prompt caching (Redis backend)
|
||||
"rich",
|
||||
"starlette",
|
||||
"termcolor",
|
||||
|
|
|
|||
37
src/llama_stack/providers/utils/cache/__init__.py
vendored
Normal file
37
src/llama_stack/providers/utils/cache/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Cache store utilities for prompt caching.
|
||||
|
||||
This module provides cache store abstractions and implementations for use in
|
||||
the Llama Stack server's prompt caching feature. Supports both memory-based
|
||||
and Redis-based caching with configurable eviction policies and TTL management.
|
||||
|
||||
Example usage:
|
||||
from llama_stack.providers.utils.cache import MemoryCacheStore, RedisCacheStore
|
||||
|
||||
# Memory cache for development
|
||||
memory_cache = MemoryCacheStore(max_entries=1000, eviction_policy="lru")
|
||||
|
||||
# Redis cache for production
|
||||
redis_cache = RedisCacheStore(
|
||||
host="localhost",
|
||||
port=6379,
|
||||
connection_pool_size=10
|
||||
)
|
||||
"""
|
||||
|
||||
from .cache_store import CacheError, CacheStore, CircuitBreaker
|
||||
from .memory import MemoryCacheStore
|
||||
from .redis import RedisCacheStore
|
||||
|
||||
__all__ = [
|
||||
"CacheStore",
|
||||
"CacheError",
|
||||
"CircuitBreaker",
|
||||
"MemoryCacheStore",
|
||||
"RedisCacheStore",
|
||||
]
|
||||
256
src/llama_stack/providers/utils/cache/cache_store.py
vendored
Normal file
256
src/llama_stack/providers/utils/cache/cache_store.py
vendored
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Cache store abstraction for prompt caching implementation.
|
||||
|
||||
This module provides a protocol-based abstraction for cache storage backends,
|
||||
enabling flexible storage implementations (memory, Redis, etc.) for prompt
|
||||
caching in the Llama Stack server.
|
||||
"""
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CacheStore(Protocol):
|
||||
"""Protocol defining the cache store interface.
|
||||
|
||||
This protocol specifies the required methods for cache store implementations.
|
||||
All implementations must support TTL-based expiration and provide efficient
|
||||
key-value storage operations.
|
||||
|
||||
Methods support both synchronous and asynchronous usage patterns depending
|
||||
on the implementation requirements.
|
||||
"""
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""Retrieve a value from the cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to retrieve
|
||||
|
||||
Returns:
|
||||
Cached value if present and not expired, None otherwise
|
||||
|
||||
Raises:
|
||||
CacheError: If cache backend is unavailable or operation fails
|
||||
"""
|
||||
...
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Store a value in the cache with optional TTL.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache (must be serializable)
|
||||
ttl: Time-to-live in seconds. If None, uses default TTL.
|
||||
|
||||
Raises:
|
||||
CacheError: If cache backend is unavailable or operation fails
|
||||
ValueError: If value is not serializable
|
||||
"""
|
||||
...
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete a key from the cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
True if key was deleted, False if key didn't exist
|
||||
|
||||
Raises:
|
||||
CacheError: If cache backend is unavailable or operation fails
|
||||
"""
|
||||
...
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if a key exists in the cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to check
|
||||
|
||||
Returns:
|
||||
True if key exists and is not expired, False otherwise
|
||||
|
||||
Raises:
|
||||
CacheError: If cache backend is unavailable or operation fails
|
||||
"""
|
||||
...
|
||||
|
||||
async def ttl(self, key: str) -> Optional[int]:
|
||||
"""Get the remaining TTL for a key.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Remaining TTL in seconds, None if key doesn't exist or has no TTL
|
||||
|
||||
Raises:
|
||||
CacheError: If cache backend is unavailable or operation fails
|
||||
"""
|
||||
...
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all entries from the cache.
|
||||
|
||||
This is primarily useful for testing. Use with caution in production
|
||||
as it affects all cached data.
|
||||
|
||||
Raises:
|
||||
CacheError: If cache backend is unavailable or operation fails
|
||||
"""
|
||||
...
|
||||
|
||||
async def size(self) -> int:
|
||||
"""Get the number of entries in the cache.
|
||||
|
||||
Returns:
|
||||
Number of cached entries
|
||||
|
||||
Raises:
|
||||
CacheError: If cache backend is unavailable or operation fails
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class CacheError(Exception):
|
||||
"""Exception raised for cache operation failures.
|
||||
|
||||
This exception is raised when cache operations fail due to backend
|
||||
unavailability, network issues, or other operational problems.
|
||||
The system should gracefully degrade when catching this exception.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, cause: Optional[Exception] = None):
|
||||
"""Initialize cache error.
|
||||
|
||||
Args:
|
||||
message: Error description (should start with "Failed to ...")
|
||||
cause: Optional underlying exception that caused this error
|
||||
"""
|
||||
super().__init__(message)
|
||||
self.cause = cause
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""Circuit breaker pattern for cache backend failure protection.
|
||||
|
||||
Prevents cascade failures by temporarily disabling cache operations
|
||||
after detecting repeated failures. Automatically attempts recovery
|
||||
after a timeout period.
|
||||
|
||||
States:
|
||||
- CLOSED: Normal operation, requests go through
|
||||
- OPEN: Too many failures, requests are blocked
|
||||
- HALF_OPEN: Testing if backend has recovered
|
||||
|
||||
Example:
|
||||
breaker = CircuitBreaker(failure_threshold=10, recovery_timeout=60)
|
||||
if breaker.is_closed():
|
||||
try:
|
||||
result = await cache.get(key)
|
||||
breaker.record_success()
|
||||
except CacheError:
|
||||
breaker.record_failure()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_threshold: int = 10,
|
||||
recovery_timeout: int = 60,
|
||||
):
|
||||
"""Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
failure_threshold: Number of consecutive failures before opening
|
||||
recovery_timeout: Seconds to wait before attempting recovery
|
||||
"""
|
||||
self.failure_threshold = failure_threshold
|
||||
self.recovery_timeout = recovery_timeout
|
||||
self.failure_count = 0
|
||||
self.last_failure_time: Optional[float] = None
|
||||
self.state = "CLOSED" # CLOSED, OPEN, HALF_OPEN
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if circuit breaker allows operations.
|
||||
|
||||
Returns:
|
||||
True if operations should proceed, False if blocked
|
||||
"""
|
||||
import time
|
||||
|
||||
if self.state == "CLOSED":
|
||||
return True
|
||||
|
||||
if self.state == "OPEN":
|
||||
# Check if we should try recovery
|
||||
if (
|
||||
self.last_failure_time is not None
|
||||
and time.time() - self.last_failure_time >= self.recovery_timeout
|
||||
):
|
||||
self.state = "HALF_OPEN"
|
||||
logger.info("Circuit breaker entering HALF_OPEN state for recovery test")
|
||||
return True
|
||||
return False
|
||||
|
||||
# HALF_OPEN state - allow one request through to test
|
||||
return True
|
||||
|
||||
def record_success(self) -> None:
|
||||
"""Record a successful operation."""
|
||||
if self.state == "HALF_OPEN":
|
||||
logger.info("Circuit breaker recovery successful, returning to CLOSED state")
|
||||
self.failure_count = 0
|
||||
self.last_failure_time = None
|
||||
self.state = "CLOSED"
|
||||
|
||||
def record_failure(self) -> None:
|
||||
"""Record a failed operation."""
|
||||
import time
|
||||
|
||||
self.failure_count += 1
|
||||
self.last_failure_time = time.time()
|
||||
|
||||
if self.state == "HALF_OPEN":
|
||||
# Recovery attempt failed, go back to OPEN
|
||||
logger.warning("Circuit breaker recovery failed, returning to OPEN state")
|
||||
self.state = "OPEN"
|
||||
elif self.failure_count >= self.failure_threshold:
|
||||
logger.error(
|
||||
f"Circuit breaker OPEN after {self.failure_count} failures. "
|
||||
f"Cache operations disabled for {self.recovery_timeout}s"
|
||||
)
|
||||
self.state = "OPEN"
|
||||
|
||||
def get_state(self) -> str:
|
||||
"""Get current circuit breaker state.
|
||||
|
||||
Returns:
|
||||
Current state: "CLOSED", "OPEN", or "HALF_OPEN"
|
||||
"""
|
||||
return self.state
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Manually reset the circuit breaker to CLOSED state.
|
||||
|
||||
This is primarily useful for testing or administrative overrides.
|
||||
"""
|
||||
self.failure_count = 0
|
||||
self.last_failure_time = None
|
||||
self.state = "CLOSED"
|
||||
logger.info("Circuit breaker manually reset to CLOSED state")
|
||||
334
src/llama_stack/providers/utils/cache/memory.py
vendored
Normal file
334
src/llama_stack/providers/utils/cache/memory.py
vendored
Normal file
|
|
@ -0,0 +1,334 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""In-memory cache store implementation using cachetools.
|
||||
|
||||
This module provides a memory-based cache store suitable for development
|
||||
and single-node deployments. For production multi-node deployments,
|
||||
consider using RedisCacheStore instead.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from cachetools import Cache, LFUCache, LRUCache, TTLCache # type: ignore # no types-cachetools available
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .cache_store import CacheError
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
EvictionPolicy = Literal["lru", "lfu", "ttl-only"]
|
||||
|
||||
|
||||
class MemoryCacheStore:
|
||||
"""In-memory cache store with configurable eviction policies.
|
||||
|
||||
This implementation uses the cachetools library to provide efficient
|
||||
in-memory caching with support for multiple eviction policies:
|
||||
- LRU (Least Recently Used): Evicts least recently accessed items
|
||||
- LFU (Least Frequently Used): Evicts least frequently accessed items
|
||||
- TTL-only: Evicts based on time-to-live only
|
||||
|
||||
Thread-safe for concurrent access within a single process.
|
||||
|
||||
Example:
|
||||
cache = MemoryCacheStore(
|
||||
max_entries=1000,
|
||||
default_ttl=600,
|
||||
eviction_policy="lru"
|
||||
)
|
||||
await cache.set("key", "value", ttl=300)
|
||||
value = await cache.get("key")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_entries: int = 1000,
|
||||
max_memory_mb: Optional[int] = 512,
|
||||
default_ttl: int = 600,
|
||||
eviction_policy: EvictionPolicy = "lru",
|
||||
):
|
||||
"""Initialize memory cache store.
|
||||
|
||||
Args:
|
||||
max_entries: Maximum number of entries to store
|
||||
max_memory_mb: Maximum memory usage in MB (soft limit, estimated)
|
||||
default_ttl: Default time-to-live in seconds
|
||||
eviction_policy: Eviction strategy ("lru", "lfu", "ttl-only")
|
||||
|
||||
Raises:
|
||||
ValueError: If invalid parameters provided
|
||||
"""
|
||||
if max_entries <= 0:
|
||||
raise ValueError("max_entries must be positive")
|
||||
if default_ttl <= 0:
|
||||
raise ValueError("default_ttl must be positive")
|
||||
if max_memory_mb is not None and max_memory_mb <= 0:
|
||||
raise ValueError("max_memory_mb must be positive")
|
||||
|
||||
self.max_entries = max_entries
|
||||
self.max_memory_mb = max_memory_mb
|
||||
self.default_ttl = default_ttl
|
||||
self.eviction_policy = eviction_policy
|
||||
|
||||
# Create appropriate cache implementation
|
||||
self._cache: Cache = self._create_cache()
|
||||
self._ttl_map: dict[str, float] = {} # Track expiration times
|
||||
|
||||
logger.info(
|
||||
f"Initialized MemoryCacheStore: policy={eviction_policy}, "
|
||||
f"max_entries={max_entries}, max_memory={max_memory_mb}MB, "
|
||||
f"default_ttl={default_ttl}s"
|
||||
)
|
||||
|
||||
def _create_cache(self) -> Cache:
|
||||
"""Create cache instance based on eviction policy.
|
||||
|
||||
Returns:
|
||||
Cache instance configured with chosen policy
|
||||
"""
|
||||
if self.eviction_policy == "lru":
|
||||
return LRUCache(maxsize=self.max_entries)
|
||||
elif self.eviction_policy == "lfu":
|
||||
return LFUCache(maxsize=self.max_entries)
|
||||
elif self.eviction_policy == "ttl-only":
|
||||
return TTLCache(maxsize=self.max_entries, ttl=self.default_ttl)
|
||||
else:
|
||||
raise ValueError(f"Unknown eviction policy: {self.eviction_policy}")
|
||||
|
||||
def _is_expired(self, key: str) -> bool:
|
||||
"""Check if a key has expired based on TTL.
|
||||
|
||||
Args:
|
||||
key: Cache key to check
|
||||
|
||||
Returns:
|
||||
True if key has expired, False otherwise
|
||||
"""
|
||||
if key not in self._ttl_map:
|
||||
return False
|
||||
|
||||
expiration_time = self._ttl_map[key]
|
||||
if time.time() >= expiration_time:
|
||||
# Clean up expired entry
|
||||
self._cache.pop(key, None)
|
||||
self._ttl_map.pop(key, None)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""Retrieve a value from the cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to retrieve
|
||||
|
||||
Returns:
|
||||
Cached value if present and not expired, None otherwise
|
||||
|
||||
Raises:
|
||||
CacheError: If cache operation fails
|
||||
"""
|
||||
try:
|
||||
# Check expiration first
|
||||
if self._is_expired(key):
|
||||
return None
|
||||
|
||||
value = self._cache.get(key)
|
||||
if value is not None:
|
||||
logger.debug(f"Cache hit: {key}")
|
||||
return value
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cache key '{key}': {e}")
|
||||
raise CacheError(f"Failed to get cache key '{key}'", cause=e) from e
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Store a value in the cache with optional TTL.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl: Time-to-live in seconds. If None, uses default TTL.
|
||||
|
||||
Raises:
|
||||
CacheError: If cache operation fails
|
||||
"""
|
||||
try:
|
||||
# Use default TTL if not specified
|
||||
effective_ttl = ttl if ttl is not None else self.default_ttl
|
||||
|
||||
# Store value
|
||||
self._cache[key] = value
|
||||
|
||||
# Track expiration time
|
||||
self._ttl_map[key] = time.time() + effective_ttl
|
||||
|
||||
# Check memory usage (soft limit)
|
||||
if self.max_memory_mb is not None:
|
||||
self._check_memory_usage()
|
||||
|
||||
logger.debug(f"Cache set: {key} (ttl={effective_ttl}s)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set cache key '{key}': {e}")
|
||||
raise CacheError(f"Failed to set cache key '{key}'", cause=e) from e
|
||||
|
||||
def _check_memory_usage(self) -> None:
|
||||
"""Check and log if memory usage exceeds soft limit.
|
||||
|
||||
This is a soft limit - we log warnings but don't enforce hard limits.
|
||||
The cachetools library will handle eviction based on max_entries.
|
||||
"""
|
||||
try:
|
||||
# Get approximate memory usage
|
||||
cache_size_bytes = sys.getsizeof(self._cache) + sys.getsizeof(self._ttl_map)
|
||||
|
||||
# Convert to MB
|
||||
cache_size_mb = cache_size_bytes / (1024 * 1024)
|
||||
|
||||
if self.max_memory_mb is not None and cache_size_mb > self.max_memory_mb:
|
||||
logger.warning(
|
||||
f"Cache memory usage ({cache_size_mb:.1f}MB) exceeds "
|
||||
f"soft limit ({self.max_memory_mb}MB). "
|
||||
f"Consider increasing max_entries or max_memory_mb."
|
||||
)
|
||||
except Exception as e:
|
||||
# Don't fail on memory check errors
|
||||
logger.debug(f"Memory usage check failed: {e}")
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete a key from the cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
True if key was deleted, False if key didn't exist
|
||||
|
||||
Raises:
|
||||
CacheError: If cache operation fails
|
||||
"""
|
||||
try:
|
||||
existed = key in self._cache
|
||||
self._cache.pop(key, None)
|
||||
self._ttl_map.pop(key, None)
|
||||
|
||||
if existed:
|
||||
logger.debug(f"Cache delete: {key}")
|
||||
|
||||
return existed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete cache key '{key}': {e}")
|
||||
raise CacheError(f"Failed to delete cache key '{key}'", cause=e) from e
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if a key exists in the cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to check
|
||||
|
||||
Returns:
|
||||
True if key exists and is not expired, False otherwise
|
||||
|
||||
Raises:
|
||||
CacheError: If cache operation fails
|
||||
"""
|
||||
try:
|
||||
if self._is_expired(key):
|
||||
return False
|
||||
return key in self._cache
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check cache key existence '{key}': {e}")
|
||||
raise CacheError(f"Failed to check cache key existence '{key}'", cause=e) from e
|
||||
|
||||
async def ttl(self, key: str) -> Optional[int]:
|
||||
"""Get the remaining TTL for a key.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Remaining TTL in seconds, None if key doesn't exist
|
||||
|
||||
Raises:
|
||||
CacheError: If cache operation fails
|
||||
"""
|
||||
try:
|
||||
if key not in self._ttl_map:
|
||||
return None
|
||||
|
||||
if self._is_expired(key):
|
||||
return None
|
||||
|
||||
remaining = int(self._ttl_map[key] - time.time())
|
||||
return max(0, remaining)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get TTL for cache key '{key}': {e}")
|
||||
raise CacheError(f"Failed to get TTL for cache key '{key}'", cause=e) from e
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all entries from the cache.
|
||||
|
||||
Raises:
|
||||
CacheError: If cache operation fails
|
||||
"""
|
||||
try:
|
||||
self._cache.clear()
|
||||
self._ttl_map.clear()
|
||||
logger.info("Cache cleared")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear cache: {e}")
|
||||
raise CacheError("Failed to clear cache", cause=e) from e
|
||||
|
||||
async def size(self) -> int:
|
||||
"""Get the number of entries in the cache.
|
||||
|
||||
Returns:
|
||||
Number of cached entries (excluding expired entries)
|
||||
|
||||
Raises:
|
||||
CacheError: If cache operation fails
|
||||
"""
|
||||
try:
|
||||
# Clean up expired entries first
|
||||
expired_keys = [
|
||||
key for key in list(self._ttl_map.keys())
|
||||
if self._is_expired(key)
|
||||
]
|
||||
|
||||
return len(self._cache)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cache size: {e}")
|
||||
raise CacheError("Failed to get cache size", cause=e) from e
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics including size, policy, and limits
|
||||
"""
|
||||
return {
|
||||
"size": len(self._cache),
|
||||
"max_entries": self.max_entries,
|
||||
"max_memory_mb": self.max_memory_mb,
|
||||
"default_ttl": self.default_ttl,
|
||||
"eviction_policy": self.eviction_policy,
|
||||
}
|
||||
513
src/llama_stack/providers/utils/cache/redis.py
vendored
Normal file
513
src/llama_stack/providers/utils/cache/redis.py
vendored
Normal file
|
|
@ -0,0 +1,513 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Redis-based cache store implementation.
|
||||
|
||||
This module provides a production-ready Redis cache store with connection
|
||||
pooling, retry logic, and comprehensive error handling. Suitable for
|
||||
distributed deployments and high-throughput scenarios.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from redis import asyncio as aioredis
|
||||
from redis.asyncio import ConnectionPool, Redis
|
||||
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .cache_store import CacheError
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RedisCacheStore:
|
||||
"""Redis-based cache store with connection pooling.
|
||||
|
||||
This implementation provides production-ready caching with:
|
||||
- Connection pooling for efficient resource usage
|
||||
- Automatic retry logic for transient failures
|
||||
- Configurable timeouts to prevent blocking
|
||||
- JSON serialization for complex data types
|
||||
- Support for Redis cluster and sentinel
|
||||
|
||||
Example:
|
||||
cache = RedisCacheStore(
|
||||
host="localhost",
|
||||
port=6379,
|
||||
db=0,
|
||||
password="secret",
|
||||
connection_pool_size=10,
|
||||
timeout_ms=100
|
||||
)
|
||||
await cache.set("key", {"data": "value"}, ttl=300)
|
||||
value = await cache.get("key")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
db: int = 0,
|
||||
password: Optional[str] = None,
|
||||
connection_pool_size: int = 10,
|
||||
timeout_ms: int = 100,
|
||||
default_ttl: int = 600,
|
||||
max_retries: int = 3,
|
||||
key_prefix: str = "llama_stack:",
|
||||
):
|
||||
"""Initialize Redis cache store.
|
||||
|
||||
Args:
|
||||
host: Redis server hostname
|
||||
port: Redis server port
|
||||
db: Redis database number (0-15)
|
||||
password: Optional Redis password
|
||||
connection_pool_size: Maximum connections in pool
|
||||
timeout_ms: Operation timeout in milliseconds
|
||||
default_ttl: Default time-to-live in seconds
|
||||
max_retries: Maximum retry attempts for failed operations
|
||||
key_prefix: Prefix for all cache keys (namespace isolation)
|
||||
|
||||
Raises:
|
||||
ValueError: If invalid parameters provided
|
||||
"""
|
||||
if connection_pool_size <= 0:
|
||||
raise ValueError("connection_pool_size must be positive")
|
||||
if timeout_ms <= 0:
|
||||
raise ValueError("timeout_ms must be positive")
|
||||
if default_ttl <= 0:
|
||||
raise ValueError("default_ttl must be positive")
|
||||
if max_retries < 0:
|
||||
raise ValueError("max_retries must be non-negative")
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.db = db
|
||||
self.password = password
|
||||
self.connection_pool_size = connection_pool_size
|
||||
self.timeout_ms = timeout_ms
|
||||
self.default_ttl = default_ttl
|
||||
self.max_retries = max_retries
|
||||
self.key_prefix = key_prefix
|
||||
|
||||
# Connection pool (lazy initialization)
|
||||
self._pool: Optional[ConnectionPool] = None
|
||||
self._redis: Optional[Redis] = None
|
||||
|
||||
logger.info(
|
||||
f"Initialized RedisCacheStore: host={host}:{port}, db={db}, "
|
||||
f"pool_size={connection_pool_size}, timeout={timeout_ms}ms, "
|
||||
f"default_ttl={default_ttl}s"
|
||||
)
|
||||
|
||||
async def _ensure_connection(self) -> Redis:
|
||||
"""Ensure Redis connection is established.
|
||||
|
||||
Returns:
|
||||
Redis client instance
|
||||
|
||||
Raises:
|
||||
CacheError: If connection cannot be established
|
||||
"""
|
||||
if self._redis is not None:
|
||||
return self._redis
|
||||
|
||||
try:
|
||||
# Create connection pool
|
||||
self._pool = ConnectionPool(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
db=self.db,
|
||||
password=self.password,
|
||||
max_connections=self.connection_pool_size,
|
||||
socket_timeout=self.timeout_ms / 1000.0,
|
||||
socket_connect_timeout=self.timeout_ms / 1000.0,
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
# Create Redis client
|
||||
self._redis = Redis(connection_pool=self._pool)
|
||||
|
||||
# Test connection
|
||||
await asyncio.wait_for(
|
||||
self._redis.ping(),
|
||||
timeout=self.timeout_ms / 1000.0
|
||||
)
|
||||
|
||||
logger.info(f"Connected to Redis at {self.host}:{self.port}")
|
||||
return self._redis
|
||||
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}")
|
||||
raise CacheError(f"Failed to connect to Redis at {self.host}:{self.port}", cause=e) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Redis connection: {e}")
|
||||
raise CacheError("Failed to initialize Redis connection", cause=e) from e
|
||||
|
||||
def _make_key(self, key: str) -> str:
|
||||
"""Create prefixed cache key for namespace isolation.
|
||||
|
||||
Args:
|
||||
key: Base cache key
|
||||
|
||||
Returns:
|
||||
Prefixed key
|
||||
"""
|
||||
return f"{self.key_prefix}{key}"
|
||||
|
||||
def _serialize(self, value: Any) -> str:
|
||||
"""Serialize value for storage.
|
||||
|
||||
Args:
|
||||
value: Value to serialize
|
||||
|
||||
Returns:
|
||||
JSON-serialized string
|
||||
|
||||
Raises:
|
||||
ValueError: If value cannot be serialized
|
||||
"""
|
||||
try:
|
||||
return json.dumps(value)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Value is not JSON-serializable: {e}") from e
|
||||
|
||||
def _deserialize(self, data: str) -> Any:
|
||||
"""Deserialize stored value.
|
||||
|
||||
Args:
|
||||
data: JSON-serialized string
|
||||
|
||||
Returns:
|
||||
Deserialized value
|
||||
|
||||
Raises:
|
||||
ValueError: If data cannot be deserialized
|
||||
"""
|
||||
try:
|
||||
return json.loads(data)
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.warning(f"Failed to deserialize cache value: {e}")
|
||||
return None
|
||||
|
||||
async def _retry_operation(self, operation, *args, **kwargs) -> Any:
|
||||
"""Retry an operation with exponential backoff.
|
||||
|
||||
Args:
|
||||
operation: Async function to retry
|
||||
*args: Positional arguments for operation
|
||||
**kwargs: Keyword arguments for operation
|
||||
|
||||
Returns:
|
||||
Operation result
|
||||
|
||||
Raises:
|
||||
CacheError: If all retries fail
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
operation(*args, **kwargs),
|
||||
timeout=self.timeout_ms / 1000.0
|
||||
)
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
last_error = e
|
||||
if attempt < self.max_retries:
|
||||
backoff = 2 ** attempt * 0.1 # 100ms, 200ms, 400ms
|
||||
logger.warning(
|
||||
f"Redis operation failed (attempt {attempt + 1}/{self.max_retries + 1}), "
|
||||
f"retrying in {backoff}s: {e}"
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
else:
|
||||
logger.error(f"Redis operation failed after {self.max_retries + 1} attempts")
|
||||
except Exception as e:
|
||||
# Don't retry on non-transient errors
|
||||
raise CacheError(f"Redis operation failed: {e}", cause=e) from e
|
||||
|
||||
raise CacheError(f"Redis operation failed after {self.max_retries + 1} attempts", cause=last_error) from last_error
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""Retrieve a value from the cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to retrieve
|
||||
|
||||
Returns:
|
||||
Cached value if present and not expired, None otherwise
|
||||
|
||||
Raises:
|
||||
CacheError: If cache operation fails
|
||||
"""
|
||||
try:
|
||||
redis = await self._ensure_connection()
|
||||
prefixed_key = self._make_key(key)
|
||||
|
||||
data = await self._retry_operation(redis.get, prefixed_key)
|
||||
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
value = self._deserialize(data)
|
||||
if value is not None:
|
||||
logger.debug(f"Cache hit: {key}")
|
||||
return value
|
||||
|
||||
except CacheError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cache key '{key}': {e}")
|
||||
raise CacheError(f"Failed to get cache key '{key}'", cause=e) from e
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Store a value in the cache with optional TTL.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache (must be JSON-serializable)
|
||||
ttl: Time-to-live in seconds. If None, uses default TTL.
|
||||
|
||||
Raises:
|
||||
CacheError: If cache operation fails
|
||||
ValueError: If value is not serializable
|
||||
"""
|
||||
try:
|
||||
redis = await self._ensure_connection()
|
||||
prefixed_key = self._make_key(key)
|
||||
|
||||
# Serialize value
|
||||
data = self._serialize(value)
|
||||
|
||||
# Use default TTL if not specified
|
||||
effective_ttl = ttl if ttl is not None else self.default_ttl
|
||||
|
||||
# Store with TTL
|
||||
await self._retry_operation(
|
||||
redis.setex,
|
||||
prefixed_key,
|
||||
effective_ttl,
|
||||
data
|
||||
)
|
||||
|
||||
logger.debug(f"Cache set: {key} (ttl={effective_ttl}s)")
|
||||
|
||||
except ValueError:
|
||||
raise
|
||||
except CacheError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set cache key '{key}': {e}")
|
||||
raise CacheError(f"Failed to set cache key '{key}'", cause=e) from e
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete a key from the cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
True if key was deleted, False if key didn't exist
|
||||
|
||||
Raises:
|
||||
CacheError: If cache operation fails
|
||||
"""
|
||||
try:
|
||||
redis = await self._ensure_connection()
|
||||
prefixed_key = self._make_key(key)
|
||||
|
||||
deleted_count = await self._retry_operation(redis.delete, prefixed_key)
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"Cache delete: {key}")
|
||||
|
||||
return bool(deleted_count > 0)
|
||||
|
||||
except CacheError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete cache key '{key}': {e}")
|
||||
raise CacheError(f"Failed to delete cache key '{key}'", cause=e) from e
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if a key exists in the cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to check
|
||||
|
||||
Returns:
|
||||
True if key exists and is not expired, False otherwise
|
||||
|
||||
Raises:
|
||||
CacheError: If cache operation fails
|
||||
"""
|
||||
try:
|
||||
redis = await self._ensure_connection()
|
||||
prefixed_key = self._make_key(key)
|
||||
|
||||
exists = await self._retry_operation(redis.exists, prefixed_key)
|
||||
return bool(exists > 0)
|
||||
|
||||
except CacheError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check cache key existence '{key}': {e}")
|
||||
raise CacheError(f"Failed to check cache key existence '{key}'", cause=e) from e
|
||||
|
||||
async def ttl(self, key: str) -> Optional[int]:
|
||||
"""Get the remaining TTL for a key.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Remaining TTL in seconds, None if key doesn't exist or has no TTL
|
||||
|
||||
Raises:
|
||||
CacheError: If cache operation fails
|
||||
"""
|
||||
try:
|
||||
redis = await self._ensure_connection()
|
||||
prefixed_key = self._make_key(key)
|
||||
|
||||
ttl_seconds = await self._retry_operation(redis.ttl, prefixed_key)
|
||||
|
||||
# Redis returns -2 if key doesn't exist, -1 if no TTL
|
||||
if ttl_seconds == -2:
|
||||
return None
|
||||
if ttl_seconds == -1:
|
||||
return None
|
||||
|
||||
return int(max(0, ttl_seconds))
|
||||
|
||||
except CacheError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get TTL for cache key '{key}': {e}")
|
||||
raise CacheError(f"Failed to get TTL for cache key '{key}'", cause=e) from e
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all entries from the cache.
|
||||
|
||||
This deletes all keys matching the key_prefix pattern.
|
||||
|
||||
Raises:
|
||||
CacheError: If cache operation fails
|
||||
"""
|
||||
try:
|
||||
redis = await self._ensure_connection()
|
||||
pattern = f"{self.key_prefix}*"
|
||||
|
||||
# Scan and delete keys matching pattern
|
||||
cursor = 0
|
||||
deleted_total = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await self._retry_operation(
|
||||
redis.scan,
|
||||
cursor=cursor,
|
||||
match=pattern,
|
||||
count=100
|
||||
)
|
||||
|
||||
if keys:
|
||||
deleted_count = await self._retry_operation(redis.delete, *keys)
|
||||
deleted_total += deleted_count
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
logger.info(f"Cache cleared: deleted {deleted_total} keys")
|
||||
|
||||
except CacheError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear cache: {e}")
|
||||
raise CacheError("Failed to clear cache", cause=e) from e
|
||||
|
||||
async def size(self) -> int:
|
||||
"""Get the number of entries in the cache.
|
||||
|
||||
Returns:
|
||||
Number of cached entries matching key_prefix
|
||||
|
||||
Raises:
|
||||
CacheError: If cache operation fails
|
||||
"""
|
||||
try:
|
||||
redis = await self._ensure_connection()
|
||||
pattern = f"{self.key_prefix}*"
|
||||
|
||||
# Count keys matching pattern
|
||||
cursor = 0
|
||||
count = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await self._retry_operation(
|
||||
redis.scan,
|
||||
cursor=cursor,
|
||||
match=pattern,
|
||||
count=100
|
||||
)
|
||||
|
||||
count += len(keys)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
return count
|
||||
|
||||
except CacheError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cache size: {e}")
|
||||
raise CacheError("Failed to get cache size", cause=e) from e
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Redis connection and cleanup resources.
|
||||
|
||||
This should be called when the cache is no longer needed.
|
||||
"""
|
||||
try:
|
||||
if self._redis is not None:
|
||||
await self._redis.close()
|
||||
self._redis = None
|
||||
|
||||
if self._pool is not None:
|
||||
await self._pool.disconnect()
|
||||
self._pool = None
|
||||
|
||||
logger.info("Redis connection closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing Redis connection: {e}")
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache configuration and connection info
|
||||
"""
|
||||
return {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"db": self.db,
|
||||
"connection_pool_size": self.connection_pool_size,
|
||||
"timeout_ms": self.timeout_ms,
|
||||
"default_ttl": self.default_ttl,
|
||||
"max_retries": self.max_retries,
|
||||
"key_prefix": self.key_prefix,
|
||||
"connected": self._redis is not None,
|
||||
}
|
||||
7
tests/unit/providers/utils/cache/__init__.py
vendored
Normal file
7
tests/unit/providers/utils/cache/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for cache store implementations."""
|
||||
257
tests/unit/providers/utils/cache/test_cache_store.py
vendored
Normal file
257
tests/unit/providers/utils/cache/test_cache_store.py
vendored
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for cache store base classes and utilities."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.cache import CacheError, CircuitBreaker
|
||||
|
||||
|
||||
class TestCacheError:
|
||||
"""Test suite for CacheError exception."""
|
||||
|
||||
def test_init_with_message(self):
|
||||
"""Test CacheError initialization with message."""
|
||||
error = CacheError("Failed to connect to cache")
|
||||
assert str(error) == "Failed to connect to cache"
|
||||
assert error.cause is None
|
||||
|
||||
def test_init_with_cause(self):
|
||||
"""Test CacheError initialization with underlying cause."""
|
||||
cause = ValueError("Invalid value")
|
||||
error = CacheError("Failed to set cache key", cause=cause)
|
||||
assert str(error) == "Failed to set cache key"
|
||||
assert error.cause == cause
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
"""Test suite for CircuitBreaker."""
|
||||
|
||||
def test_init_default_params(self):
|
||||
"""Test initialization with default parameters."""
|
||||
breaker = CircuitBreaker()
|
||||
assert breaker.failure_threshold == 10
|
||||
assert breaker.recovery_timeout == 60
|
||||
assert breaker.failure_count == 0
|
||||
assert breaker.last_failure_time is None
|
||||
assert breaker.state == "CLOSED"
|
||||
|
||||
def test_init_custom_params(self):
|
||||
"""Test initialization with custom parameters."""
|
||||
breaker = CircuitBreaker(failure_threshold=5, recovery_timeout=30)
|
||||
assert breaker.failure_threshold == 5
|
||||
assert breaker.recovery_timeout == 30
|
||||
|
||||
def test_is_closed_initial_state(self):
|
||||
"""Test is_closed in initial state."""
|
||||
breaker = CircuitBreaker()
|
||||
assert breaker.is_closed() is True
|
||||
assert breaker.get_state() == "CLOSED"
|
||||
|
||||
def test_record_success(self):
|
||||
"""Test recording successful operations."""
|
||||
breaker = CircuitBreaker()
|
||||
|
||||
# Record some failures
|
||||
breaker.record_failure()
|
||||
breaker.record_failure()
|
||||
assert breaker.failure_count == 2
|
||||
|
||||
# Record success should reset
|
||||
breaker.record_success()
|
||||
assert breaker.failure_count == 0
|
||||
assert breaker.last_failure_time is None
|
||||
assert breaker.state == "CLOSED"
|
||||
|
||||
def test_record_failure_below_threshold(self):
|
||||
"""Test recording failures below threshold."""
|
||||
breaker = CircuitBreaker(failure_threshold=5)
|
||||
|
||||
# Record failures below threshold
|
||||
for i in range(4):
|
||||
breaker.record_failure()
|
||||
assert breaker.is_closed() is True
|
||||
assert breaker.state == "CLOSED"
|
||||
|
||||
assert breaker.failure_count == 4
|
||||
|
||||
def test_record_failure_reach_threshold(self):
|
||||
"""Test circuit breaker opens when threshold reached."""
|
||||
breaker = CircuitBreaker(failure_threshold=3)
|
||||
|
||||
# Record failures to reach threshold
|
||||
for i in range(3):
|
||||
breaker.record_failure()
|
||||
|
||||
# Should be open now
|
||||
assert breaker.state == "OPEN"
|
||||
assert breaker.is_closed() is False
|
||||
|
||||
def test_circuit_open_blocks_requests(self):
|
||||
"""Test that open circuit blocks requests."""
|
||||
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=10)
|
||||
|
||||
# Open the circuit
|
||||
for i in range(3):
|
||||
breaker.record_failure()
|
||||
|
||||
assert breaker.is_closed() is False
|
||||
assert breaker.state == "OPEN"
|
||||
|
||||
async def test_recovery_timeout(self):
|
||||
"""Test circuit breaker recovery after timeout."""
|
||||
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
|
||||
|
||||
# Open the circuit
|
||||
for i in range(3):
|
||||
breaker.record_failure()
|
||||
|
||||
assert breaker.state == "OPEN"
|
||||
assert breaker.is_closed() is False
|
||||
|
||||
# Wait for recovery timeout
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Should enter HALF_OPEN state
|
||||
assert breaker.is_closed() is True
|
||||
assert breaker.state == "HALF_OPEN"
|
||||
|
||||
async def test_half_open_success_closes_circuit(self):
|
||||
"""Test successful request in HALF_OPEN closes circuit."""
|
||||
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
|
||||
|
||||
# Open the circuit
|
||||
for i in range(3):
|
||||
breaker.record_failure()
|
||||
|
||||
# Wait for recovery
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Trigger state transition by calling is_closed()
|
||||
assert breaker.is_closed() is True
|
||||
assert breaker.state == "HALF_OPEN"
|
||||
|
||||
# Record success
|
||||
breaker.record_success()
|
||||
assert breaker.state == "CLOSED"
|
||||
assert breaker.failure_count == 0
|
||||
|
||||
async def test_half_open_failure_reopens_circuit(self):
|
||||
"""Test failed request in HALF_OPEN reopens circuit."""
|
||||
breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
|
||||
|
||||
# Open the circuit
|
||||
for i in range(3):
|
||||
breaker.record_failure()
|
||||
|
||||
# Wait for recovery
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Trigger state transition by calling is_closed()
|
||||
assert breaker.is_closed() is True
|
||||
assert breaker.state == "HALF_OPEN"
|
||||
|
||||
# Record failure
|
||||
breaker.record_failure()
|
||||
assert breaker.state == "OPEN"
|
||||
|
||||
def test_reset(self):
|
||||
"""Test manual reset of circuit breaker."""
|
||||
breaker = CircuitBreaker(failure_threshold=3)
|
||||
|
||||
# Open the circuit
|
||||
for i in range(3):
|
||||
breaker.record_failure()
|
||||
|
||||
assert breaker.state == "OPEN"
|
||||
|
||||
# Manual reset
|
||||
breaker.reset()
|
||||
assert breaker.state == "CLOSED"
|
||||
assert breaker.failure_count == 0
|
||||
assert breaker.last_failure_time is None
|
||||
|
||||
def test_get_state(self):
|
||||
"""Test getting circuit breaker state."""
|
||||
breaker = CircuitBreaker(failure_threshold=3)
|
||||
|
||||
# Initial state
|
||||
assert breaker.get_state() == "CLOSED"
|
||||
|
||||
# After failures
|
||||
breaker.record_failure()
|
||||
assert breaker.get_state() == "CLOSED"
|
||||
|
||||
# Open state
|
||||
for i in range(2):
|
||||
breaker.record_failure()
|
||||
assert breaker.get_state() == "OPEN"
|
||||
|
||||
async def test_multiple_recovery_attempts(self):
|
||||
"""Test multiple recovery attempts."""
|
||||
breaker = CircuitBreaker(failure_threshold=2, recovery_timeout=1)
|
||||
|
||||
# Open the circuit
|
||||
breaker.record_failure()
|
||||
breaker.record_failure()
|
||||
assert breaker.state == "OPEN"
|
||||
|
||||
# First recovery attempt fails
|
||||
await asyncio.sleep(1.1)
|
||||
assert breaker.is_closed() is True # Trigger state check
|
||||
assert breaker.state == "HALF_OPEN"
|
||||
breaker.record_failure()
|
||||
assert breaker.state == "OPEN"
|
||||
|
||||
# Second recovery attempt succeeds
|
||||
await asyncio.sleep(1.1)
|
||||
assert breaker.is_closed() is True # Trigger state check
|
||||
assert breaker.state == "HALF_OPEN"
|
||||
breaker.record_success()
|
||||
assert breaker.state == "CLOSED"
|
||||
|
||||
def test_failure_count_tracking(self):
|
||||
"""Test failure count tracking."""
|
||||
breaker = CircuitBreaker(failure_threshold=5)
|
||||
|
||||
# Track failures
|
||||
assert breaker.failure_count == 0
|
||||
|
||||
breaker.record_failure()
|
||||
assert breaker.failure_count == 1
|
||||
|
||||
breaker.record_failure()
|
||||
assert breaker.failure_count == 2
|
||||
|
||||
# Success resets count
|
||||
breaker.record_success()
|
||||
assert breaker.failure_count == 0
|
||||
|
||||
async def test_concurrent_operations(self):
|
||||
"""Test circuit breaker with concurrent operations."""
|
||||
breaker = CircuitBreaker(failure_threshold=10)
|
||||
|
||||
async def record_failures(count: int):
|
||||
for _ in range(count):
|
||||
breaker.record_failure()
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Concurrent failures
|
||||
await asyncio.gather(
|
||||
record_failures(3),
|
||||
record_failures(3),
|
||||
record_failures(3),
|
||||
)
|
||||
|
||||
assert breaker.failure_count == 9
|
||||
assert breaker.state == "CLOSED"
|
||||
|
||||
# One more should open it
|
||||
breaker.record_failure()
|
||||
assert breaker.state == "OPEN"
|
||||
332
tests/unit/providers/utils/cache/test_memory_cache.py
vendored
Normal file
332
tests/unit/providers/utils/cache/test_memory_cache.py
vendored
Normal file
|
|
@ -0,0 +1,332 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for MemoryCacheStore implementation."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.cache import CacheError, MemoryCacheStore
|
||||
|
||||
|
||||
class TestMemoryCacheStore:
|
||||
"""Test suite for MemoryCacheStore."""
|
||||
|
||||
async def test_init_default_params(self):
|
||||
"""Test initialization with default parameters."""
|
||||
cache = MemoryCacheStore()
|
||||
assert cache.max_entries == 1000
|
||||
assert cache.max_memory_mb == 512
|
||||
assert cache.default_ttl == 600
|
||||
assert cache.eviction_policy == "lru"
|
||||
|
||||
async def test_init_custom_params(self):
|
||||
"""Test initialization with custom parameters."""
|
||||
cache = MemoryCacheStore(
|
||||
max_entries=500,
|
||||
max_memory_mb=256,
|
||||
default_ttl=300,
|
||||
eviction_policy="lfu",
|
||||
)
|
||||
assert cache.max_entries == 500
|
||||
assert cache.max_memory_mb == 256
|
||||
assert cache.default_ttl == 300
|
||||
assert cache.eviction_policy == "lfu"
|
||||
|
||||
async def test_init_invalid_params(self):
|
||||
"""Test initialization with invalid parameters."""
|
||||
with pytest.raises(ValueError, match="max_entries must be positive"):
|
||||
MemoryCacheStore(max_entries=0)
|
||||
|
||||
with pytest.raises(ValueError, match="default_ttl must be positive"):
|
||||
MemoryCacheStore(default_ttl=0)
|
||||
|
||||
with pytest.raises(ValueError, match="max_memory_mb must be positive"):
|
||||
MemoryCacheStore(max_memory_mb=0)
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown eviction policy"):
|
||||
MemoryCacheStore(eviction_policy="invalid") # type: ignore
|
||||
|
||||
async def test_set_and_get(self):
|
||||
"""Test basic set and get operations."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Set value
|
||||
await cache.set("key1", "value1")
|
||||
|
||||
# Get value
|
||||
value = await cache.get("key1")
|
||||
assert value == "value1"
|
||||
|
||||
async def test_get_nonexistent_key(self):
|
||||
"""Test getting a non-existent key."""
|
||||
cache = MemoryCacheStore()
|
||||
value = await cache.get("nonexistent")
|
||||
assert value is None
|
||||
|
||||
async def test_set_with_custom_ttl(self):
|
||||
"""Test setting value with custom TTL."""
|
||||
cache = MemoryCacheStore(default_ttl=10)
|
||||
|
||||
# Set with custom TTL
|
||||
await cache.set("key1", "value1", ttl=1)
|
||||
|
||||
# Value should exist initially
|
||||
value = await cache.get("key1")
|
||||
assert value == "value1"
|
||||
|
||||
# Wait for expiration
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Value should be expired
|
||||
value = await cache.get("key1")
|
||||
assert value is None
|
||||
|
||||
async def test_set_complex_value(self):
|
||||
"""Test storing complex data types."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Test dictionary
|
||||
data = {"nested": {"key": "value"}, "list": [1, 2, 3]}
|
||||
await cache.set("complex", data)
|
||||
value = await cache.get("complex")
|
||||
assert value == data
|
||||
|
||||
# Test list
|
||||
list_data = [1, "two", {"three": 3}]
|
||||
await cache.set("list", list_data)
|
||||
value = await cache.get("list")
|
||||
assert value == list_data
|
||||
|
||||
async def test_delete(self):
|
||||
"""Test deleting a key."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Set and delete
|
||||
await cache.set("key1", "value1")
|
||||
deleted = await cache.delete("key1")
|
||||
assert deleted is True
|
||||
|
||||
# Verify deleted
|
||||
value = await cache.get("key1")
|
||||
assert value is None
|
||||
|
||||
# Delete non-existent key
|
||||
deleted = await cache.delete("nonexistent")
|
||||
assert deleted is False
|
||||
|
||||
async def test_exists(self):
|
||||
"""Test checking key existence."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Non-existent key
|
||||
exists = await cache.exists("key1")
|
||||
assert exists is False
|
||||
|
||||
# Existing key
|
||||
await cache.set("key1", "value1")
|
||||
exists = await cache.exists("key1")
|
||||
assert exists is True
|
||||
|
||||
# Expired key
|
||||
await cache.set("key2", "value2", ttl=1)
|
||||
await asyncio.sleep(1.1)
|
||||
exists = await cache.exists("key2")
|
||||
assert exists is False
|
||||
|
||||
async def test_ttl(self):
|
||||
"""Test getting remaining TTL."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Non-existent key
|
||||
ttl = await cache.ttl("nonexistent")
|
||||
assert ttl is None
|
||||
|
||||
# Key with TTL
|
||||
await cache.set("key1", "value1", ttl=10)
|
||||
ttl = await cache.ttl("key1")
|
||||
assert ttl is not None
|
||||
assert 8 <= ttl <= 10 # Allow some tolerance
|
||||
|
||||
# Expired key
|
||||
await cache.set("key2", "value2", ttl=1)
|
||||
await asyncio.sleep(1.1)
|
||||
ttl = await cache.ttl("key2")
|
||||
assert ttl is None
|
||||
|
||||
async def test_clear(self):
|
||||
"""Test clearing all entries."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Add multiple entries
|
||||
await cache.set("key1", "value1")
|
||||
await cache.set("key2", "value2")
|
||||
await cache.set("key3", "value3")
|
||||
|
||||
# Clear
|
||||
await cache.clear()
|
||||
|
||||
# Verify all cleared
|
||||
assert await cache.get("key1") is None
|
||||
assert await cache.get("key2") is None
|
||||
assert await cache.get("key3") is None
|
||||
|
||||
async def test_size(self):
|
||||
"""Test getting cache size."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Empty cache
|
||||
size = await cache.size()
|
||||
assert size == 0
|
||||
|
||||
# Add entries
|
||||
await cache.set("key1", "value1")
|
||||
await cache.set("key2", "value2")
|
||||
size = await cache.size()
|
||||
assert size == 2
|
||||
|
||||
# Delete entry
|
||||
await cache.delete("key1")
|
||||
size = await cache.size()
|
||||
assert size == 1
|
||||
|
||||
# Clear cache
|
||||
await cache.clear()
|
||||
size = await cache.size()
|
||||
assert size == 0
|
||||
|
||||
async def test_lru_eviction(self):
|
||||
"""Test LRU eviction policy."""
|
||||
cache = MemoryCacheStore(max_entries=3, eviction_policy="lru")
|
||||
|
||||
# Fill cache
|
||||
await cache.set("key1", "value1")
|
||||
await cache.set("key2", "value2")
|
||||
await cache.set("key3", "value3")
|
||||
|
||||
# Access key1 to make it recently used
|
||||
await cache.get("key1")
|
||||
|
||||
# Add new entry, should evict key2 (least recently used)
|
||||
await cache.set("key4", "value4")
|
||||
|
||||
# key2 should be evicted
|
||||
assert await cache.get("key1") == "value1"
|
||||
assert await cache.get("key2") is None
|
||||
assert await cache.get("key3") == "value3"
|
||||
assert await cache.get("key4") == "value4"
|
||||
|
||||
async def test_lfu_eviction(self):
|
||||
"""Test LFU eviction policy."""
|
||||
cache = MemoryCacheStore(max_entries=3, eviction_policy="lfu")
|
||||
|
||||
# Fill cache
|
||||
await cache.set("key1", "value1")
|
||||
await cache.set("key2", "value2")
|
||||
await cache.set("key3", "value3")
|
||||
|
||||
# Access key1 multiple times
|
||||
await cache.get("key1")
|
||||
await cache.get("key1")
|
||||
await cache.get("key1")
|
||||
|
||||
# Access key2 twice
|
||||
await cache.get("key2")
|
||||
await cache.get("key2")
|
||||
|
||||
# key3 accessed once (least frequently)
|
||||
|
||||
# Add new entry, should evict key3 (least frequently used)
|
||||
await cache.set("key4", "value4")
|
||||
|
||||
# key3 should be evicted
|
||||
assert await cache.get("key1") == "value1"
|
||||
assert await cache.get("key2") == "value2"
|
||||
assert await cache.get("key3") is None
|
||||
assert await cache.get("key4") == "value4"
|
||||
|
||||
async def test_concurrent_access(self):
|
||||
"""Test concurrent access to cache."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
async def set_value(key: str, value: str):
|
||||
await cache.set(key, value)
|
||||
|
||||
async def get_value(key: str):
|
||||
return await cache.get(key)
|
||||
|
||||
# Concurrent sets
|
||||
await asyncio.gather(
|
||||
set_value("key1", "value1"),
|
||||
set_value("key2", "value2"),
|
||||
set_value("key3", "value3"),
|
||||
)
|
||||
|
||||
# Concurrent gets
|
||||
results = await asyncio.gather(
|
||||
get_value("key1"),
|
||||
get_value("key2"),
|
||||
get_value("key3"),
|
||||
)
|
||||
|
||||
assert results == ["value1", "value2", "value3"]
|
||||
|
||||
async def test_update_existing_key(self):
|
||||
"""Test updating an existing key."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Set initial value
|
||||
await cache.set("key1", "value1")
|
||||
assert await cache.get("key1") == "value1"
|
||||
|
||||
# Update value
|
||||
await cache.set("key1", "value2")
|
||||
assert await cache.get("key1") == "value2"
|
||||
|
||||
async def test_get_stats(self):
|
||||
"""Test getting cache statistics."""
|
||||
cache = MemoryCacheStore(
|
||||
max_entries=100,
|
||||
max_memory_mb=128,
|
||||
default_ttl=300,
|
||||
eviction_policy="lru",
|
||||
)
|
||||
|
||||
await cache.set("key1", "value1")
|
||||
await cache.set("key2", "value2")
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats["size"] == 2
|
||||
assert stats["max_entries"] == 100
|
||||
assert stats["max_memory_mb"] == 128
|
||||
assert stats["default_ttl"] == 300
|
||||
assert stats["eviction_policy"] == "lru"
|
||||
|
||||
async def test_ttl_expiration_cleanup(self):
|
||||
"""Test that expired entries are cleaned up properly."""
|
||||
cache = MemoryCacheStore()
|
||||
|
||||
# Set entry with short TTL
|
||||
await cache.set("key1", "value1", ttl=1)
|
||||
await cache.set("key2", "value2", ttl=10)
|
||||
|
||||
# Initially both exist
|
||||
assert await cache.size() == 2
|
||||
|
||||
# Wait for key1 to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Accessing expired key should clean it up
|
||||
assert await cache.get("key1") is None
|
||||
|
||||
# Size should reflect cleanup
|
||||
size = await cache.size()
|
||||
assert size == 1
|
||||
|
||||
# key2 should still exist
|
||||
assert await cache.get("key2") == "value2"
|
||||
421
tests/unit/providers/utils/cache/test_redis_cache.py
vendored
Normal file
421
tests/unit/providers/utils/cache/test_redis_cache.py
vendored
Normal file
|
|
@ -0,0 +1,421 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for RedisCacheStore implementation."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.cache import CacheError, RedisCacheStore
|
||||
|
||||
|
||||
class TestRedisCacheStore:
|
||||
"""Test suite for RedisCacheStore."""
|
||||
|
||||
async def test_init_default_params(self):
|
||||
"""Test initialization with default parameters."""
|
||||
cache = RedisCacheStore()
|
||||
assert cache.host == "localhost"
|
||||
assert cache.port == 6379
|
||||
assert cache.db == 0
|
||||
assert cache.password is None
|
||||
assert cache.connection_pool_size == 10
|
||||
assert cache.timeout_ms == 100
|
||||
assert cache.default_ttl == 600
|
||||
assert cache.max_retries == 3
|
||||
assert cache.key_prefix == "llama_stack:"
|
||||
|
||||
async def test_init_custom_params(self):
|
||||
"""Test initialization with custom parameters."""
|
||||
cache = RedisCacheStore(
|
||||
host="redis.example.com",
|
||||
port=6380,
|
||||
db=1,
|
||||
password="secret",
|
||||
connection_pool_size=20,
|
||||
timeout_ms=200,
|
||||
default_ttl=300,
|
||||
max_retries=5,
|
||||
key_prefix="test:",
|
||||
)
|
||||
assert cache.host == "redis.example.com"
|
||||
assert cache.port == 6380
|
||||
assert cache.db == 1
|
||||
assert cache.password == "secret"
|
||||
assert cache.connection_pool_size == 20
|
||||
assert cache.timeout_ms == 200
|
||||
assert cache.default_ttl == 300
|
||||
assert cache.max_retries == 5
|
||||
assert cache.key_prefix == "test:"
|
||||
|
||||
async def test_init_invalid_params(self):
|
||||
"""Test initialization with invalid parameters."""
|
||||
with pytest.raises(ValueError, match="connection_pool_size must be positive"):
|
||||
RedisCacheStore(connection_pool_size=0)
|
||||
|
||||
with pytest.raises(ValueError, match="timeout_ms must be positive"):
|
||||
RedisCacheStore(timeout_ms=0)
|
||||
|
||||
with pytest.raises(ValueError, match="default_ttl must be positive"):
|
||||
RedisCacheStore(default_ttl=0)
|
||||
|
||||
with pytest.raises(ValueError, match="max_retries must be non-negative"):
|
||||
RedisCacheStore(max_retries=-1)
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_ensure_connection(self, mock_redis_class, mock_pool_class):
|
||||
"""Test connection establishment."""
|
||||
# Setup mocks
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Ensure connection
|
||||
redis = await cache._ensure_connection()
|
||||
|
||||
# Verify connection was established
|
||||
assert redis == mock_redis
|
||||
mock_pool_class.assert_called_once()
|
||||
mock_redis.ping.assert_called_once()
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_connection_failure(self, mock_redis_class, mock_pool_class):
|
||||
"""Test connection failure handling."""
|
||||
from redis.exceptions import ConnectionError as RedisConnectionError
|
||||
|
||||
# Setup mocks to fail
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock(side_effect=RedisConnectionError("Connection refused"))
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Connection should fail
|
||||
with pytest.raises(CacheError, match="Failed to connect to Redis"):
|
||||
await cache._ensure_connection()
|
||||
|
||||
def test_make_key(self):
|
||||
"""Test key prefixing."""
|
||||
cache = RedisCacheStore(key_prefix="test:")
|
||||
assert cache._make_key("mykey") == "test:mykey"
|
||||
assert cache._make_key("another") == "test:another"
|
||||
|
||||
def test_serialize_deserialize(self):
|
||||
"""Test value serialization."""
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Simple value
|
||||
assert cache._serialize("hello") == '"hello"'
|
||||
assert cache._deserialize('"hello"') == "hello"
|
||||
|
||||
# Dictionary
|
||||
data = {"key": "value", "number": 42}
|
||||
serialized = cache._serialize(data)
|
||||
assert cache._deserialize(serialized) == data
|
||||
|
||||
# List
|
||||
list_data = [1, 2, "three"]
|
||||
serialized = cache._serialize(list_data)
|
||||
assert cache._deserialize(serialized) == list_data
|
||||
|
||||
def test_serialize_error(self):
|
||||
"""Test serialization error handling."""
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Object that can't be serialized
|
||||
class NonSerializable:
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match="Value is not JSON-serializable"):
|
||||
cache._serialize(NonSerializable())
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_set_and_get(self, mock_redis_class, mock_pool_class):
|
||||
"""Test set and get operations."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=json.dumps("value1"))
|
||||
mock_redis.setex = AsyncMock()
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Set value
|
||||
await cache.set("key1", "value1")
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
# Get value
|
||||
value = await cache.get("key1")
|
||||
assert value == "value1"
|
||||
mock_redis.get.assert_called_once()
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_get_nonexistent_key(self, mock_redis_class, mock_pool_class):
|
||||
"""Test getting a non-existent key."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Get non-existent key
|
||||
value = await cache.get("nonexistent")
|
||||
assert value is None
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_set_with_custom_ttl(self, mock_redis_class, mock_pool_class):
|
||||
"""Test setting value with custom TTL."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.setex = AsyncMock()
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore(default_ttl=600)
|
||||
|
||||
# Set with custom TTL
|
||||
await cache.set("key1", "value1", ttl=300)
|
||||
|
||||
# Verify setex was called with custom TTL
|
||||
call_args = mock_redis.setex.call_args
|
||||
assert call_args[0][1] == 300 # TTL argument
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_delete(self, mock_redis_class, mock_pool_class):
|
||||
"""Test deleting a key."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.delete = AsyncMock(return_value=1) # 1 key deleted
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Delete key
|
||||
deleted = await cache.delete("key1")
|
||||
assert deleted is True
|
||||
|
||||
# Delete non-existent key
|
||||
mock_redis.delete = AsyncMock(return_value=0)
|
||||
deleted = await cache.delete("nonexistent")
|
||||
assert deleted is False
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_exists(self, mock_redis_class, mock_pool_class):
|
||||
"""Test checking key existence."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.exists = AsyncMock(return_value=1) # Exists
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Check existing key
|
||||
exists = await cache.exists("key1")
|
||||
assert exists is True
|
||||
|
||||
# Check non-existent key
|
||||
mock_redis.exists = AsyncMock(return_value=0)
|
||||
exists = await cache.exists("nonexistent")
|
||||
assert exists is False
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_ttl(self, mock_redis_class, mock_pool_class):
|
||||
"""Test getting remaining TTL."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.ttl = AsyncMock(return_value=300)
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Get TTL
|
||||
ttl = await cache.ttl("key1")
|
||||
assert ttl == 300
|
||||
|
||||
# Key doesn't exist
|
||||
mock_redis.ttl = AsyncMock(return_value=-2)
|
||||
ttl = await cache.ttl("nonexistent")
|
||||
assert ttl is None
|
||||
|
||||
# Key has no TTL
|
||||
mock_redis.ttl = AsyncMock(return_value=-1)
|
||||
ttl = await cache.ttl("no_ttl_key")
|
||||
assert ttl is None
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_clear(self, mock_redis_class, mock_pool_class):
|
||||
"""Test clearing all entries."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.scan = AsyncMock(
|
||||
side_effect=[
|
||||
(10, ["llama_stack:key1", "llama_stack:key2"]),
|
||||
(0, ["llama_stack:key3"]), # cursor 0 indicates end
|
||||
]
|
||||
)
|
||||
mock_redis.delete = AsyncMock(return_value=3)
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Clear cache
|
||||
await cache.clear()
|
||||
|
||||
# Verify scan and delete were called
|
||||
assert mock_redis.scan.call_count == 2
|
||||
mock_redis.delete.assert_called()
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_size(self, mock_redis_class, mock_pool_class):
|
||||
"""Test getting cache size."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.scan = AsyncMock(
|
||||
side_effect=[
|
||||
(10, ["llama_stack:key1", "llama_stack:key2"]),
|
||||
(0, ["llama_stack:key3"]),
|
||||
]
|
||||
)
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache
|
||||
cache = RedisCacheStore()
|
||||
|
||||
# Get size
|
||||
size = await cache.size()
|
||||
assert size == 3
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_retry_logic(self, mock_redis_class, mock_pool_class):
|
||||
"""Test retry logic for transient failures."""
|
||||
from redis.exceptions import TimeoutError as RedisTimeoutError
|
||||
|
||||
# Setup mocks - fail twice, then succeed
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.get = AsyncMock(
|
||||
side_effect=[
|
||||
RedisTimeoutError("Timeout"),
|
||||
RedisTimeoutError("Timeout"),
|
||||
json.dumps("success"),
|
||||
]
|
||||
)
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache with retries
|
||||
cache = RedisCacheStore(max_retries=3)
|
||||
|
||||
# Should succeed after retries
|
||||
value = await cache.get("key1")
|
||||
assert value == "success"
|
||||
assert mock_redis.get.call_count == 3
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_retry_exhaustion(self, mock_redis_class, mock_pool_class):
|
||||
"""Test behavior when all retries are exhausted."""
|
||||
from redis.exceptions import TimeoutError as RedisTimeoutError
|
||||
|
||||
# Setup mocks - always fail
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=RedisTimeoutError("Timeout"))
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# Create cache with limited retries
|
||||
cache = RedisCacheStore(max_retries=2)
|
||||
|
||||
# Should raise CacheError after exhausting retries
|
||||
with pytest.raises(CacheError, match="failed after .* attempts"):
|
||||
await cache.get("key1")
|
||||
|
||||
# Should have tried 3 times (initial + 2 retries)
|
||||
assert mock_redis.get.call_count == 3
|
||||
|
||||
@patch("llama_stack.providers.utils.cache.redis.ConnectionPool")
|
||||
@patch("llama_stack.providers.utils.cache.redis.Redis")
|
||||
async def test_close(self, mock_redis_class, mock_pool_class):
|
||||
"""Test closing Redis connection."""
|
||||
# Setup mocks
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock()
|
||||
mock_redis.close = AsyncMock()
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
mock_pool = AsyncMock()
|
||||
mock_pool.disconnect = AsyncMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Create cache and establish connection
|
||||
cache = RedisCacheStore()
|
||||
await cache._ensure_connection()
|
||||
|
||||
# Close connection
|
||||
await cache.close()
|
||||
|
||||
# Verify cleanup
|
||||
mock_redis.close.assert_called_once()
|
||||
mock_pool.disconnect.assert_called_once()
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting cache statistics."""
|
||||
cache = RedisCacheStore(
|
||||
host="redis.example.com",
|
||||
port=6380,
|
||||
db=1,
|
||||
connection_pool_size=20,
|
||||
timeout_ms=200,
|
||||
default_ttl=300,
|
||||
max_retries=5,
|
||||
key_prefix="test:",
|
||||
)
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats["host"] == "redis.example.com"
|
||||
assert stats["port"] == 6380
|
||||
assert stats["db"] == 1
|
||||
assert stats["connection_pool_size"] == 20
|
||||
assert stats["timeout_ms"] == 200
|
||||
assert stats["default_ttl"] == 300
|
||||
assert stats["max_retries"] == 5
|
||||
assert stats["key_prefix"] == "test:"
|
||||
assert stats["connected"] is False # Not connected yet
|
||||
17
uv.lock
generated
17
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
|||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.12"
|
||||
resolution-markers = [
|
||||
"(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
|
|
@ -1996,6 +1996,7 @@ dependencies = [
|
|||
{ name = "aiohttp" },
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "asyncpg" },
|
||||
{ name = "cachetools" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "fire" },
|
||||
{ name = "h11" },
|
||||
|
|
@ -2013,6 +2014,7 @@ dependencies = [
|
|||
{ name = "python-dotenv" },
|
||||
{ name = "python-multipart" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "redis" },
|
||||
{ name = "rich" },
|
||||
{ name = "sqlalchemy", extra = ["asyncio"] },
|
||||
{ name = "starlette" },
|
||||
|
|
@ -2147,6 +2149,7 @@ requires-dist = [
|
|||
{ name = "aiohttp" },
|
||||
{ name = "aiosqlite", specifier = ">=0.21.0" },
|
||||
{ name = "asyncpg" },
|
||||
{ name = "cachetools", specifier = ">=5.5.0" },
|
||||
{ name = "fastapi", specifier = ">=0.115.0,<1.0" },
|
||||
{ name = "fire" },
|
||||
{ name = "h11", specifier = ">=0.16.0" },
|
||||
|
|
@ -2166,6 +2169,7 @@ requires-dist = [
|
|||
{ name = "python-multipart", specifier = ">=0.0.20" },
|
||||
{ name = "pyyaml", specifier = ">=6.0" },
|
||||
{ name = "pyyaml", specifier = ">=6.0.2" },
|
||||
{ name = "redis", specifier = ">=5.2.0" },
|
||||
{ name = "rich" },
|
||||
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
|
||||
{ name = "starlette" },
|
||||
|
|
@ -4398,6 +4402,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/ef/33/d8df6a2b214ffbe4138db9a1efe3248f67dc3c671f82308bea1582ecbbb7/qdrant_client-1.15.1-py3-none-any.whl", hash = "sha256:2b975099b378382f6ca1cfb43f0d59e541be6e16a5892f282a4b8de7eff5cb63", size = 337331, upload-time = "2025-07-31T19:35:17.539Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "7.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/57/8f/f125feec0b958e8d22c8f0b492b30b1991d9499a4315dfde466cf4289edc/redis-7.0.1.tar.gz", hash = "sha256:c949df947dca995dc68fdf5a7863950bf6df24f8d6022394585acc98e81624f1", size = 4755322, upload-time = "2025-10-27T14:34:00.33Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/97/9f22a33c475cda519f20aba6babb340fb2f2254a02fb947816960d1e669a/redis-7.0.1-py3-none-any.whl", hash = "sha256:4977af3c7d67f8f0eb8b6fec0dafc9605db9343142f634041fb0235f67c0588a", size = 339938, upload-time = "2025-10-27T14:33:58.553Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "referencing"
|
||||
version = "0.36.2"
|
||||
|
|
@ -4656,6 +4669,8 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/6b/fa/3234f913fe9a6525a7b97c6dad1f51e72b917e6872e051a5e2ffd8b16fbb/ruamel.yaml.clib-0.2.14-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:70eda7703b8126f5e52fcf276e6c0f40b0d314674f896fc58c47b0aef2b9ae83", size = 137970, upload-time = "2025-09-22T19:51:09.472Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ef/ec/4edbf17ac2c87fa0845dd366ef8d5852b96eb58fcd65fc1ecf5fe27b4641/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a0cb71ccc6ef9ce36eecb6272c81afdc2f565950cdcec33ae8e6cd8f7fc86f27", size = 739639, upload-time = "2025-09-22T19:51:10.566Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/15/18/b0e1fafe59051de9e79cdd431863b03593ecfa8341c110affad7c8121efc/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e7cb9ad1d525d40f7d87b6df7c0ff916a66bc52cb61b66ac1b2a16d0c1b07640", size = 764456, upload-time = "2025-09-22T19:51:11.736Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e7/cd/150fdb96b8fab27fe08d8a59fe67554568727981806e6bc2677a16081ec7/ruamel_yaml_clib-0.2.14-cp314-cp314-win32.whl", hash = "sha256:9b4104bf43ca0cd4e6f738cb86326a3b2f6eef00f417bd1e7efb7bdffe74c539", size = 102394, upload-time = "2025-11-14T21:57:36.703Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bd/e6/a3fa40084558c7e1dc9546385f22a93949c890a8b2e445b2ba43935f51da/ruamel_yaml_clib-0.2.14-cp314-cp314-win_amd64.whl", hash = "sha256:13997d7d354a9890ea1ec5937a219817464e5cc344805b37671562a401ca3008", size = 122673, upload-time = "2025-11-14T21:57:38.177Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue