mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge 897eb46320
into b82af5b826
This commit is contained in:
commit
127964494f
4 changed files with 171 additions and 44 deletions
|
@ -975,6 +975,7 @@ class RedisCache(BaseCache):
|
|||
- increment_value: float
|
||||
- ttl_seconds: int
|
||||
"""
|
||||
|
||||
# don't waste a network request if there's nothing to increment
|
||||
if len(increment_list) == 0:
|
||||
return None
|
||||
|
|
|
@ -35,6 +35,9 @@ litellm_settings:
|
|||
num_retries: 0
|
||||
callbacks: ["datadog_llm_observability"]
|
||||
check_provider_endpoint: true
|
||||
cache: true
|
||||
cache_params:
|
||||
type: redis
|
||||
|
||||
files_settings:
|
||||
- custom_llm_provider: gemini
|
||||
|
|
|
@ -1,7 +1,17 @@
|
|||
import asyncio
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
@ -16,6 +26,7 @@ from litellm.proxy.auth.auth_utils import (
|
|||
get_key_model_rpm_limit,
|
||||
get_key_model_tpm_limit,
|
||||
)
|
||||
from litellm.router_strategy.base_routing_strategy import BaseRoutingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
@ -31,17 +42,28 @@ else:
|
|||
|
||||
class CacheObject(TypedDict):
|
||||
current_global_requests: Optional[dict]
|
||||
request_count_api_key: Optional[dict]
|
||||
request_count_api_key: Optional[int]
|
||||
request_count_api_key_model: Optional[dict]
|
||||
request_count_user_id: Optional[dict]
|
||||
request_count_team_id: Optional[dict]
|
||||
request_count_end_user_id: Optional[dict]
|
||||
rpm_api_key: Optional[int]
|
||||
tpm_api_key: Optional[int]
|
||||
|
||||
|
||||
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||
RateLimitGroups = Literal["request_count", "tpm", "rpm"]
|
||||
|
||||
|
||||
class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self, internal_usage_cache: InternalUsageCache):
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
BaseRoutingStrategy.__init__(
|
||||
self,
|
||||
dual_cache=internal_usage_cache.dual_cache,
|
||||
should_batch_redis_writes=True,
|
||||
default_sync_interval=0.01,
|
||||
)
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
try:
|
||||
|
@ -51,6 +73,68 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_current_usage_key(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
precise_minute: str,
|
||||
data: dict,
|
||||
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
|
||||
group: RateLimitGroups,
|
||||
) -> str:
|
||||
if rate_limit_type == "key":
|
||||
return f"{user_api_key_dict.api_key}::{precise_minute}::{group}"
|
||||
elif rate_limit_type == "model_per_key":
|
||||
return f"{user_api_key_dict.api_key}::{data.get('model')}::{precise_minute}::{group}"
|
||||
elif rate_limit_type == "user":
|
||||
return f"{user_api_key_dict.user_id}::{precise_minute}::{group}"
|
||||
elif rate_limit_type == "customer":
|
||||
return f"{user_api_key_dict.end_user_id}::{precise_minute}::{group}"
|
||||
elif rate_limit_type == "team":
|
||||
return f"{user_api_key_dict.team_id}::{precise_minute}::{group}"
|
||||
else:
|
||||
raise ValueError(f"Invalid rate limit type: {rate_limit_type}")
|
||||
|
||||
async def check_key_in_limits_v2(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
data: dict,
|
||||
max_parallel_requests: int,
|
||||
precise_minute: str,
|
||||
tpm_limit: int,
|
||||
rpm_limit: int,
|
||||
current_tpm: int,
|
||||
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
|
||||
):
|
||||
verbose_proxy_logger.info(
|
||||
f"Current Usage of {rate_limit_type} in this minute: {current_tpm}"
|
||||
)
|
||||
if current_tpm >= tpm_limit:
|
||||
raise self.raise_rate_limit_error(
|
||||
additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}"
|
||||
)
|
||||
|
||||
## INCREMENT CURRENT USAGE
|
||||
increment_list: List[Tuple[str, int]] = []
|
||||
for group in ["request_count", "rpm"]:
|
||||
key = self._get_current_usage_key(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
precise_minute=precise_minute,
|
||||
data=data,
|
||||
rate_limit_type=rate_limit_type,
|
||||
group=cast(RateLimitGroups, group),
|
||||
)
|
||||
increment_list.append((key, 1))
|
||||
|
||||
results = await self._increment_value_list_in_current_window(
|
||||
increment_list=increment_list,
|
||||
ttl=60,
|
||||
)
|
||||
|
||||
if results[0] >= max_parallel_requests or results[1] >= rpm_limit:
|
||||
raise self.raise_rate_limit_error(
|
||||
additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}"
|
||||
)
|
||||
|
||||
async def check_key_in_limits(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
|
@ -143,6 +227,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
request_count_user_id: Optional[str],
|
||||
request_count_team_id: Optional[str],
|
||||
request_count_end_user_id: Optional[str],
|
||||
rpm_api_key: Optional[str],
|
||||
tpm_api_key: Optional[str],
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
) -> CacheObject:
|
||||
keys = [
|
||||
|
@ -152,7 +238,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
request_count_user_id,
|
||||
request_count_team_id,
|
||||
request_count_end_user_id,
|
||||
rpm_api_key,
|
||||
tpm_api_key,
|
||||
]
|
||||
|
||||
results = await self.internal_usage_cache.async_batch_get_cache(
|
||||
keys=keys,
|
||||
parent_otel_span=parent_otel_span,
|
||||
|
@ -166,6 +255,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
request_count_user_id=None,
|
||||
request_count_team_id=None,
|
||||
request_count_end_user_id=None,
|
||||
rpm_api_key=None,
|
||||
tpm_api_key=None,
|
||||
)
|
||||
|
||||
return CacheObject(
|
||||
|
@ -175,6 +266,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
request_count_user_id=results[3],
|
||||
request_count_team_id=results[4],
|
||||
request_count_end_user_id=results[5],
|
||||
rpm_api_key=results[6],
|
||||
tpm_api_key=results[7],
|
||||
)
|
||||
|
||||
async def async_pre_call_hook( # noqa: PLR0915
|
||||
|
@ -254,6 +347,12 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
if api_key is not None
|
||||
else None
|
||||
),
|
||||
rpm_api_key=(
|
||||
f"{api_key}::{precise_minute}::rpm" if api_key is not None else None
|
||||
),
|
||||
tpm_api_key=(
|
||||
f"{api_key}::{precise_minute}::tpm" if api_key is not None else None
|
||||
),
|
||||
request_count_api_key_model=(
|
||||
f"{api_key}::{_model}::{precise_minute}::request_count"
|
||||
if api_key is not None and _model is not None
|
||||
|
@ -279,20 +378,31 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
if api_key is not None:
|
||||
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
|
||||
# CHECK IF REQUEST ALLOWED for key
|
||||
await self.check_key_in_limits(
|
||||
await self.check_key_in_limits_v2(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
max_parallel_requests=max_parallel_requests,
|
||||
current=cache_objects["request_count_api_key"],
|
||||
request_count_api_key=request_count_api_key,
|
||||
precise_minute=precise_minute,
|
||||
tpm_limit=tpm_limit,
|
||||
rpm_limit=rpm_limit,
|
||||
current_tpm=cache_objects["tpm_api_key"] or 0,
|
||||
rate_limit_type="key",
|
||||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
|
||||
# await self.check_key_in_limits(
|
||||
# user_api_key_dict=user_api_key_dict,
|
||||
# cache=cache,
|
||||
# data=data,
|
||||
# call_type=call_type,
|
||||
# max_parallel_requests=max_parallel_requests,
|
||||
# current=cache_objects["request_count_api_key"],
|
||||
# request_count_api_key=request_count_api_key,
|
||||
# tpm_limit=tpm_limit,
|
||||
# rpm_limit=rpm_limit,
|
||||
# rate_limit_type="key",
|
||||
# values_to_update_in_cache=values_to_update_in_cache,
|
||||
# )
|
||||
|
||||
# Check if request under RPM/TPM per model for a given API Key
|
||||
if (
|
||||
get_key_model_tpm_limit(user_api_key_dict) is not None
|
||||
|
@ -434,13 +544,13 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.internal_usage_cache.async_batch_set_cache(
|
||||
cache_list=values_to_update_in_cache,
|
||||
ttl=60,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
) # don't block execution for cache updates
|
||||
)
|
||||
# asyncio.create_task(
|
||||
# self.internal_usage_cache.async_batch_set_cache(
|
||||
# cache_list=values_to_update_in_cache,
|
||||
# ttl=60,
|
||||
# litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
# ) # don't block execution for cache updates
|
||||
# )
|
||||
|
||||
return
|
||||
|
||||
|
@ -803,6 +913,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
"""
|
||||
Retrieve the key's remaining rate limits.
|
||||
"""
|
||||
return
|
||||
api_key = user_api_key_dict.api_key
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
|
|
|
@ -5,7 +5,7 @@ Base class across routing strategies to abstract commmon functions like batch in
|
|||
import asyncio
|
||||
import threading
|
||||
from abc import ABC
|
||||
from typing import List, Optional, Set, Union
|
||||
from typing import List, Optional, Set, Tuple, Union
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
|
@ -22,26 +22,51 @@ class BaseRoutingStrategy(ABC):
|
|||
):
|
||||
self.dual_cache = dual_cache
|
||||
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
|
||||
self._sync_task: Optional[asyncio.Task[None]] = None
|
||||
if should_batch_redis_writes:
|
||||
try:
|
||||
# Try to get existing event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# If loop exists and is running, create task in existing loop
|
||||
loop.create_task(
|
||||
self.periodic_sync_in_memory_spend_with_redis(
|
||||
default_sync_interval=default_sync_interval
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._create_sync_thread(default_sync_interval)
|
||||
except RuntimeError: # No event loop in current thread
|
||||
self._create_sync_thread(default_sync_interval)
|
||||
self.setup_sync_task(default_sync_interval)
|
||||
|
||||
self.in_memory_keys_to_update: set[
|
||||
str
|
||||
] = set() # Set with max size of 1000 keys
|
||||
|
||||
def setup_sync_task(self, default_sync_interval: Optional[Union[int, float]]):
|
||||
"""Setup the sync task in a way that's compatible with FastAPI"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
self._sync_task = loop.create_task(
|
||||
self.periodic_sync_in_memory_spend_with_redis(
|
||||
default_sync_interval=default_sync_interval
|
||||
)
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup method to be called when shutting down"""
|
||||
if self._sync_task is not None:
|
||||
self._sync_task.cancel()
|
||||
try:
|
||||
await self._sync_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _increment_value_list_in_current_window(
|
||||
self, increment_list: List[Tuple[str, int]], ttl: int
|
||||
) -> List[float]:
|
||||
"""
|
||||
Increment a list of values in the current window
|
||||
"""
|
||||
results = []
|
||||
for key, value in increment_list:
|
||||
result = await self._increment_value_in_current_window(
|
||||
key=key, value=value, ttl=ttl
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
async def _increment_value_in_current_window(
|
||||
self, key: str, value: Union[int, float], ttl: int
|
||||
):
|
||||
|
@ -175,16 +200,3 @@ class BaseRoutingStrategy(ABC):
|
|||
verbose_router_logger.exception(
|
||||
f"Error syncing in-memory cache with Redis: {str(e)}"
|
||||
)
|
||||
|
||||
def _create_sync_thread(self, default_sync_interval):
|
||||
"""Helper method to create a new thread for periodic sync"""
|
||||
thread = threading.Thread(
|
||||
target=asyncio.run,
|
||||
args=(
|
||||
self.periodic_sync_in_memory_spend_with_redis(
|
||||
default_sync_interval=default_sync_interval
|
||||
),
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue