This commit is contained in:
Krish Dholakia 2025-04-24 00:54:57 -07:00 committed by GitHub
commit 127964494f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 171 additions and 44 deletions

View file

@ -975,6 +975,7 @@ class RedisCache(BaseCache):
- increment_value: float
- ttl_seconds: int
"""
# don't waste a network request if there's nothing to increment
if len(increment_list) == 0:
return None

View file

@ -35,6 +35,9 @@ litellm_settings:
num_retries: 0
callbacks: ["datadog_llm_observability"]
check_provider_endpoint: true
cache: true
cache_params:
type: redis
files_settings:
- custom_llm_provider: gemini

View file

@ -1,7 +1,17 @@
import asyncio
import sys
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (
TYPE_CHECKING,
Any,
List,
Literal,
Optional,
Tuple,
TypedDict,
Union,
cast,
)
from fastapi import HTTPException
from pydantic import BaseModel
@ -16,6 +26,7 @@ from litellm.proxy.auth.auth_utils import (
get_key_model_rpm_limit,
get_key_model_tpm_limit,
)
from litellm.router_strategy.base_routing_strategy import BaseRoutingStrategy
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -31,17 +42,28 @@ else:
class CacheObject(TypedDict):
current_global_requests: Optional[dict]
request_count_api_key: Optional[dict]
request_count_api_key: Optional[int]
request_count_api_key_model: Optional[dict]
request_count_user_id: Optional[dict]
request_count_team_id: Optional[dict]
request_count_end_user_id: Optional[dict]
rpm_api_key: Optional[int]
tpm_api_key: Optional[int]
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
RateLimitGroups = Literal["request_count", "tpm", "rpm"]
class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
# Class variables or attributes
def __init__(self, internal_usage_cache: InternalUsageCache):
self.internal_usage_cache = internal_usage_cache
BaseRoutingStrategy.__init__(
self,
dual_cache=internal_usage_cache.dual_cache,
should_batch_redis_writes=True,
default_sync_interval=0.01,
)
def print_verbose(self, print_statement):
try:
@ -51,6 +73,68 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
except Exception:
pass
def _get_current_usage_key(
self,
user_api_key_dict: UserAPIKeyAuth,
precise_minute: str,
data: dict,
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
group: RateLimitGroups,
) -> str:
if rate_limit_type == "key":
return f"{user_api_key_dict.api_key}::{precise_minute}::{group}"
elif rate_limit_type == "model_per_key":
return f"{user_api_key_dict.api_key}::{data.get('model')}::{precise_minute}::{group}"
elif rate_limit_type == "user":
return f"{user_api_key_dict.user_id}::{precise_minute}::{group}"
elif rate_limit_type == "customer":
return f"{user_api_key_dict.end_user_id}::{precise_minute}::{group}"
elif rate_limit_type == "team":
return f"{user_api_key_dict.team_id}::{precise_minute}::{group}"
else:
raise ValueError(f"Invalid rate limit type: {rate_limit_type}")
async def check_key_in_limits_v2(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
max_parallel_requests: int,
precise_minute: str,
tpm_limit: int,
rpm_limit: int,
current_tpm: int,
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
):
verbose_proxy_logger.info(
f"Current Usage of {rate_limit_type} in this minute: {current_tpm}"
)
if current_tpm >= tpm_limit:
raise self.raise_rate_limit_error(
additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}"
)
## INCREMENT CURRENT USAGE
increment_list: List[Tuple[str, int]] = []
for group in ["request_count", "rpm"]:
key = self._get_current_usage_key(
user_api_key_dict=user_api_key_dict,
precise_minute=precise_minute,
data=data,
rate_limit_type=rate_limit_type,
group=cast(RateLimitGroups, group),
)
increment_list.append((key, 1))
results = await self._increment_value_list_in_current_window(
increment_list=increment_list,
ttl=60,
)
if results[0] >= max_parallel_requests or results[1] >= rpm_limit:
raise self.raise_rate_limit_error(
additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}"
)
async def check_key_in_limits(
self,
user_api_key_dict: UserAPIKeyAuth,
@ -143,6 +227,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
request_count_user_id: Optional[str],
request_count_team_id: Optional[str],
request_count_end_user_id: Optional[str],
rpm_api_key: Optional[str],
tpm_api_key: Optional[str],
parent_otel_span: Optional[Span] = None,
) -> CacheObject:
keys = [
@ -152,7 +238,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
request_count_user_id,
request_count_team_id,
request_count_end_user_id,
rpm_api_key,
tpm_api_key,
]
results = await self.internal_usage_cache.async_batch_get_cache(
keys=keys,
parent_otel_span=parent_otel_span,
@ -166,6 +255,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
request_count_user_id=None,
request_count_team_id=None,
request_count_end_user_id=None,
rpm_api_key=None,
tpm_api_key=None,
)
return CacheObject(
@ -175,6 +266,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
request_count_user_id=results[3],
request_count_team_id=results[4],
request_count_end_user_id=results[5],
rpm_api_key=results[6],
tpm_api_key=results[7],
)
async def async_pre_call_hook( # noqa: PLR0915
@ -254,6 +347,12 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
if api_key is not None
else None
),
rpm_api_key=(
f"{api_key}::{precise_minute}::rpm" if api_key is not None else None
),
tpm_api_key=(
f"{api_key}::{precise_minute}::tpm" if api_key is not None else None
),
request_count_api_key_model=(
f"{api_key}::{_model}::{precise_minute}::request_count"
if api_key is not None and _model is not None
@ -279,20 +378,31 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
if api_key is not None:
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
# CHECK IF REQUEST ALLOWED for key
await self.check_key_in_limits(
await self.check_key_in_limits_v2(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=max_parallel_requests,
current=cache_objects["request_count_api_key"],
request_count_api_key=request_count_api_key,
precise_minute=precise_minute,
tpm_limit=tpm_limit,
rpm_limit=rpm_limit,
current_tpm=cache_objects["tpm_api_key"] or 0,
rate_limit_type="key",
values_to_update_in_cache=values_to_update_in_cache,
)
# await self.check_key_in_limits(
# user_api_key_dict=user_api_key_dict,
# cache=cache,
# data=data,
# call_type=call_type,
# max_parallel_requests=max_parallel_requests,
# current=cache_objects["request_count_api_key"],
# request_count_api_key=request_count_api_key,
# tpm_limit=tpm_limit,
# rpm_limit=rpm_limit,
# rate_limit_type="key",
# values_to_update_in_cache=values_to_update_in_cache,
# )
# Check if request under RPM/TPM per model for a given API Key
if (
get_key_model_tpm_limit(user_api_key_dict) is not None
@ -434,13 +544,13 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
values_to_update_in_cache=values_to_update_in_cache,
)
asyncio.create_task(
self.internal_usage_cache.async_batch_set_cache(
cache_list=values_to_update_in_cache,
ttl=60,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) # don't block execution for cache updates
)
# asyncio.create_task(
# self.internal_usage_cache.async_batch_set_cache(
# cache_list=values_to_update_in_cache,
# ttl=60,
# litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
# ) # don't block execution for cache updates
# )
return
@ -803,6 +913,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"""
Retrieve the key's remaining rate limits.
"""
return
api_key = user_api_key_dict.api_key
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")

View file

@ -5,7 +5,7 @@ Base class across routing strategies to abstract commmon functions like batch in
import asyncio
import threading
from abc import ABC
from typing import List, Optional, Set, Union
from typing import List, Optional, Set, Tuple, Union
from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache
@ -22,26 +22,51 @@ class BaseRoutingStrategy(ABC):
):
self.dual_cache = dual_cache
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
self._sync_task: Optional[asyncio.Task[None]] = None
if should_batch_redis_writes:
try:
# Try to get existing event loop
loop = asyncio.get_event_loop()
if loop.is_running():
# If loop exists and is running, create task in existing loop
loop.create_task(
self.periodic_sync_in_memory_spend_with_redis(
default_sync_interval=default_sync_interval
)
)
else:
self._create_sync_thread(default_sync_interval)
except RuntimeError: # No event loop in current thread
self._create_sync_thread(default_sync_interval)
self.setup_sync_task(default_sync_interval)
self.in_memory_keys_to_update: set[
str
] = set() # Set with max size of 1000 keys
def setup_sync_task(self, default_sync_interval: Optional[Union[int, float]]):
"""Setup the sync task in a way that's compatible with FastAPI"""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._sync_task = loop.create_task(
self.periodic_sync_in_memory_spend_with_redis(
default_sync_interval=default_sync_interval
)
)
async def cleanup(self):
"""Cleanup method to be called when shutting down"""
if self._sync_task is not None:
self._sync_task.cancel()
try:
await self._sync_task
except asyncio.CancelledError:
pass
async def _increment_value_list_in_current_window(
self, increment_list: List[Tuple[str, int]], ttl: int
) -> List[float]:
"""
Increment a list of values in the current window
"""
results = []
for key, value in increment_list:
result = await self._increment_value_in_current_window(
key=key, value=value, ttl=ttl
)
results.append(result)
return results
async def _increment_value_in_current_window(
self, key: str, value: Union[int, float], ttl: int
):
@ -175,16 +200,3 @@ class BaseRoutingStrategy(ABC):
verbose_router_logger.exception(
f"Error syncing in-memory cache with Redis: {str(e)}"
)
def _create_sync_thread(self, default_sync_interval):
"""Helper method to create a new thread for periodic sync"""
thread = threading.Thread(
target=asyncio.run,
args=(
self.periodic_sync_in_memory_spend_with_redis(
default_sync_interval=default_sync_interval
),
),
daemon=True,
)
thread.start()