fix(tpm_rpm_limiter.py): fix cache init logic

This commit is contained in:
Krrish Dholakia 2024-04-01 18:01:38 -07:00
parent 9c0aecf9b8
commit 6467dd4e11
5 changed files with 36 additions and 44 deletions

View file

@ -10,13 +10,16 @@ model_list:
# api_key: my-fake-key # api_key: my-fake-key
# api_base: https://exampleopenaiendpoint-production.up.railway.app/ # api_base: https://exampleopenaiendpoint-production.up.railway.app/
# litellm_settings: litellm_settings:
# cache: true drop_params: True
# max_budget: 600020 max_budget: 800021
# budget_duration: 30d budget_duration: 30d
# cache: true
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234
alerting: ["slack"]
# proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds) # proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
# enable_jwt_auth: True # enable_jwt_auth: True
# alerting: ["slack"] # alerting: ["slack"]

View file

@ -22,13 +22,11 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger):
user_api_key_cache = None user_api_key_cache = None
# Class variables or attributes # Class variables or attributes
def __init__(self, redis_usage_cache: Optional[RedisCache]): def __init__(self, internal_cache: Optional[DualCache]):
self.redis_usage_cache = redis_usage_cache if internal_cache is None:
self.internal_cache = DualCache( self.internal_cache = DualCache()
redis_cache=redis_usage_cache, else:
default_in_memory_ttl=10, self.internal_cache = internal_cache
default_redis_ttl=60,
)
def print_verbose(self, print_statement): def print_verbose(self, print_statement):
try: try:

View file

@ -302,9 +302,7 @@ disable_spend_logs = False
jwt_handler = JWTHandler() jwt_handler = JWTHandler()
prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None
### INITIALIZE GLOBAL LOGGING OBJECT ### ### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging( proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
user_api_key_cache=user_api_key_cache, redis_usage_cache=redis_usage_cache
)
### REDIS QUEUE ### ### REDIS QUEUE ###
async_result = None async_result = None
celery_app_conn = None celery_app_conn = None
@ -2251,6 +2249,7 @@ class ProxyConfig:
proxy_logging_obj.update_values( proxy_logging_obj.update_values(
alerting=general_settings.get("alerting", None), alerting=general_settings.get("alerting", None),
alerting_threshold=general_settings.get("alerting_threshold", 600), alerting_threshold=general_settings.get("alerting_threshold", 600),
redis_cache=redis_usage_cache,
) )
### CONNECT TO DATABASE ### ### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None) database_url = general_settings.get("database_url", None)
@ -4976,31 +4975,13 @@ async def global_spend():
if prisma_client is None: if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"}) raise HTTPException(status_code=500, detail={"error": "No db connected"})
sql_query = f""" sql_query = """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
SELECT SUM(spend) AS total_spend
FROM "LiteLLM_VerificationToken";
;
"""
response = await prisma_client.db.query_raw(query=sql_query) response = await prisma_client.db.query_raw(query=sql_query)
if response is not None: if response is not None:
if isinstance(response, list) and len(response) > 0: if isinstance(response, list) and len(response) > 0:
total_spend = response[0].get("total_spend", 0.0) total_spend = response[0].get("total_spend", 0.0)
sql_query = f""" return {"spend": total_spend, "max_budget": litellm.max_budget}
SELECT
*
FROM
"LiteLLM_UserTable"
WHERE
user_id = 'litellm-proxy-budget';
"""
user_response = await prisma_client.db.query_raw(query=sql_query)
if user_response is not None:
if isinstance(user_response, list) and len(user_response) > 0:
total_proxy_budget = user_response[0].get("max_budget", 0.0)
return {"spend": total_spend, "max_budget": total_proxy_budget}
@router.get( @router.get(

View file

@ -50,28 +50,33 @@ class ProxyLogging:
def __init__( def __init__(
self, self,
user_api_key_cache: DualCache, user_api_key_cache: DualCache,
redis_usage_cache: Optional[RedisCache] = None,
): ):
## INITIALIZE LITELLM CALLBACKS ## ## INITIALIZE LITELLM CALLBACKS ##
self.call_details: dict = {} self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache self.call_details["user_api_key_cache"] = user_api_key_cache
self.internal_usage_cache = DualCache()
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler() self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler()
self.max_tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter( self.max_tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(
redis_usage_cache=redis_usage_cache internal_cache=self.internal_usage_cache
) )
self.max_budget_limiter = _PROXY_MaxBudgetLimiter() self.max_budget_limiter = _PROXY_MaxBudgetLimiter()
self.cache_control_check = _PROXY_CacheControlCheck() self.cache_control_check = _PROXY_CacheControlCheck()
self.alerting: Optional[List] = None self.alerting: Optional[List] = None
self.alerting_threshold: float = 300 # default to 5 min. threshold self.alerting_threshold: float = 300 # default to 5 min. threshold
self.internal_usage_cache = DualCache(redis_cache=redis_usage_cache)
def update_values( def update_values(
self, alerting: Optional[List], alerting_threshold: Optional[float] self,
alerting: Optional[List],
alerting_threshold: Optional[float],
redis_cache: Optional[RedisCache],
): ):
self.alerting = alerting self.alerting = alerting
if alerting_threshold is not None: if alerting_threshold is not None:
self.alerting_threshold = alerting_threshold self.alerting_threshold = alerting_threshold
if redis_cache is not None:
self.internal_usage_cache.redis_cache = redis_cache
def _init_litellm_callbacks(self): def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!") print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_parallel_request_limiter)
@ -265,10 +270,11 @@ class ProxyLogging:
if self.alerting is None: if self.alerting is None:
# do nothing if alerting is not switched on # do nothing if alerting is not switched on
return return
_id: str = "default_id" # used for caching
if type == "user_and_proxy_budget": if type == "user_and_proxy_budget":
user_info = dict(user_info) user_info = dict(user_info)
user_id = user_info["user_id"] user_id = user_info["user_id"]
_id = user_id
max_budget = user_info["max_budget"] max_budget = user_info["max_budget"]
spend = user_info["spend"] spend = user_info["spend"]
user_email = user_info["user_email"] user_email = user_info["user_email"]
@ -276,12 +282,14 @@ class ProxyLogging:
elif type == "token_budget": elif type == "token_budget":
token_info = dict(user_info) token_info = dict(user_info)
token = token_info["token"] token = token_info["token"]
_id = token
spend = token_info["spend"] spend = token_info["spend"]
max_budget = token_info["max_budget"] max_budget = token_info["max_budget"]
user_id = token_info["user_id"] user_id = token_info["user_id"]
user_info = f"""\nToken: {token}\nSpend: ${spend}\nMax Budget: ${max_budget}\nUser ID: {user_id}""" user_info = f"""\nToken: {token}\nSpend: ${spend}\nMax Budget: ${max_budget}\nUser ID: {user_id}"""
elif type == "failed_tracking": elif type == "failed_tracking":
user_id = str(user_info) user_id = str(user_info)
_id = user_id
user_info = f"\nUser ID: {user_id}\n Error {error_message}" user_info = f"\nUser ID: {user_id}\n Error {error_message}"
message = "Failed Tracking Cost for" + user_info message = "Failed Tracking Cost for" + user_info
await self.alerting_handler( await self.alerting_handler(
@ -337,13 +345,15 @@ class ProxyLogging:
# check if 5% of max budget is left # check if 5% of max budget is left
if percent_left <= 0.05: if percent_left <= 0.05:
message = "5% budget left for" + user_info message = "5% budget left for" + user_info
result = await _cache.async_get_cache(key=message) cache_key = "alerting:{}".format(_id)
result = await _cache.async_get_cache(key=cache_key)
if result is None: if result is None:
await self.alerting_handler( await self.alerting_handler(
message=message, message=message,
level="Medium", level="Medium",
) )
await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
await _cache.async_set_cache(key=cache_key, value="SENT", ttl=2419200)
return return

View file

@ -38,7 +38,7 @@ async def test_pre_call_hook_rpm_limits():
key=_api_key, value={"api_key": _api_key, "tpm_limit": 9, "rpm_limit": 1} key=_api_key, value={"api_key": _api_key, "tpm_limit": 9, "rpm_limit": 1}
) )
tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(redis_usage_cache=None) tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=DualCache())
await tpm_rpm_limiter.async_pre_call_hook( await tpm_rpm_limiter.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
@ -89,8 +89,8 @@ async def test_pre_call_hook_team_rpm_limits(
user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) # type: ignore user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) # type: ignore
local_cache = DualCache() local_cache = DualCache()
local_cache.set_cache(key=_api_key, value=_user_api_key_dict) local_cache.set_cache(key=_api_key, value=_user_api_key_dict)
tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(redis_usage_cache=_redis_usage_cache) internal_cache = DualCache(redis_cache=_redis_usage_cache)
tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=internal_cache)
await tpm_rpm_limiter.async_pre_call_hook( await tpm_rpm_limiter.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
) )