forked from phoenix/litellm-mirror
use redis async_increment_pipeline
This commit is contained in:
parent
5dd8726685
commit
24ab979486
1 changed files with 15 additions and 36 deletions
|
@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union
|
|||
import litellm
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.redis_cache import RedisPipelineIncrementOperation
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
|
@ -51,7 +52,7 @@ DEFAULT_REDIS_SYNC_INTERVAL = 60
|
|||
class ProviderBudgetLimiting(CustomLogger):
|
||||
def __init__(self, router_cache: DualCache, provider_budget_config: dict):
|
||||
self.router_cache = router_cache
|
||||
self.last_synced_values = {} # To track last synced values for each key
|
||||
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
|
||||
asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis())
|
||||
|
||||
# cast elements of provider_budget_config to ProviderBudgetInfo
|
||||
|
@ -214,14 +215,17 @@ class ProviderBudgetLimiting(CustomLogger):
|
|||
|
||||
spend_key = f"provider_spend:{custom_llm_provider}:{budget_config.time_period}"
|
||||
ttl_seconds = duration_in_seconds(duration=budget_config.time_period)
|
||||
verbose_router_logger.debug(
|
||||
f"Incrementing spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}"
|
||||
|
||||
# Create RedisPipelineIncrementOperation object
|
||||
increment_op = RedisPipelineIncrementOperation(
|
||||
key=spend_key, increment_value=response_cost, ttl_seconds=ttl_seconds
|
||||
)
|
||||
# Increment the spend in Redis and set TTL
|
||||
|
||||
await self.router_cache.in_memory_cache.async_increment(
|
||||
key=spend_key,
|
||||
value=response_cost,
|
||||
)
|
||||
self.redis_increment_operation_queue.append(increment_op)
|
||||
verbose_router_logger.debug(
|
||||
f"Incremented spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}"
|
||||
)
|
||||
|
@ -256,40 +260,15 @@ class ProviderBudgetLimiting(CustomLogger):
|
|||
try:
|
||||
if not self.router_cache.redis_cache:
|
||||
return # Redis is not initialized
|
||||
|
||||
# Build cache keys for all providers
|
||||
cache_keys = [
|
||||
f"provider_spend:{provider}:{config.time_period}"
|
||||
for provider, config in self.provider_budget_config.items()
|
||||
if config is not None
|
||||
]
|
||||
|
||||
# Fetch current in-memory values
|
||||
current_values = (
|
||||
await self.router_cache.in_memory_cache.async_batch_get_cache(
|
||||
keys=cache_keys
|
||||
)
|
||||
verbose_router_logger.debug(
|
||||
"Pushing Redis Increment Pipeline for queue: %s",
|
||||
self.redis_increment_operation_queue,
|
||||
)
|
||||
await self.router_cache.redis_cache.async_increment_pipeline(
|
||||
increment_list=self.redis_increment_operation_queue,
|
||||
)
|
||||
|
||||
for key, current_value in zip(cache_keys, current_values):
|
||||
if current_value is None:
|
||||
continue # Skip keys with no in-memory value
|
||||
|
||||
# Get the last synced value (default to 0 if not synced before)
|
||||
last_synced = self.last_synced_values.get(key, 0.0)
|
||||
# Calculate the delta to push to Redis
|
||||
delta = float(current_value) - last_synced
|
||||
if delta > 0: # Only push if there is a positive increment
|
||||
await self.router_cache.redis_cache.async_increment(
|
||||
key=key,
|
||||
value=delta,
|
||||
)
|
||||
verbose_router_logger.debug(
|
||||
f"Pushed delta to Redis for {key}: {delta} (last synced: {last_synced}, current: {current_value})"
|
||||
)
|
||||
|
||||
# Update last synced value
|
||||
self.last_synced_values[key] = float(current_value)
|
||||
self.redis_increment_operation_queue = []
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue