diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 908cb58cfd..0869713660 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -470,7 +470,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): class UserAPIKeyAuth( - LiteLLM_VerificationToken + LiteLLM_VerificationTokenView ): # the expected response object for user api key auth """ Return the row in the db diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index fb61fe3da6..e0c85cee0f 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -154,6 +154,32 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): tpm_limit=user_tpm_limit, rpm_limit=user_rpm_limit, ) + + # TEAM RATE LIMITS + ## get team tpm/rpm limits + team_id = user_api_key_dict.team_id + team_tpm_limit = user_api_key_dict.team_tpm_limit or sys.maxsize + team_rpm_limit = user_api_key_dict.team_rpm_limit or sys.maxsize + + if team_tpm_limit is None: + team_tpm_limit = sys.maxsize + if team_rpm_limit is None: + team_rpm_limit = sys.maxsize + + # now do the same tpm/rpm checks + request_count_api_key = f"{team_id}::{precise_minute}::request_count" + + # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") + await self.check_key_in_limits( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type=call_type, + max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user + request_count_api_key=request_count_api_key, + tpm_limit=team_tpm_limit, + rpm_limit=team_rpm_limit, + ) return async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): @@ -163,6 +189,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): user_api_key_user_id = kwargs["litellm_params"]["metadata"].get( "user_api_key_user_id", None ) + user_api_key_team_id = kwargs["litellm_params"]["metadata"].get( + "user_api_key_team_id", None + ) if user_api_key is None: return @@ -243,6 +272,40 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): request_count_api_key, new_val, ttl=60 ) # store in cache for 1 min. + # ------------ + # Update usage - Team + # ------------ + if user_api_key_team_id is None: + return + + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + total_tokens = response_obj.usage.total_tokens + + request_count_api_key = ( + f"{user_api_key_team_id}::{precise_minute}::request_count" + ) + + current = self.user_api_key_cache.get_cache(key=request_count_api_key) or { + "current_requests": 1, + "current_tpm": total_tokens, + "current_rpm": 1, + } + + new_val = { + "current_requests": max(current["current_requests"] - 1, 0), + "current_tpm": current["current_tpm"] + total_tokens, + "current_rpm": current["current_rpm"] + 1, + } + + self.print_verbose( + f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" + ) + self.user_api_key_cache.set_cache( + request_count_api_key, new_val, ttl=60 + ) # store in cache for 1 min. + except Exception as e: self.print_verbose(e) # noqa