forked from phoenix/litellm-mirror
fix(tpm_rpm_limiter.py): fix cache init logic
This commit is contained in:
parent
9c0aecf9b8
commit
6467dd4e11
5 changed files with 36 additions and 44 deletions
|
@ -10,13 +10,16 @@ model_list:
|
||||||
# api_key: my-fake-key
|
# api_key: my-fake-key
|
||||||
# api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
# api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
|
||||||
# litellm_settings:
|
litellm_settings:
|
||||||
# cache: true
|
drop_params: True
|
||||||
# max_budget: 600020
|
max_budget: 800021
|
||||||
# budget_duration: 30d
|
budget_duration: 30d
|
||||||
|
# cache: true
|
||||||
|
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
|
alerting: ["slack"]
|
||||||
# proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
# proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||||
# enable_jwt_auth: True
|
# enable_jwt_auth: True
|
||||||
# alerting: ["slack"]
|
# alerting: ["slack"]
|
||||||
|
|
|
@ -22,13 +22,11 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger):
|
||||||
user_api_key_cache = None
|
user_api_key_cache = None
|
||||||
|
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(self, redis_usage_cache: Optional[RedisCache]):
|
def __init__(self, internal_cache: Optional[DualCache]):
|
||||||
self.redis_usage_cache = redis_usage_cache
|
if internal_cache is None:
|
||||||
self.internal_cache = DualCache(
|
self.internal_cache = DualCache()
|
||||||
redis_cache=redis_usage_cache,
|
else:
|
||||||
default_in_memory_ttl=10,
|
self.internal_cache = internal_cache
|
||||||
default_redis_ttl=60,
|
|
||||||
)
|
|
||||||
|
|
||||||
def print_verbose(self, print_statement):
|
def print_verbose(self, print_statement):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -302,9 +302,7 @@ disable_spend_logs = False
|
||||||
jwt_handler = JWTHandler()
|
jwt_handler = JWTHandler()
|
||||||
prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None
|
prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None
|
||||||
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||||
proxy_logging_obj = ProxyLogging(
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||||
user_api_key_cache=user_api_key_cache, redis_usage_cache=redis_usage_cache
|
|
||||||
)
|
|
||||||
### REDIS QUEUE ###
|
### REDIS QUEUE ###
|
||||||
async_result = None
|
async_result = None
|
||||||
celery_app_conn = None
|
celery_app_conn = None
|
||||||
|
@ -2251,6 +2249,7 @@ class ProxyConfig:
|
||||||
proxy_logging_obj.update_values(
|
proxy_logging_obj.update_values(
|
||||||
alerting=general_settings.get("alerting", None),
|
alerting=general_settings.get("alerting", None),
|
||||||
alerting_threshold=general_settings.get("alerting_threshold", 600),
|
alerting_threshold=general_settings.get("alerting_threshold", 600),
|
||||||
|
redis_cache=redis_usage_cache,
|
||||||
)
|
)
|
||||||
### CONNECT TO DATABASE ###
|
### CONNECT TO DATABASE ###
|
||||||
database_url = general_settings.get("database_url", None)
|
database_url = general_settings.get("database_url", None)
|
||||||
|
@ -4976,31 +4975,13 @@ async def global_spend():
|
||||||
|
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||||
sql_query = f"""
|
sql_query = """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
|
||||||
SELECT SUM(spend) AS total_spend
|
|
||||||
FROM "LiteLLM_VerificationToken";
|
|
||||||
;
|
|
||||||
"""
|
|
||||||
response = await prisma_client.db.query_raw(query=sql_query)
|
response = await prisma_client.db.query_raw(query=sql_query)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
if isinstance(response, list) and len(response) > 0:
|
if isinstance(response, list) and len(response) > 0:
|
||||||
total_spend = response[0].get("total_spend", 0.0)
|
total_spend = response[0].get("total_spend", 0.0)
|
||||||
|
|
||||||
sql_query = f"""
|
return {"spend": total_spend, "max_budget": litellm.max_budget}
|
||||||
SELECT
|
|
||||||
*
|
|
||||||
FROM
|
|
||||||
"LiteLLM_UserTable"
|
|
||||||
WHERE
|
|
||||||
user_id = 'litellm-proxy-budget';
|
|
||||||
"""
|
|
||||||
user_response = await prisma_client.db.query_raw(query=sql_query)
|
|
||||||
|
|
||||||
if user_response is not None:
|
|
||||||
if isinstance(user_response, list) and len(user_response) > 0:
|
|
||||||
total_proxy_budget = user_response[0].get("max_budget", 0.0)
|
|
||||||
|
|
||||||
return {"spend": total_spend, "max_budget": total_proxy_budget}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
|
|
|
@ -50,28 +50,33 @@ class ProxyLogging:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
user_api_key_cache: DualCache,
|
user_api_key_cache: DualCache,
|
||||||
redis_usage_cache: Optional[RedisCache] = None,
|
|
||||||
):
|
):
|
||||||
## INITIALIZE LITELLM CALLBACKS ##
|
## INITIALIZE LITELLM CALLBACKS ##
|
||||||
self.call_details: dict = {}
|
self.call_details: dict = {}
|
||||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||||
|
self.internal_usage_cache = DualCache()
|
||||||
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler()
|
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler()
|
||||||
self.max_tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(
|
self.max_tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(
|
||||||
redis_usage_cache=redis_usage_cache
|
internal_cache=self.internal_usage_cache
|
||||||
)
|
)
|
||||||
self.max_budget_limiter = _PROXY_MaxBudgetLimiter()
|
self.max_budget_limiter = _PROXY_MaxBudgetLimiter()
|
||||||
self.cache_control_check = _PROXY_CacheControlCheck()
|
self.cache_control_check = _PROXY_CacheControlCheck()
|
||||||
self.alerting: Optional[List] = None
|
self.alerting: Optional[List] = None
|
||||||
self.alerting_threshold: float = 300 # default to 5 min. threshold
|
self.alerting_threshold: float = 300 # default to 5 min. threshold
|
||||||
self.internal_usage_cache = DualCache(redis_cache=redis_usage_cache)
|
|
||||||
|
|
||||||
def update_values(
|
def update_values(
|
||||||
self, alerting: Optional[List], alerting_threshold: Optional[float]
|
self,
|
||||||
|
alerting: Optional[List],
|
||||||
|
alerting_threshold: Optional[float],
|
||||||
|
redis_cache: Optional[RedisCache],
|
||||||
):
|
):
|
||||||
self.alerting = alerting
|
self.alerting = alerting
|
||||||
if alerting_threshold is not None:
|
if alerting_threshold is not None:
|
||||||
self.alerting_threshold = alerting_threshold
|
self.alerting_threshold = alerting_threshold
|
||||||
|
|
||||||
|
if redis_cache is not None:
|
||||||
|
self.internal_usage_cache.redis_cache = redis_cache
|
||||||
|
|
||||||
def _init_litellm_callbacks(self):
|
def _init_litellm_callbacks(self):
|
||||||
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
||||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||||
|
@ -265,10 +270,11 @@ class ProxyLogging:
|
||||||
if self.alerting is None:
|
if self.alerting is None:
|
||||||
# do nothing if alerting is not switched on
|
# do nothing if alerting is not switched on
|
||||||
return
|
return
|
||||||
|
_id: str = "default_id" # used for caching
|
||||||
if type == "user_and_proxy_budget":
|
if type == "user_and_proxy_budget":
|
||||||
user_info = dict(user_info)
|
user_info = dict(user_info)
|
||||||
user_id = user_info["user_id"]
|
user_id = user_info["user_id"]
|
||||||
|
_id = user_id
|
||||||
max_budget = user_info["max_budget"]
|
max_budget = user_info["max_budget"]
|
||||||
spend = user_info["spend"]
|
spend = user_info["spend"]
|
||||||
user_email = user_info["user_email"]
|
user_email = user_info["user_email"]
|
||||||
|
@ -276,12 +282,14 @@ class ProxyLogging:
|
||||||
elif type == "token_budget":
|
elif type == "token_budget":
|
||||||
token_info = dict(user_info)
|
token_info = dict(user_info)
|
||||||
token = token_info["token"]
|
token = token_info["token"]
|
||||||
|
_id = token
|
||||||
spend = token_info["spend"]
|
spend = token_info["spend"]
|
||||||
max_budget = token_info["max_budget"]
|
max_budget = token_info["max_budget"]
|
||||||
user_id = token_info["user_id"]
|
user_id = token_info["user_id"]
|
||||||
user_info = f"""\nToken: {token}\nSpend: ${spend}\nMax Budget: ${max_budget}\nUser ID: {user_id}"""
|
user_info = f"""\nToken: {token}\nSpend: ${spend}\nMax Budget: ${max_budget}\nUser ID: {user_id}"""
|
||||||
elif type == "failed_tracking":
|
elif type == "failed_tracking":
|
||||||
user_id = str(user_info)
|
user_id = str(user_info)
|
||||||
|
_id = user_id
|
||||||
user_info = f"\nUser ID: {user_id}\n Error {error_message}"
|
user_info = f"\nUser ID: {user_id}\n Error {error_message}"
|
||||||
message = "Failed Tracking Cost for" + user_info
|
message = "Failed Tracking Cost for" + user_info
|
||||||
await self.alerting_handler(
|
await self.alerting_handler(
|
||||||
|
@ -337,13 +345,15 @@ class ProxyLogging:
|
||||||
# check if 5% of max budget is left
|
# check if 5% of max budget is left
|
||||||
if percent_left <= 0.05:
|
if percent_left <= 0.05:
|
||||||
message = "5% budget left for" + user_info
|
message = "5% budget left for" + user_info
|
||||||
result = await _cache.async_get_cache(key=message)
|
cache_key = "alerting:{}".format(_id)
|
||||||
|
result = await _cache.async_get_cache(key=cache_key)
|
||||||
if result is None:
|
if result is None:
|
||||||
await self.alerting_handler(
|
await self.alerting_handler(
|
||||||
message=message,
|
message=message,
|
||||||
level="Medium",
|
level="Medium",
|
||||||
)
|
)
|
||||||
await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
|
|
||||||
|
await _cache.async_set_cache(key=cache_key, value="SENT", ttl=2419200)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ async def test_pre_call_hook_rpm_limits():
|
||||||
key=_api_key, value={"api_key": _api_key, "tpm_limit": 9, "rpm_limit": 1}
|
key=_api_key, value={"api_key": _api_key, "tpm_limit": 9, "rpm_limit": 1}
|
||||||
)
|
)
|
||||||
|
|
||||||
tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(redis_usage_cache=None)
|
tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=DualCache())
|
||||||
|
|
||||||
await tpm_rpm_limiter.async_pre_call_hook(
|
await tpm_rpm_limiter.async_pre_call_hook(
|
||||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||||
|
@ -89,8 +89,8 @@ async def test_pre_call_hook_team_rpm_limits(
|
||||||
user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) # type: ignore
|
user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) # type: ignore
|
||||||
local_cache = DualCache()
|
local_cache = DualCache()
|
||||||
local_cache.set_cache(key=_api_key, value=_user_api_key_dict)
|
local_cache.set_cache(key=_api_key, value=_user_api_key_dict)
|
||||||
tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(redis_usage_cache=_redis_usage_cache)
|
internal_cache = DualCache(redis_cache=_redis_usage_cache)
|
||||||
|
tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=internal_cache)
|
||||||
await tpm_rpm_limiter.async_pre_call_hook(
|
await tpm_rpm_limiter.async_pre_call_hook(
|
||||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue