forked from phoenix/litellm-mirror
Merge pull request #2775 from BerriAI/litellm_redis_user_api_key_cache_v3
fix(tpm_rpm_limiter.py): enable redis caching for tpm/rpm checks on keys/user/teams
This commit is contained in:
commit
1356f6cd32
9 changed files with 461 additions and 14 deletions
|
@ -100,7 +100,13 @@ class RedisCache(BaseCache):
|
||||||
# if users don't provider one, use the default litellm cache
|
# if users don't provider one, use the default litellm cache
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, host=None, port=None, password=None, redis_flush_size=100, **kwargs
|
self,
|
||||||
|
host=None,
|
||||||
|
port=None,
|
||||||
|
password=None,
|
||||||
|
redis_flush_size=100,
|
||||||
|
namespace: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
from ._redis import get_redis_client, get_redis_connection_pool
|
from ._redis import get_redis_client, get_redis_connection_pool
|
||||||
|
|
||||||
|
@ -116,9 +122,10 @@ class RedisCache(BaseCache):
|
||||||
self.redis_client = get_redis_client(**redis_kwargs)
|
self.redis_client = get_redis_client(**redis_kwargs)
|
||||||
self.redis_kwargs = redis_kwargs
|
self.redis_kwargs = redis_kwargs
|
||||||
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
|
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
|
||||||
|
# redis namespaces
|
||||||
|
self.namespace = namespace
|
||||||
# for high traffic, we store the redis results in memory and then batch write to redis
|
# for high traffic, we store the redis results in memory and then batch write to redis
|
||||||
self.redis_batch_writing_buffer = []
|
self.redis_batch_writing_buffer: list = []
|
||||||
self.redis_flush_size = redis_flush_size
|
self.redis_flush_size = redis_flush_size
|
||||||
self.redis_version = "Unknown"
|
self.redis_version = "Unknown"
|
||||||
try:
|
try:
|
||||||
|
@ -133,11 +140,21 @@ class RedisCache(BaseCache):
|
||||||
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_and_fix_namespace(self, key: str) -> str:
|
||||||
|
"""
|
||||||
|
Make sure each key starts with the given namespace
|
||||||
|
"""
|
||||||
|
if self.namespace is not None and not key.startswith(self.namespace):
|
||||||
|
key = self.namespace + ":" + key
|
||||||
|
|
||||||
|
return key
|
||||||
|
|
||||||
def set_cache(self, key, value, **kwargs):
|
def set_cache(self, key, value, **kwargs):
|
||||||
ttl = kwargs.get("ttl", None)
|
ttl = kwargs.get("ttl", None)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}"
|
f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}"
|
||||||
)
|
)
|
||||||
|
key = self.check_and_fix_namespace(key=key)
|
||||||
try:
|
try:
|
||||||
self.redis_client.set(name=key, value=str(value), ex=ttl)
|
self.redis_client.set(name=key, value=str(value), ex=ttl)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -158,6 +175,7 @@ class RedisCache(BaseCache):
|
||||||
|
|
||||||
async def async_set_cache(self, key, value, **kwargs):
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
_redis_client = self.init_async_client()
|
_redis_client = self.init_async_client()
|
||||||
|
key = self.check_and_fix_namespace(key=key)
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
ttl = kwargs.get("ttl", None)
|
ttl = kwargs.get("ttl", None)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -187,6 +205,7 @@ class RedisCache(BaseCache):
|
||||||
async with redis_client.pipeline(transaction=True) as pipe:
|
async with redis_client.pipeline(transaction=True) as pipe:
|
||||||
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
||||||
for cache_key, cache_value in cache_list:
|
for cache_key, cache_value in cache_list:
|
||||||
|
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
|
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
|
||||||
)
|
)
|
||||||
|
@ -213,6 +232,7 @@ class RedisCache(BaseCache):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}",
|
f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}",
|
||||||
)
|
)
|
||||||
|
key = self.check_and_fix_namespace(key=key)
|
||||||
self.redis_batch_writing_buffer.append((key, value))
|
self.redis_batch_writing_buffer.append((key, value))
|
||||||
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
||||||
await self.flush_cache_buffer()
|
await self.flush_cache_buffer()
|
||||||
|
@ -242,6 +262,7 @@ class RedisCache(BaseCache):
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
try:
|
try:
|
||||||
|
key = self.check_and_fix_namespace(key=key)
|
||||||
print_verbose(f"Get Redis Cache: key: {key}")
|
print_verbose(f"Get Redis Cache: key: {key}")
|
||||||
cached_response = self.redis_client.get(key)
|
cached_response = self.redis_client.get(key)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -255,6 +276,7 @@ class RedisCache(BaseCache):
|
||||||
|
|
||||||
async def async_get_cache(self, key, **kwargs):
|
async def async_get_cache(self, key, **kwargs):
|
||||||
_redis_client = self.init_async_client()
|
_redis_client = self.init_async_client()
|
||||||
|
key = self.check_and_fix_namespace(key=key)
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
try:
|
try:
|
||||||
print_verbose(f"Get Async Redis Cache: key: {key}")
|
print_verbose(f"Get Async Redis Cache: key: {key}")
|
||||||
|
@ -281,6 +303,7 @@ class RedisCache(BaseCache):
|
||||||
async with redis_client.pipeline(transaction=True) as pipe:
|
async with redis_client.pipeline(transaction=True) as pipe:
|
||||||
# Queue the get operations in the pipeline for all keys.
|
# Queue the get operations in the pipeline for all keys.
|
||||||
for cache_key in key_list:
|
for cache_key in key_list:
|
||||||
|
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||||
pipe.get(cache_key) # Queue GET command in pipeline
|
pipe.get(cache_key) # Queue GET command in pipeline
|
||||||
|
|
||||||
# Execute the pipeline and await the results.
|
# Execute the pipeline and await the results.
|
||||||
|
@ -796,6 +819,8 @@ class DualCache(BaseCache):
|
||||||
self,
|
self,
|
||||||
in_memory_cache: Optional[InMemoryCache] = None,
|
in_memory_cache: Optional[InMemoryCache] = None,
|
||||||
redis_cache: Optional[RedisCache] = None,
|
redis_cache: Optional[RedisCache] = None,
|
||||||
|
default_in_memory_ttl: Optional[float] = None,
|
||||||
|
default_redis_ttl: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# If in_memory_cache is not provided, use the default InMemoryCache
|
# If in_memory_cache is not provided, use the default InMemoryCache
|
||||||
|
@ -803,11 +828,17 @@ class DualCache(BaseCache):
|
||||||
# If redis_cache is not provided, use the default RedisCache
|
# If redis_cache is not provided, use the default RedisCache
|
||||||
self.redis_cache = redis_cache
|
self.redis_cache = redis_cache
|
||||||
|
|
||||||
|
self.default_in_memory_ttl = default_in_memory_ttl
|
||||||
|
self.default_redis_ttl = default_redis_ttl
|
||||||
|
|
||||||
def set_cache(self, key, value, local_only: bool = False, **kwargs):
|
def set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||||
# Update both Redis and in-memory cache
|
# Update both Redis and in-memory cache
|
||||||
try:
|
try:
|
||||||
print_verbose(f"set cache: key: {key}; value: {value}")
|
print_verbose(f"set cache: key: {key}; value: {value}")
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
|
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
||||||
|
kwargs["ttl"] = self.default_in_memory_ttl
|
||||||
|
|
||||||
self.in_memory_cache.set_cache(key, value, **kwargs)
|
self.in_memory_cache.set_cache(key, value, **kwargs)
|
||||||
|
|
||||||
if self.redis_cache is not None and local_only == False:
|
if self.redis_cache is not None and local_only == False:
|
||||||
|
@ -823,7 +854,6 @@ class DualCache(BaseCache):
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
|
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
|
||||||
|
|
||||||
print_verbose(f"in_memory_result: {in_memory_result}")
|
|
||||||
if in_memory_result is not None:
|
if in_memory_result is not None:
|
||||||
result = in_memory_result
|
result = in_memory_result
|
||||||
|
|
||||||
|
@ -1008,6 +1038,9 @@ class Cache:
|
||||||
self.redis_flush_size = redis_flush_size
|
self.redis_flush_size = redis_flush_size
|
||||||
self.ttl = ttl
|
self.ttl = ttl
|
||||||
|
|
||||||
|
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
||||||
|
self.cache.namespace = self.namespace
|
||||||
|
|
||||||
def get_cache_key(self, *args, **kwargs):
|
def get_cache_key(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Get the cache key for the given arguments.
|
Get the cache key for the given arguments.
|
||||||
|
|
|
@ -12,6 +12,7 @@ from typing import Any, Literal, Union, BinaryIO
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import dotenv, traceback, random, asyncio, time, contextvars
|
import dotenv, traceback, random, asyncio, time, contextvars
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import litellm
|
import litellm
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
|
|
|
@ -6,8 +6,9 @@ model_list:
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
max_budget: 600020
|
cache: true
|
||||||
budget_duration: 30d
|
# max_budget: 600020
|
||||||
|
# budget_duration: 30d
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
|
|
380
litellm/proxy/hooks/tpm_rpm_limiter.py
Normal file
380
litellm/proxy/hooks/tpm_rpm_limiter.py
Normal file
|
@ -0,0 +1,380 @@
|
||||||
|
# What is this?
|
||||||
|
## Checks TPM/RPM Limits for a key/user/team on the proxy
|
||||||
|
## Works with Redis - if given
|
||||||
|
|
||||||
|
from typing import Optional, Literal
|
||||||
|
import litellm, traceback, sys
|
||||||
|
from litellm.caching import DualCache, RedisCache
|
||||||
|
from litellm.proxy._types import (
|
||||||
|
UserAPIKeyAuth,
|
||||||
|
LiteLLM_VerificationTokenView,
|
||||||
|
LiteLLM_UserTable,
|
||||||
|
LiteLLM_TeamTable,
|
||||||
|
)
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm import ModelResponse
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class _PROXY_MaxTPMRPMLimiter(CustomLogger):
|
||||||
|
user_api_key_cache = None
|
||||||
|
|
||||||
|
# Class variables or attributes
|
||||||
|
def __init__(self, redis_usage_cache: Optional[RedisCache]):
|
||||||
|
self.redis_usage_cache = redis_usage_cache
|
||||||
|
self.internal_cache = DualCache(
|
||||||
|
redis_cache=redis_usage_cache,
|
||||||
|
default_in_memory_ttl=10,
|
||||||
|
default_redis_ttl=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
def print_verbose(self, print_statement):
|
||||||
|
try:
|
||||||
|
verbose_proxy_logger.debug(print_statement)
|
||||||
|
if litellm.set_verbose:
|
||||||
|
print(print_statement) # noqa
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
## check if admin has set tpm/rpm limits for this key/user/team
|
||||||
|
|
||||||
|
def _check_limits_set(
|
||||||
|
self,
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
key: Optional[str],
|
||||||
|
user_id: Optional[str],
|
||||||
|
team_id: Optional[str],
|
||||||
|
) -> bool:
|
||||||
|
## key
|
||||||
|
if key is not None:
|
||||||
|
key_val = user_api_key_cache.get_cache(key=key)
|
||||||
|
if isinstance(key_val, dict):
|
||||||
|
key_val = LiteLLM_VerificationTokenView(**key_val)
|
||||||
|
|
||||||
|
if isinstance(key_val, LiteLLM_VerificationTokenView):
|
||||||
|
user_api_key_tpm_limit = key_val.tpm_limit
|
||||||
|
|
||||||
|
user_api_key_rpm_limit = key_val.rpm_limit
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_api_key_tpm_limit is not None
|
||||||
|
or user_api_key_rpm_limit is not None
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
## team
|
||||||
|
if team_id is not None:
|
||||||
|
team_val = user_api_key_cache.get_cache(key=team_id)
|
||||||
|
if isinstance(team_val, dict):
|
||||||
|
team_val = LiteLLM_TeamTable(**team_val)
|
||||||
|
|
||||||
|
if isinstance(team_val, LiteLLM_TeamTable):
|
||||||
|
team_tpm_limit = team_val.tpm_limit
|
||||||
|
|
||||||
|
team_rpm_limit = team_val.rpm_limit
|
||||||
|
|
||||||
|
if team_tpm_limit is not None or team_rpm_limit is not None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
## user
|
||||||
|
if user_id is not None:
|
||||||
|
user_val = user_api_key_cache.get_cache(key=user_id)
|
||||||
|
if isinstance(user_val, dict):
|
||||||
|
user_val = LiteLLM_UserTable(**user_val)
|
||||||
|
|
||||||
|
if isinstance(user_val, LiteLLM_UserTable):
|
||||||
|
user_tpm_limit = user_val.tpm_limit
|
||||||
|
|
||||||
|
user_rpm_limit = user_val.rpm_limit
|
||||||
|
|
||||||
|
if user_tpm_limit is not None or user_rpm_limit is not None:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def check_key_in_limits(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
current_minute_dict: dict,
|
||||||
|
tpm_limit: int,
|
||||||
|
rpm_limit: int,
|
||||||
|
request_count_api_key: str,
|
||||||
|
type: Literal["key", "user", "team"],
|
||||||
|
):
|
||||||
|
if type == "key" and user_api_key_dict.api_key is not None:
|
||||||
|
current = current_minute_dict["key"].get(user_api_key_dict.api_key, None)
|
||||||
|
elif type == "user" and user_api_key_dict.user_id is not None:
|
||||||
|
current = current_minute_dict["user"].get(user_api_key_dict.user_id, None)
|
||||||
|
elif type == "team" and user_api_key_dict.team_id is not None:
|
||||||
|
current = current_minute_dict["team"].get(user_api_key_dict.team_id, None)
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
if current is None:
|
||||||
|
if tpm_limit == 0 or rpm_limit == 0:
|
||||||
|
# base case
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429, detail="Max tpm/rpm limit reached."
|
||||||
|
)
|
||||||
|
elif current["current_tpm"] < tpm_limit and current["current_rpm"] < rpm_limit:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=429, detail="Max tpm/rpm limit reached.")
|
||||||
|
|
||||||
|
async def async_pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
cache: DualCache,
|
||||||
|
data: dict,
|
||||||
|
call_type: str,
|
||||||
|
):
|
||||||
|
self.print_verbose(
|
||||||
|
f"Inside Max TPM/RPM Limiter Pre-Call Hook - {user_api_key_dict}"
|
||||||
|
)
|
||||||
|
api_key = user_api_key_dict.api_key
|
||||||
|
# check if REQUEST ALLOWED for user_id
|
||||||
|
user_id = user_api_key_dict.user_id
|
||||||
|
## get team tpm/rpm limits
|
||||||
|
team_id = user_api_key_dict.team_id
|
||||||
|
|
||||||
|
_set_limits = self._check_limits_set(
|
||||||
|
user_api_key_cache=cache, key=api_key, user_id=user_id, team_id=team_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if _set_limits == False:
|
||||||
|
return
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
|
||||||
|
self.user_api_key_cache = cache
|
||||||
|
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
current_hour = datetime.now().strftime("%H")
|
||||||
|
current_minute = datetime.now().strftime("%M")
|
||||||
|
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||||
|
cache_key = "usage:{}".format(precise_minute)
|
||||||
|
current_minute_dict = await self.internal_cache.async_get_cache(
|
||||||
|
key=cache_key
|
||||||
|
) # {"usage:{curr_minute}": {"key": {<api_key>: {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}}}}
|
||||||
|
|
||||||
|
if current_minute_dict is None:
|
||||||
|
current_minute_dict = {"key": {}, "user": {}, "team": {}}
|
||||||
|
|
||||||
|
if api_key is not None:
|
||||||
|
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
||||||
|
if tpm_limit is None:
|
||||||
|
tpm_limit = sys.maxsize
|
||||||
|
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
|
||||||
|
if rpm_limit is None:
|
||||||
|
rpm_limit = sys.maxsize
|
||||||
|
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
|
||||||
|
await self.check_key_in_limits(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
current_minute_dict=current_minute_dict,
|
||||||
|
request_count_api_key=request_count_api_key,
|
||||||
|
tpm_limit=tpm_limit,
|
||||||
|
rpm_limit=rpm_limit,
|
||||||
|
type="key",
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_id is not None:
|
||||||
|
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits
|
||||||
|
|
||||||
|
# get user tpm/rpm limits
|
||||||
|
if _user_id_rate_limits is not None and isinstance(
|
||||||
|
_user_id_rate_limits, dict
|
||||||
|
):
|
||||||
|
user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None)
|
||||||
|
user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None)
|
||||||
|
if user_tpm_limit is None:
|
||||||
|
user_tpm_limit = sys.maxsize
|
||||||
|
if user_rpm_limit is None:
|
||||||
|
user_rpm_limit = sys.maxsize
|
||||||
|
|
||||||
|
# now do the same tpm/rpm checks
|
||||||
|
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
||||||
|
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||||
|
await self.check_key_in_limits(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
current_minute_dict=current_minute_dict,
|
||||||
|
request_count_api_key=request_count_api_key,
|
||||||
|
tpm_limit=user_tpm_limit,
|
||||||
|
rpm_limit=user_rpm_limit,
|
||||||
|
type="user",
|
||||||
|
)
|
||||||
|
|
||||||
|
# TEAM RATE LIMITS
|
||||||
|
if team_id is not None:
|
||||||
|
team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize)
|
||||||
|
|
||||||
|
if team_tpm_limit is None:
|
||||||
|
team_tpm_limit = sys.maxsize
|
||||||
|
team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize)
|
||||||
|
if team_rpm_limit is None:
|
||||||
|
team_rpm_limit = sys.maxsize
|
||||||
|
|
||||||
|
if team_tpm_limit is None:
|
||||||
|
team_tpm_limit = sys.maxsize
|
||||||
|
if team_rpm_limit is None:
|
||||||
|
team_rpm_limit = sys.maxsize
|
||||||
|
|
||||||
|
# now do the same tpm/rpm checks
|
||||||
|
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
|
||||||
|
|
||||||
|
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||||
|
await self.check_key_in_limits(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
current_minute_dict=current_minute_dict,
|
||||||
|
request_count_api_key=request_count_api_key,
|
||||||
|
tpm_limit=team_tpm_limit,
|
||||||
|
rpm_limit=team_rpm_limit,
|
||||||
|
type="team",
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
self.print_verbose(f"INSIDE TPM RPM Limiter ASYNC SUCCESS LOGGING")
|
||||||
|
|
||||||
|
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
|
||||||
|
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"user_api_key_user_id", None
|
||||||
|
)
|
||||||
|
user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"user_api_key_team_id", None
|
||||||
|
)
|
||||||
|
|
||||||
|
_limits_set = self._check_limits_set(
|
||||||
|
user_api_key_cache=self.user_api_key_cache,
|
||||||
|
key=user_api_key,
|
||||||
|
user_id=user_api_key_user_id,
|
||||||
|
team_id=user_api_key_team_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if _limits_set == False: # don't waste cache calls if no tpm/rpm limits set
|
||||||
|
return
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
current_hour = datetime.now().strftime("%H")
|
||||||
|
current_minute = datetime.now().strftime("%M")
|
||||||
|
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||||
|
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
if isinstance(response_obj, ModelResponse):
|
||||||
|
total_tokens = response_obj.usage.total_tokens
|
||||||
|
|
||||||
|
"""
|
||||||
|
- get value from redis
|
||||||
|
- increment requests + 1
|
||||||
|
- increment tpm + 1
|
||||||
|
- increment rpm + 1
|
||||||
|
- update value in-memory + redis
|
||||||
|
"""
|
||||||
|
cache_key = "usage:{}".format(precise_minute)
|
||||||
|
if (
|
||||||
|
self.internal_cache.redis_cache is not None
|
||||||
|
): # get straight from redis if possible
|
||||||
|
current_minute_dict = (
|
||||||
|
await self.internal_cache.redis_cache.async_get_cache(
|
||||||
|
key=cache_key,
|
||||||
|
)
|
||||||
|
) # {"usage:{current_minute}": {"key": {}, "team": {}, "user": {}}}
|
||||||
|
else:
|
||||||
|
current_minute_dict = await self.internal_cache.async_get_cache(
|
||||||
|
key=cache_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_minute_dict is None:
|
||||||
|
current_minute_dict = {"key": {}, "user": {}, "team": {}}
|
||||||
|
|
||||||
|
_cache_updated = False # check if a cache update is required. prevent unnecessary rewrites.
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Update usage - API Key
|
||||||
|
# ------------
|
||||||
|
|
||||||
|
if user_api_key is not None:
|
||||||
|
_cache_updated = True
|
||||||
|
## API KEY ##
|
||||||
|
if user_api_key in current_minute_dict["key"]:
|
||||||
|
current_key_usage = current_minute_dict["key"][user_api_key]
|
||||||
|
new_val = {
|
||||||
|
"current_tpm": current_key_usage["current_tpm"] + total_tokens,
|
||||||
|
"current_rpm": current_key_usage["current_rpm"] + 1,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
new_val = {
|
||||||
|
"current_tpm": total_tokens,
|
||||||
|
"current_rpm": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
current_minute_dict["key"][user_api_key] = new_val
|
||||||
|
|
||||||
|
self.print_verbose(
|
||||||
|
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Update usage - User
|
||||||
|
# ------------
|
||||||
|
if user_api_key_user_id is not None:
|
||||||
|
_cache_updated = True
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
if isinstance(response_obj, ModelResponse):
|
||||||
|
total_tokens = response_obj.usage.total_tokens
|
||||||
|
|
||||||
|
if user_api_key_user_id in current_minute_dict["key"]:
|
||||||
|
current_key_usage = current_minute_dict["key"][user_api_key_user_id]
|
||||||
|
new_val = {
|
||||||
|
"current_tpm": current_key_usage["current_tpm"] + total_tokens,
|
||||||
|
"current_rpm": current_key_usage["current_rpm"] + 1,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
new_val = {
|
||||||
|
"current_tpm": total_tokens,
|
||||||
|
"current_rpm": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
current_minute_dict["user"][user_api_key_user_id] = new_val
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Update usage - Team
|
||||||
|
# ------------
|
||||||
|
if user_api_key_team_id is not None:
|
||||||
|
_cache_updated = True
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
if isinstance(response_obj, ModelResponse):
|
||||||
|
total_tokens = response_obj.usage.total_tokens
|
||||||
|
|
||||||
|
if user_api_key_team_id in current_minute_dict["key"]:
|
||||||
|
current_key_usage = current_minute_dict["key"][user_api_key_team_id]
|
||||||
|
new_val = {
|
||||||
|
"current_tpm": current_key_usage["current_tpm"] + total_tokens,
|
||||||
|
"current_rpm": current_key_usage["current_rpm"] + 1,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
new_val = {
|
||||||
|
"current_tpm": total_tokens,
|
||||||
|
"current_rpm": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
current_minute_dict["team"][user_api_key_team_id] = new_val
|
||||||
|
|
||||||
|
if _cache_updated == True:
|
||||||
|
await self.internal_cache.async_set_cache(
|
||||||
|
key=cache_key, value=current_minute_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.print_verbose(e) # noqa
|
|
@ -102,7 +102,7 @@ from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||||
import pydantic
|
import pydantic
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache, RedisCache
|
||||||
from litellm.proxy.health_check import perform_health_check
|
from litellm.proxy.health_check import perform_health_check
|
||||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
|
@ -281,6 +281,9 @@ otel_logging = False
|
||||||
prisma_client: Optional[PrismaClient] = None
|
prisma_client: Optional[PrismaClient] = None
|
||||||
custom_db_client: Optional[DBClient] = None
|
custom_db_client: Optional[DBClient] = None
|
||||||
user_api_key_cache = DualCache()
|
user_api_key_cache = DualCache()
|
||||||
|
redis_usage_cache: Optional[RedisCache] = (
|
||||||
|
None # redis cache used for tracking spend, tpm/rpm limits
|
||||||
|
)
|
||||||
user_custom_auth = None
|
user_custom_auth = None
|
||||||
user_custom_key_generate = None
|
user_custom_key_generate = None
|
||||||
use_background_health_checks = None
|
use_background_health_checks = None
|
||||||
|
@ -299,7 +302,9 @@ disable_spend_logs = False
|
||||||
jwt_handler = JWTHandler()
|
jwt_handler = JWTHandler()
|
||||||
prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None
|
prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None
|
||||||
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
proxy_logging_obj = ProxyLogging(
|
||||||
|
user_api_key_cache=user_api_key_cache, redis_usage_cache=redis_usage_cache
|
||||||
|
)
|
||||||
### REDIS QUEUE ###
|
### REDIS QUEUE ###
|
||||||
async_result = None
|
async_result = None
|
||||||
celery_app_conn = None
|
celery_app_conn = None
|
||||||
|
@ -909,6 +914,10 @@ async def user_api_key_auth(
|
||||||
models=valid_token.team_models,
|
models=valid_token.team_models,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
user_api_key_cache.set_cache(
|
||||||
|
key=valid_token.team_id, value=_team_obj
|
||||||
|
) # save team table in cache - used for tpm/rpm limiting - tpm_rpm_limiter.py
|
||||||
|
|
||||||
_end_user_object = None
|
_end_user_object = None
|
||||||
if "user" in request_data:
|
if "user" in request_data:
|
||||||
_id = "end_user_id:{}".format(request_data["user"])
|
_id = "end_user_id:{}".format(request_data["user"])
|
||||||
|
@ -1905,7 +1914,7 @@ class ProxyConfig:
|
||||||
"""
|
"""
|
||||||
Load config values into proxy global state
|
Load config values into proxy global state
|
||||||
"""
|
"""
|
||||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj
|
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache
|
||||||
|
|
||||||
# Load existing config
|
# Load existing config
|
||||||
config = await self.get_config(config_file_path=config_file_path)
|
config = await self.get_config(config_file_path=config_file_path)
|
||||||
|
@ -1967,6 +1976,7 @@ class ProxyConfig:
|
||||||
"password": cache_password,
|
"password": cache_password,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
||||||
print( # noqa
|
print( # noqa
|
||||||
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}"
|
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}"
|
||||||
|
@ -1991,7 +2001,14 @@ class ProxyConfig:
|
||||||
cache_params[key] = litellm.get_secret(value)
|
cache_params[key] = litellm.get_secret(value)
|
||||||
|
|
||||||
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
||||||
|
|
||||||
litellm.cache = Cache(**cache_params)
|
litellm.cache = Cache(**cache_params)
|
||||||
|
|
||||||
|
if litellm.cache is not None and isinstance(
|
||||||
|
litellm.cache.cache, RedisCache
|
||||||
|
):
|
||||||
|
## INIT PROXY REDIS USAGE CLIENT ##
|
||||||
|
redis_usage_cache = litellm.cache.cache
|
||||||
print( # noqa
|
print( # noqa
|
||||||
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
|
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -12,13 +12,14 @@ from litellm.proxy._types import (
|
||||||
LiteLLM_TeamTable,
|
LiteLLM_TeamTable,
|
||||||
Member,
|
Member,
|
||||||
)
|
)
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache, RedisCache
|
||||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||||
from litellm.proxy.hooks.parallel_request_limiter import (
|
from litellm.proxy.hooks.parallel_request_limiter import (
|
||||||
_PROXY_MaxParallelRequestsHandler,
|
_PROXY_MaxParallelRequestsHandler,
|
||||||
)
|
)
|
||||||
from litellm import ModelResponse, EmbeddingResponse, ImageResponse
|
from litellm import ModelResponse, EmbeddingResponse, ImageResponse
|
||||||
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||||
|
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
|
||||||
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.proxy.db.base_client import CustomDB
|
from litellm.proxy.db.base_client import CustomDB
|
||||||
|
@ -46,16 +47,23 @@ class ProxyLogging:
|
||||||
- support the max parallel request integration
|
- support the max parallel request integration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, user_api_key_cache: DualCache):
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
redis_usage_cache: Optional[RedisCache] = None,
|
||||||
|
):
|
||||||
## INITIALIZE LITELLM CALLBACKS ##
|
## INITIALIZE LITELLM CALLBACKS ##
|
||||||
self.call_details: dict = {}
|
self.call_details: dict = {}
|
||||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||||
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler()
|
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler()
|
||||||
|
self.max_tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(
|
||||||
|
redis_usage_cache=redis_usage_cache
|
||||||
|
)
|
||||||
self.max_budget_limiter = _PROXY_MaxBudgetLimiter()
|
self.max_budget_limiter = _PROXY_MaxBudgetLimiter()
|
||||||
self.cache_control_check = _PROXY_CacheControlCheck()
|
self.cache_control_check = _PROXY_CacheControlCheck()
|
||||||
self.alerting: Optional[List] = None
|
self.alerting: Optional[List] = None
|
||||||
self.alerting_threshold: float = 300 # default to 5 min. threshold
|
self.alerting_threshold: float = 300 # default to 5 min. threshold
|
||||||
pass
|
self.redis_usage_cache = redis_usage_cache
|
||||||
|
|
||||||
def update_values(
|
def update_values(
|
||||||
self, alerting: Optional[List], alerting_threshold: Optional[float]
|
self, alerting: Optional[List], alerting_threshold: Optional[float]
|
||||||
|
@ -67,6 +75,7 @@ class ProxyLogging:
|
||||||
def _init_litellm_callbacks(self):
|
def _init_litellm_callbacks(self):
|
||||||
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
||||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||||
|
litellm.callbacks.append(self.max_tpm_rpm_limiter)
|
||||||
litellm.callbacks.append(self.max_budget_limiter)
|
litellm.callbacks.append(self.max_budget_limiter)
|
||||||
litellm.callbacks.append(self.cache_control_check)
|
litellm.callbacks.append(self.cache_control_check)
|
||||||
litellm.success_callback.append(self.response_taking_too_long_callback)
|
litellm.success_callback.append(self.response_taking_too_long_callback)
|
||||||
|
|
|
@ -61,7 +61,7 @@ from litellm.proxy.utils import DBClient
|
||||||
from starlette.datastructures import URL
|
from starlette.datastructures import URL
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache(), redis_usage_cache=None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
@ -253,7 +253,12 @@ class CompletionCustomHandler(
|
||||||
assert isinstance(end_time, datetime)
|
assert isinstance(end_time, datetime)
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse)
|
response_obj,
|
||||||
|
(
|
||||||
|
litellm.ModelResponse,
|
||||||
|
litellm.EmbeddingResponse,
|
||||||
|
litellm.TextCompletionResponse,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
## KWARGS
|
## KWARGS
|
||||||
assert isinstance(kwargs["model"], str)
|
assert isinstance(kwargs["model"], str)
|
||||||
|
|
|
@ -167,6 +167,7 @@ async def chat_completion_streaming(session, key, model="gpt-4"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Global proxy now tracked via `/global/spend/logs`")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_global_proxy_budget_update():
|
async def test_global_proxy_budget_update():
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue