Merge pull request #2775 from BerriAI/litellm_redis_user_api_key_cache_v3

fix(tpm_rpm_limiter.py): enable redis caching for tpm/rpm checks on keys/user/teams
This commit is contained in:
Krish Dholakia 2024-03-30 22:07:05 -07:00 committed by GitHub
commit 1356f6cd32
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 461 additions and 14 deletions

View file

@ -100,7 +100,13 @@ class RedisCache(BaseCache):
# if users don't provider one, use the default litellm cache # if users don't provider one, use the default litellm cache
def __init__( def __init__(
self, host=None, port=None, password=None, redis_flush_size=100, **kwargs self,
host=None,
port=None,
password=None,
redis_flush_size=100,
namespace: Optional[str] = None,
**kwargs,
): ):
from ._redis import get_redis_client, get_redis_connection_pool from ._redis import get_redis_client, get_redis_connection_pool
@ -116,9 +122,10 @@ class RedisCache(BaseCache):
self.redis_client = get_redis_client(**redis_kwargs) self.redis_client = get_redis_client(**redis_kwargs)
self.redis_kwargs = redis_kwargs self.redis_kwargs = redis_kwargs
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs) self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
# redis namespaces
self.namespace = namespace
# for high traffic, we store the redis results in memory and then batch write to redis # for high traffic, we store the redis results in memory and then batch write to redis
self.redis_batch_writing_buffer = [] self.redis_batch_writing_buffer: list = []
self.redis_flush_size = redis_flush_size self.redis_flush_size = redis_flush_size
self.redis_version = "Unknown" self.redis_version = "Unknown"
try: try:
@ -133,11 +140,21 @@ class RedisCache(BaseCache):
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
) )
def check_and_fix_namespace(self, key: str) -> str:
"""
Make sure each key starts with the given namespace
"""
if self.namespace is not None and not key.startswith(self.namespace):
key = self.namespace + ":" + key
return key
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
ttl = kwargs.get("ttl", None) ttl = kwargs.get("ttl", None)
print_verbose( print_verbose(
f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}" f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}"
) )
key = self.check_and_fix_namespace(key=key)
try: try:
self.redis_client.set(name=key, value=str(value), ex=ttl) self.redis_client.set(name=key, value=str(value), ex=ttl)
except Exception as e: except Exception as e:
@ -158,6 +175,7 @@ class RedisCache(BaseCache):
async def async_set_cache(self, key, value, **kwargs): async def async_set_cache(self, key, value, **kwargs):
_redis_client = self.init_async_client() _redis_client = self.init_async_client()
key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client: async with _redis_client as redis_client:
ttl = kwargs.get("ttl", None) ttl = kwargs.get("ttl", None)
print_verbose( print_verbose(
@ -187,6 +205,7 @@ class RedisCache(BaseCache):
async with redis_client.pipeline(transaction=True) as pipe: async with redis_client.pipeline(transaction=True) as pipe:
# Iterate through each key-value pair in the cache_list and set them in the pipeline. # Iterate through each key-value pair in the cache_list and set them in the pipeline.
for cache_key, cache_value in cache_list: for cache_key, cache_value in cache_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
print_verbose( print_verbose(
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}" f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
) )
@ -213,6 +232,7 @@ class RedisCache(BaseCache):
print_verbose( print_verbose(
f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}", f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}",
) )
key = self.check_and_fix_namespace(key=key)
self.redis_batch_writing_buffer.append((key, value)) self.redis_batch_writing_buffer.append((key, value))
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
await self.flush_cache_buffer() await self.flush_cache_buffer()
@ -242,6 +262,7 @@ class RedisCache(BaseCache):
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
try: try:
key = self.check_and_fix_namespace(key=key)
print_verbose(f"Get Redis Cache: key: {key}") print_verbose(f"Get Redis Cache: key: {key}")
cached_response = self.redis_client.get(key) cached_response = self.redis_client.get(key)
print_verbose( print_verbose(
@ -255,6 +276,7 @@ class RedisCache(BaseCache):
async def async_get_cache(self, key, **kwargs): async def async_get_cache(self, key, **kwargs):
_redis_client = self.init_async_client() _redis_client = self.init_async_client()
key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client: async with _redis_client as redis_client:
try: try:
print_verbose(f"Get Async Redis Cache: key: {key}") print_verbose(f"Get Async Redis Cache: key: {key}")
@ -281,6 +303,7 @@ class RedisCache(BaseCache):
async with redis_client.pipeline(transaction=True) as pipe: async with redis_client.pipeline(transaction=True) as pipe:
# Queue the get operations in the pipeline for all keys. # Queue the get operations in the pipeline for all keys.
for cache_key in key_list: for cache_key in key_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
pipe.get(cache_key) # Queue GET command in pipeline pipe.get(cache_key) # Queue GET command in pipeline
# Execute the pipeline and await the results. # Execute the pipeline and await the results.
@ -796,6 +819,8 @@ class DualCache(BaseCache):
self, self,
in_memory_cache: Optional[InMemoryCache] = None, in_memory_cache: Optional[InMemoryCache] = None,
redis_cache: Optional[RedisCache] = None, redis_cache: Optional[RedisCache] = None,
default_in_memory_ttl: Optional[float] = None,
default_redis_ttl: Optional[float] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
# If in_memory_cache is not provided, use the default InMemoryCache # If in_memory_cache is not provided, use the default InMemoryCache
@ -803,11 +828,17 @@ class DualCache(BaseCache):
# If redis_cache is not provided, use the default RedisCache # If redis_cache is not provided, use the default RedisCache
self.redis_cache = redis_cache self.redis_cache = redis_cache
self.default_in_memory_ttl = default_in_memory_ttl
self.default_redis_ttl = default_redis_ttl
def set_cache(self, key, value, local_only: bool = False, **kwargs): def set_cache(self, key, value, local_only: bool = False, **kwargs):
# Update both Redis and in-memory cache # Update both Redis and in-memory cache
try: try:
print_verbose(f"set cache: key: {key}; value: {value}") print_verbose(f"set cache: key: {key}; value: {value}")
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
kwargs["ttl"] = self.default_in_memory_ttl
self.in_memory_cache.set_cache(key, value, **kwargs) self.in_memory_cache.set_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only == False: if self.redis_cache is not None and local_only == False:
@ -823,7 +854,6 @@ class DualCache(BaseCache):
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs) in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None: if in_memory_result is not None:
result = in_memory_result result = in_memory_result
@ -1008,6 +1038,9 @@ class Cache:
self.redis_flush_size = redis_flush_size self.redis_flush_size = redis_flush_size
self.ttl = ttl self.ttl = ttl
if self.namespace is not None and isinstance(self.cache, RedisCache):
self.cache.namespace = self.namespace
def get_cache_key(self, *args, **kwargs): def get_cache_key(self, *args, **kwargs):
""" """
Get the cache key for the given arguments. Get the cache key for the given arguments.

View file

@ -12,6 +12,7 @@ from typing import Any, Literal, Union, BinaryIO
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
import litellm import litellm
from ._logging import verbose_logger from ._logging import verbose_logger

View file

@ -6,8 +6,9 @@ model_list:
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
litellm_settings: litellm_settings:
max_budget: 600020 cache: true
budget_duration: 30d # max_budget: 600020
# budget_duration: 30d
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234

View file

@ -0,0 +1,380 @@
# What is this?
## Checks TPM/RPM Limits for a key/user/team on the proxy
## Works with Redis - if given
from typing import Optional, Literal
import litellm, traceback, sys
from litellm.caching import DualCache, RedisCache
from litellm.proxy._types import (
UserAPIKeyAuth,
LiteLLM_VerificationTokenView,
LiteLLM_UserTable,
LiteLLM_TeamTable,
)
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm import ModelResponse
from datetime import datetime
class _PROXY_MaxTPMRPMLimiter(CustomLogger):
user_api_key_cache = None
# Class variables or attributes
def __init__(self, redis_usage_cache: Optional[RedisCache]):
self.redis_usage_cache = redis_usage_cache
self.internal_cache = DualCache(
redis_cache=redis_usage_cache,
default_in_memory_ttl=10,
default_redis_ttl=60,
)
def print_verbose(self, print_statement):
try:
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except:
pass
## check if admin has set tpm/rpm limits for this key/user/team
def _check_limits_set(
self,
user_api_key_cache: DualCache,
key: Optional[str],
user_id: Optional[str],
team_id: Optional[str],
) -> bool:
## key
if key is not None:
key_val = user_api_key_cache.get_cache(key=key)
if isinstance(key_val, dict):
key_val = LiteLLM_VerificationTokenView(**key_val)
if isinstance(key_val, LiteLLM_VerificationTokenView):
user_api_key_tpm_limit = key_val.tpm_limit
user_api_key_rpm_limit = key_val.rpm_limit
if (
user_api_key_tpm_limit is not None
or user_api_key_rpm_limit is not None
):
return True
## team
if team_id is not None:
team_val = user_api_key_cache.get_cache(key=team_id)
if isinstance(team_val, dict):
team_val = LiteLLM_TeamTable(**team_val)
if isinstance(team_val, LiteLLM_TeamTable):
team_tpm_limit = team_val.tpm_limit
team_rpm_limit = team_val.rpm_limit
if team_tpm_limit is not None or team_rpm_limit is not None:
return True
## user
if user_id is not None:
user_val = user_api_key_cache.get_cache(key=user_id)
if isinstance(user_val, dict):
user_val = LiteLLM_UserTable(**user_val)
if isinstance(user_val, LiteLLM_UserTable):
user_tpm_limit = user_val.tpm_limit
user_rpm_limit = user_val.rpm_limit
if user_tpm_limit is not None or user_rpm_limit is not None:
return True
return False
async def check_key_in_limits(
self,
user_api_key_dict: UserAPIKeyAuth,
current_minute_dict: dict,
tpm_limit: int,
rpm_limit: int,
request_count_api_key: str,
type: Literal["key", "user", "team"],
):
if type == "key" and user_api_key_dict.api_key is not None:
current = current_minute_dict["key"].get(user_api_key_dict.api_key, None)
elif type == "user" and user_api_key_dict.user_id is not None:
current = current_minute_dict["user"].get(user_api_key_dict.user_id, None)
elif type == "team" and user_api_key_dict.team_id is not None:
current = current_minute_dict["team"].get(user_api_key_dict.team_id, None)
else:
return
if current is None:
if tpm_limit == 0 or rpm_limit == 0:
# base case
raise HTTPException(
status_code=429, detail="Max tpm/rpm limit reached."
)
elif current["current_tpm"] < tpm_limit and current["current_rpm"] < rpm_limit:
pass
else:
raise HTTPException(status_code=429, detail="Max tpm/rpm limit reached.")
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
self.print_verbose(
f"Inside Max TPM/RPM Limiter Pre-Call Hook - {user_api_key_dict}"
)
api_key = user_api_key_dict.api_key
# check if REQUEST ALLOWED for user_id
user_id = user_api_key_dict.user_id
## get team tpm/rpm limits
team_id = user_api_key_dict.team_id
_set_limits = self._check_limits_set(
user_api_key_cache=cache, key=api_key, user_id=user_id, team_id=team_id
)
if _set_limits == False:
return
# ------------
# Setup values
# ------------
self.user_api_key_cache = cache
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
cache_key = "usage:{}".format(precise_minute)
current_minute_dict = await self.internal_cache.async_get_cache(
key=cache_key
) # {"usage:{curr_minute}": {"key": {<api_key>: {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}}}}
if current_minute_dict is None:
current_minute_dict = {"key": {}, "user": {}, "team": {}}
if api_key is not None:
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
if tpm_limit is None:
tpm_limit = sys.maxsize
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
if rpm_limit is None:
rpm_limit = sys.maxsize
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
current_minute_dict=current_minute_dict,
request_count_api_key=request_count_api_key,
tpm_limit=tpm_limit,
rpm_limit=rpm_limit,
type="key",
)
if user_id is not None:
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits
# get user tpm/rpm limits
if _user_id_rate_limits is not None and isinstance(
_user_id_rate_limits, dict
):
user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None)
user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None)
if user_tpm_limit is None:
user_tpm_limit = sys.maxsize
if user_rpm_limit is None:
user_rpm_limit = sys.maxsize
# now do the same tpm/rpm checks
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
current_minute_dict=current_minute_dict,
request_count_api_key=request_count_api_key,
tpm_limit=user_tpm_limit,
rpm_limit=user_rpm_limit,
type="user",
)
# TEAM RATE LIMITS
if team_id is not None:
team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize)
if team_tpm_limit is None:
team_tpm_limit = sys.maxsize
team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize)
if team_rpm_limit is None:
team_rpm_limit = sys.maxsize
if team_tpm_limit is None:
team_tpm_limit = sys.maxsize
if team_rpm_limit is None:
team_rpm_limit = sys.maxsize
# now do the same tpm/rpm checks
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
current_minute_dict=current_minute_dict,
request_count_api_key=request_count_api_key,
tpm_limit=team_tpm_limit,
rpm_limit=team_rpm_limit,
type="team",
)
return
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.print_verbose(f"INSIDE TPM RPM Limiter ASYNC SUCCESS LOGGING")
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None
)
user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_team_id", None
)
_limits_set = self._check_limits_set(
user_api_key_cache=self.user_api_key_cache,
key=user_api_key,
user_id=user_api_key_user_id,
team_id=user_api_key_team_id,
)
if _limits_set == False: # don't waste cache calls if no tpm/rpm limits set
return
# ------------
# Setup values
# ------------
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
total_tokens = 0
if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens
"""
- get value from redis
- increment requests + 1
- increment tpm + 1
- increment rpm + 1
- update value in-memory + redis
"""
cache_key = "usage:{}".format(precise_minute)
if (
self.internal_cache.redis_cache is not None
): # get straight from redis if possible
current_minute_dict = (
await self.internal_cache.redis_cache.async_get_cache(
key=cache_key,
)
) # {"usage:{current_minute}": {"key": {}, "team": {}, "user": {}}}
else:
current_minute_dict = await self.internal_cache.async_get_cache(
key=cache_key,
)
if current_minute_dict is None:
current_minute_dict = {"key": {}, "user": {}, "team": {}}
_cache_updated = False # check if a cache update is required. prevent unnecessary rewrites.
# ------------
# Update usage - API Key
# ------------
if user_api_key is not None:
_cache_updated = True
## API KEY ##
if user_api_key in current_minute_dict["key"]:
current_key_usage = current_minute_dict["key"][user_api_key]
new_val = {
"current_tpm": current_key_usage["current_tpm"] + total_tokens,
"current_rpm": current_key_usage["current_rpm"] + 1,
}
else:
new_val = {
"current_tpm": total_tokens,
"current_rpm": 1,
}
current_minute_dict["key"][user_api_key] = new_val
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
# ------------
# Update usage - User
# ------------
if user_api_key_user_id is not None:
_cache_updated = True
total_tokens = 0
if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens
if user_api_key_user_id in current_minute_dict["key"]:
current_key_usage = current_minute_dict["key"][user_api_key_user_id]
new_val = {
"current_tpm": current_key_usage["current_tpm"] + total_tokens,
"current_rpm": current_key_usage["current_rpm"] + 1,
}
else:
new_val = {
"current_tpm": total_tokens,
"current_rpm": 1,
}
current_minute_dict["user"][user_api_key_user_id] = new_val
# ------------
# Update usage - Team
# ------------
if user_api_key_team_id is not None:
_cache_updated = True
total_tokens = 0
if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens
if user_api_key_team_id in current_minute_dict["key"]:
current_key_usage = current_minute_dict["key"][user_api_key_team_id]
new_val = {
"current_tpm": current_key_usage["current_tpm"] + total_tokens,
"current_rpm": current_key_usage["current_rpm"] + 1,
}
else:
new_val = {
"current_tpm": total_tokens,
"current_rpm": 1,
}
current_minute_dict["team"][user_api_key_team_id] = new_val
if _cache_updated == True:
await self.internal_cache.async_set_cache(
key=cache_key, value=current_minute_dict
)
except Exception as e:
self.print_verbose(e) # noqa

View file

@ -102,7 +102,7 @@ from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
import pydantic import pydantic
from litellm.proxy._types import * from litellm.proxy._types import *
from litellm.caching import DualCache from litellm.caching import DualCache, RedisCache
from litellm.proxy.health_check import perform_health_check from litellm.proxy.health_check import perform_health_check
from litellm._logging import verbose_router_logger, verbose_proxy_logger from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.handle_jwt import JWTHandler
@ -281,6 +281,9 @@ otel_logging = False
prisma_client: Optional[PrismaClient] = None prisma_client: Optional[PrismaClient] = None
custom_db_client: Optional[DBClient] = None custom_db_client: Optional[DBClient] = None
user_api_key_cache = DualCache() user_api_key_cache = DualCache()
redis_usage_cache: Optional[RedisCache] = (
None # redis cache used for tracking spend, tpm/rpm limits
)
user_custom_auth = None user_custom_auth = None
user_custom_key_generate = None user_custom_key_generate = None
use_background_health_checks = None use_background_health_checks = None
@ -299,7 +302,9 @@ disable_spend_logs = False
jwt_handler = JWTHandler() jwt_handler = JWTHandler()
prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None
### INITIALIZE GLOBAL LOGGING OBJECT ### ### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) proxy_logging_obj = ProxyLogging(
user_api_key_cache=user_api_key_cache, redis_usage_cache=redis_usage_cache
)
### REDIS QUEUE ### ### REDIS QUEUE ###
async_result = None async_result = None
celery_app_conn = None celery_app_conn = None
@ -909,6 +914,10 @@ async def user_api_key_auth(
models=valid_token.team_models, models=valid_token.team_models,
) )
user_api_key_cache.set_cache(
key=valid_token.team_id, value=_team_obj
) # save team table in cache - used for tpm/rpm limiting - tpm_rpm_limiter.py
_end_user_object = None _end_user_object = None
if "user" in request_data: if "user" in request_data:
_id = "end_user_id:{}".format(request_data["user"]) _id = "end_user_id:{}".format(request_data["user"])
@ -1905,7 +1914,7 @@ class ProxyConfig:
""" """
Load config values into proxy global state Load config values into proxy global state
""" """
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache
# Load existing config # Load existing config
config = await self.get_config(config_file_path=config_file_path) config = await self.get_config(config_file_path=config_file_path)
@ -1967,6 +1976,7 @@ class ProxyConfig:
"password": cache_password, "password": cache_password,
} }
) )
# Assuming cache_type, cache_host, cache_port, and cache_password are strings # Assuming cache_type, cache_host, cache_port, and cache_password are strings
print( # noqa print( # noqa
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}" f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}"
@ -1991,7 +2001,14 @@ class ProxyConfig:
cache_params[key] = litellm.get_secret(value) cache_params[key] = litellm.get_secret(value)
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables ## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
litellm.cache = Cache(**cache_params) litellm.cache = Cache(**cache_params)
if litellm.cache is not None and isinstance(
litellm.cache.cache, RedisCache
):
## INIT PROXY REDIS USAGE CLIENT ##
redis_usage_cache = litellm.cache.cache
print( # noqa print( # noqa
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}" f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
) )

View file

@ -12,13 +12,14 @@ from litellm.proxy._types import (
LiteLLM_TeamTable, LiteLLM_TeamTable,
Member, Member,
) )
from litellm.caching import DualCache from litellm.caching import DualCache, RedisCache
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import ( from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler, _PROXY_MaxParallelRequestsHandler,
) )
from litellm import ModelResponse, EmbeddingResponse, ImageResponse from litellm import ModelResponse, EmbeddingResponse, ImageResponse
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB from litellm.proxy.db.base_client import CustomDB
@ -46,16 +47,23 @@ class ProxyLogging:
- support the max parallel request integration - support the max parallel request integration
""" """
def __init__(self, user_api_key_cache: DualCache): def __init__(
self,
user_api_key_cache: DualCache,
redis_usage_cache: Optional[RedisCache] = None,
):
## INITIALIZE LITELLM CALLBACKS ## ## INITIALIZE LITELLM CALLBACKS ##
self.call_details: dict = {} self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache self.call_details["user_api_key_cache"] = user_api_key_cache
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler() self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler()
self.max_tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(
redis_usage_cache=redis_usage_cache
)
self.max_budget_limiter = _PROXY_MaxBudgetLimiter() self.max_budget_limiter = _PROXY_MaxBudgetLimiter()
self.cache_control_check = _PROXY_CacheControlCheck() self.cache_control_check = _PROXY_CacheControlCheck()
self.alerting: Optional[List] = None self.alerting: Optional[List] = None
self.alerting_threshold: float = 300 # default to 5 min. threshold self.alerting_threshold: float = 300 # default to 5 min. threshold
pass self.redis_usage_cache = redis_usage_cache
def update_values( def update_values(
self, alerting: Optional[List], alerting_threshold: Optional[float] self, alerting: Optional[List], alerting_threshold: Optional[float]
@ -67,6 +75,7 @@ class ProxyLogging:
def _init_litellm_callbacks(self): def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!") print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_tpm_rpm_limiter)
litellm.callbacks.append(self.max_budget_limiter) litellm.callbacks.append(self.max_budget_limiter)
litellm.callbacks.append(self.cache_control_check) litellm.callbacks.append(self.cache_control_check)
litellm.success_callback.append(self.response_taking_too_long_callback) litellm.success_callback.append(self.response_taking_too_long_callback)

View file

@ -61,7 +61,7 @@ from litellm.proxy.utils import DBClient
from starlette.datastructures import URL from starlette.datastructures import URL
from litellm.caching import DualCache from litellm.caching import DualCache
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache(), redis_usage_cache=None)
@pytest.fixture @pytest.fixture

View file

@ -253,7 +253,12 @@ class CompletionCustomHandler(
assert isinstance(end_time, datetime) assert isinstance(end_time, datetime)
## RESPONSE OBJECT ## RESPONSE OBJECT
assert isinstance( assert isinstance(
response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse) response_obj,
(
litellm.ModelResponse,
litellm.EmbeddingResponse,
litellm.TextCompletionResponse,
),
) )
## KWARGS ## KWARGS
assert isinstance(kwargs["model"], str) assert isinstance(kwargs["model"], str)

View file

@ -167,6 +167,7 @@ async def chat_completion_streaming(session, key, model="gpt-4"):
continue continue
@pytest.mark.skip(reason="Global proxy now tracked via `/global/spend/logs`")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_global_proxy_budget_update(): async def test_global_proxy_budget_update():
""" """