mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(lowest_tpm_rpm_v2.py): support batch writing increments to redis
This commit is contained in:
parent
5e892a1e92
commit
1328afe612
5 changed files with 176 additions and 10 deletions
|
@ -7,6 +7,7 @@ DEFAULT_MAX_RETRIES = 2
|
||||||
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
|
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
|
||||||
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
|
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
|
||||||
)
|
)
|
||||||
|
DEFAULT_REDIS_SYNC_INTERVAL = 1
|
||||||
DEFAULT_COOLDOWN_TIME_SECONDS = 5
|
DEFAULT_COOLDOWN_TIME_SECONDS = 5
|
||||||
DEFAULT_REPLICATE_POLLING_RETRIES = 5
|
DEFAULT_REPLICATE_POLLING_RETRIES = 5
|
||||||
DEFAULT_REPLICATE_POLLING_DELAY_SECONDS = 1
|
DEFAULT_REPLICATE_POLLING_DELAY_SECONDS = 1
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -3,13 +3,13 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: azure/chatgpt-v-2
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_API_KEY
|
||||||
api_base: os.environ/AZURE_API_BASE
|
api_base: http://0.0.0.0:8090
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
callbacks: ["prometheus"]
|
callbacks: ["otel"]
|
||||||
|
|
||||||
router_settings:
|
router_settings:
|
||||||
routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE
|
# routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE
|
||||||
redis_host: os.environ/REDIS_HOST
|
redis_host: os.environ/REDIS_HOST
|
||||||
redis_password: os.environ/REDIS_PASSWORD
|
redis_password: os.environ/REDIS_PASSWORD
|
||||||
redis_port: os.environ/REDIS_PORT
|
redis_port: os.environ/REDIS_PORT
|
||||||
|
|
154
litellm/router_strategy/base_routing_strategy.py
Normal file
154
litellm/router_strategy/base_routing_strategy.py
Normal file
|
@ -0,0 +1,154 @@
|
||||||
|
"""
|
||||||
|
Base class across routing strategies to abstract commmon functions like batch incrementing redis
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from litellm._logging import verbose_router_logger
|
||||||
|
from litellm.caching.caching import DualCache
|
||||||
|
from litellm.caching.redis_cache import RedisPipelineIncrementOperation
|
||||||
|
from litellm.constants import DEFAULT_REDIS_SYNC_INTERVAL
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRoutingStrategy(ABC):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dual_cache: DualCache,
|
||||||
|
should_batch_redis_writes: bool,
|
||||||
|
default_sync_interval: Optional[Union[int, float]],
|
||||||
|
):
|
||||||
|
self.dual_cache = dual_cache
|
||||||
|
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
|
||||||
|
if should_batch_redis_writes:
|
||||||
|
asyncio.create_task(
|
||||||
|
self.periodic_sync_in_memory_spend_with_redis(
|
||||||
|
default_sync_interval=default_sync_interval
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _increment_value_in_current_window(
|
||||||
|
self, key: str, value: Union[int, float], ttl: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Increment spend within existing budget window
|
||||||
|
|
||||||
|
Runs once the budget start time exists in Redis Cache (on the 2nd and subsequent requests to the same provider)
|
||||||
|
|
||||||
|
- Increments the spend in memory cache (so spend instantly updated in memory)
|
||||||
|
- Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM)
|
||||||
|
"""
|
||||||
|
result = await self.dual_cache.in_memory_cache.async_increment(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
ttl=ttl,
|
||||||
|
)
|
||||||
|
increment_op = RedisPipelineIncrementOperation(
|
||||||
|
key=key,
|
||||||
|
increment_value=value,
|
||||||
|
ttl=ttl,
|
||||||
|
)
|
||||||
|
self.redis_increment_operation_queue.append(increment_op)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def periodic_sync_in_memory_spend_with_redis(
|
||||||
|
self, default_sync_interval: Optional[Union[int, float]]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Handler that triggers sync_in_memory_spend_with_redis every DEFAULT_REDIS_SYNC_INTERVAL seconds
|
||||||
|
|
||||||
|
Required for multi-instance environment usage of provider budgets
|
||||||
|
"""
|
||||||
|
default_sync_interval = default_sync_interval or DEFAULT_REDIS_SYNC_INTERVAL
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await self._sync_in_memory_spend_with_redis()
|
||||||
|
await asyncio.sleep(
|
||||||
|
default_sync_interval
|
||||||
|
) # Wait for DEFAULT_REDIS_SYNC_INTERVAL seconds before next sync
|
||||||
|
except Exception as e:
|
||||||
|
verbose_router_logger.error(f"Error in periodic sync task: {str(e)}")
|
||||||
|
await asyncio.sleep(
|
||||||
|
default_sync_interval
|
||||||
|
) # Still wait DEFAULT_REDIS_SYNC_INTERVAL seconds on error before retrying
|
||||||
|
|
||||||
|
async def _push_in_memory_increments_to_redis(self):
|
||||||
|
"""
|
||||||
|
How this works:
|
||||||
|
- async_log_success_event collects all provider spend increments in `redis_increment_operation_queue`
|
||||||
|
- This function pushes all increments to Redis in a batched pipeline to optimize performance
|
||||||
|
|
||||||
|
Only runs if Redis is initialized
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self.dual_cache.redis_cache:
|
||||||
|
return # Redis is not initialized
|
||||||
|
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
"Pushing Redis Increment Pipeline for queue: %s",
|
||||||
|
self.redis_increment_operation_queue,
|
||||||
|
)
|
||||||
|
if len(self.redis_increment_operation_queue) > 0:
|
||||||
|
asyncio.create_task(
|
||||||
|
self.dual_cache.redis_cache.async_increment_pipeline(
|
||||||
|
increment_list=self.redis_increment_operation_queue,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.redis_increment_operation_queue = []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_router_logger.error(
|
||||||
|
f"Error syncing in-memory cache with Redis: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_cache_keys(self) -> List:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _sync_in_memory_spend_with_redis(self):
|
||||||
|
"""
|
||||||
|
Ensures in-memory cache is updated with latest Redis values for all provider spends.
|
||||||
|
|
||||||
|
Why Do we need this?
|
||||||
|
- Optimization to hit sub 100ms latency. Performance was impacted when redis was used for read/write per request
|
||||||
|
- Use provider budgets in multi-instance environment, we use Redis to sync spend across all instances
|
||||||
|
|
||||||
|
What this does:
|
||||||
|
1. Push all provider spend increments to Redis
|
||||||
|
2. Fetch all current provider spend from Redis to update in-memory cache
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# No need to sync if Redis cache is not initialized
|
||||||
|
if self.dual_cache.redis_cache is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 1. Push all provider spend increments to Redis
|
||||||
|
await self._push_in_memory_increments_to_redis()
|
||||||
|
|
||||||
|
# 2. Fetch all current provider spend from Redis to update in-memory cache
|
||||||
|
cache_keys = self.get_cache_keys()
|
||||||
|
|
||||||
|
# Batch fetch current spend values from Redis
|
||||||
|
redis_values = await self.dual_cache.redis_cache.async_batch_get_cache(
|
||||||
|
key_list=cache_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update in-memory cache with Redis values
|
||||||
|
if isinstance(redis_values, dict): # Check if redis_values is a dictionary
|
||||||
|
for key, value in redis_values.items():
|
||||||
|
if value is not None:
|
||||||
|
await self.dual_cache.in_memory_cache.async_set_cache(
|
||||||
|
key=key, value=float(value)
|
||||||
|
)
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"Updated in-memory cache for {key}: {value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_router_logger.error(
|
||||||
|
f"Error syncing in-memory cache with Redis: {str(e)}"
|
||||||
|
)
|
|
@ -15,6 +15,8 @@ from litellm.types.router import RouterErrors
|
||||||
from litellm.types.utils import LiteLLMPydanticObjectBase, StandardLoggingPayload
|
from litellm.types.utils import LiteLLMPydanticObjectBase, StandardLoggingPayload
|
||||||
from litellm.utils import get_utc_datetime, print_verbose
|
from litellm.utils import get_utc_datetime, print_verbose
|
||||||
|
|
||||||
|
from .base_routing_strategy import BaseRoutingStrategy
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
|
@ -27,7 +29,7 @@ class RoutingArgs(LiteLLMPydanticObjectBase):
|
||||||
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
|
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
|
||||||
|
|
||||||
|
|
||||||
class LowestTPMLoggingHandler_v2(CustomLogger):
|
class LowestTPMLoggingHandler_v2(BaseRoutingStrategy, CustomLogger):
|
||||||
"""
|
"""
|
||||||
Updated version of TPM/RPM Logging.
|
Updated version of TPM/RPM Logging.
|
||||||
|
|
||||||
|
@ -51,6 +53,12 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
self.router_cache = router_cache
|
self.router_cache = router_cache
|
||||||
self.model_list = model_list
|
self.model_list = model_list
|
||||||
self.routing_args = RoutingArgs(**routing_args)
|
self.routing_args = RoutingArgs(**routing_args)
|
||||||
|
BaseRoutingStrategy.__init__(
|
||||||
|
self,
|
||||||
|
dual_cache=router_cache,
|
||||||
|
should_batch_redis_writes=True,
|
||||||
|
default_sync_interval=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
|
def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
|
||||||
"""
|
"""
|
||||||
|
@ -104,6 +112,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# if local result below limit, check redis ## prevent unnecessary redis checks
|
# if local result below limit, check redis ## prevent unnecessary redis checks
|
||||||
|
|
||||||
result = self.router_cache.increment_cache(
|
result = self.router_cache.increment_cache(
|
||||||
key=rpm_key, value=1, ttl=self.routing_args.ttl
|
key=rpm_key, value=1, ttl=self.routing_args.ttl
|
||||||
)
|
)
|
||||||
|
@ -186,12 +195,15 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# if local result below limit, check redis ## prevent unnecessary redis checks
|
# if local result below limit, check redis ## prevent unnecessary redis checks
|
||||||
result = await self.router_cache.async_increment_cache(
|
result = await self._increment_value_in_current_window(
|
||||||
key=rpm_key,
|
key=rpm_key, value=1, ttl=self.routing_args.ttl
|
||||||
value=1,
|
|
||||||
ttl=self.routing_args.ttl,
|
|
||||||
parent_otel_span=parent_otel_span,
|
|
||||||
)
|
)
|
||||||
|
# result = await self.router_cache.async_increment_cache(
|
||||||
|
# key=rpm_key,
|
||||||
|
# value=1,
|
||||||
|
# ttl=self.routing_args.ttl,
|
||||||
|
# parent_otel_span=parent_otel_span,
|
||||||
|
# )
|
||||||
if result is not None and result > deployment_rpm:
|
if result is not None and result > deployment_rpm:
|
||||||
raise litellm.RateLimitError(
|
raise litellm.RateLimitError(
|
||||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue