diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index e0c85cee0..a4fb70c57 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -38,7 +38,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): current = cache.get_cache( key=request_count_api_key ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} - # print(f"current: {current}") if current is None: new_val = { "current_requests": 1, @@ -73,8 +72,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook") api_key = user_api_key_dict.api_key max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize - tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize - rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize + tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize) + rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize) if api_key is None: return @@ -131,35 +130,34 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): _user_id_rate_limits = user_api_key_dict.user_id_rate_limits # get user tpm/rpm limits - if _user_id_rate_limits is None or _user_id_rate_limits == {}: - return - user_tpm_limit = _user_id_rate_limits.get("tpm_limit") - user_rpm_limit = _user_id_rate_limits.get("rpm_limit") - if user_tpm_limit is None: - user_tpm_limit = sys.maxsize - if user_rpm_limit is None: - user_rpm_limit = sys.maxsize + if _user_id_rate_limits is not None and isinstance(_user_id_rate_limits, dict): + user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None) + user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None) + if user_tpm_limit is None: + user_tpm_limit = sys.maxsize + if user_rpm_limit is None: + user_rpm_limit = sys.maxsize - # now do the same tpm/rpm checks - request_count_api_key = f"{user_id}::{precise_minute}::request_count" + # now do the same tpm/rpm checks + request_count_api_key = f"{user_id}::{precise_minute}::request_count" - # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") - await self.check_key_in_limits( - user_api_key_dict=user_api_key_dict, - cache=cache, - data=data, - call_type=call_type, - max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user - request_count_api_key=request_count_api_key, - tpm_limit=user_tpm_limit, - rpm_limit=user_rpm_limit, - ) + # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") + await self.check_key_in_limits( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type=call_type, + max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user + request_count_api_key=request_count_api_key, + tpm_limit=user_tpm_limit, + rpm_limit=user_rpm_limit, + ) # TEAM RATE LIMITS ## get team tpm/rpm limits team_id = user_api_key_dict.team_id - team_tpm_limit = user_api_key_dict.team_tpm_limit or sys.maxsize - team_rpm_limit = user_api_key_dict.team_rpm_limit or sys.maxsize + team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize) + team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize) if team_tpm_limit is None: team_tpm_limit = sys.maxsize @@ -241,36 +239,36 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # ------------ # Update usage - User # ------------ - if user_api_key_user_id is None: - return + if user_api_key_user_id is not None: + total_tokens = 0 - total_tokens = 0 + if isinstance(response_obj, ModelResponse): + total_tokens = response_obj.usage.total_tokens - if isinstance(response_obj, ModelResponse): - total_tokens = response_obj.usage.total_tokens + request_count_api_key = ( + f"{user_api_key_user_id}::{precise_minute}::request_count" + ) - request_count_api_key = ( - f"{user_api_key_user_id}::{precise_minute}::request_count" - ) + current = self.user_api_key_cache.get_cache( + key=request_count_api_key + ) or { + "current_requests": 1, + "current_tpm": total_tokens, + "current_rpm": 1, + } - current = self.user_api_key_cache.get_cache(key=request_count_api_key) or { - "current_requests": 1, - "current_tpm": total_tokens, - "current_rpm": 1, - } + new_val = { + "current_requests": max(current["current_requests"] - 1, 0), + "current_tpm": current["current_tpm"] + total_tokens, + "current_rpm": current["current_rpm"] + 1, + } - new_val = { - "current_requests": max(current["current_requests"] - 1, 0), - "current_tpm": current["current_tpm"] + total_tokens, - "current_rpm": current["current_rpm"] + 1, - } - - self.print_verbose( - f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" - ) - self.user_api_key_cache.set_cache( - request_count_api_key, new_val, ttl=60 - ) # store in cache for 1 min. + self.print_verbose( + f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" + ) + self.user_api_key_cache.set_cache( + request_count_api_key, new_val, ttl=60 + ) # store in cache for 1 min. # ------------ # Update usage - Team diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e1f604096..ba756d05c 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -350,8 +350,7 @@ async def user_api_key_auth( original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility) if api_key.startswith("sk-"): api_key = hash_token(token=api_key) - # valid_token = user_api_key_cache.get_cache(key=api_key) - valid_token = None + valid_token = user_api_key_cache.get_cache(key=api_key) if valid_token is None: ## check db verbose_proxy_logger.debug(f"api key: {api_key}") @@ -384,7 +383,6 @@ async def user_api_key_auth( # 6. If token spend per model is under budget per model # 7. If token spend is under team budget # 8. If team spend is under team budget - request_data = await _read_request_body( request=request ) # request data, used across all checks. Making this easily available @@ -627,7 +625,7 @@ async def user_api_key_auth( ) ) - if valid_token.spend > valid_token.team_max_budget: + if valid_token.spend >= valid_token.team_max_budget: raise Exception( f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Team: {valid_token.team_max_budget}" ) @@ -646,7 +644,7 @@ async def user_api_key_auth( ) ) - if valid_token.team_spend > valid_token.team_max_budget: + if valid_token.team_spend >= valid_token.team_max_budget: raise Exception( f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}" ) diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py index e402b617b..bd5185a23 100644 --- a/litellm/tests/test_parallel_request_limiter.py +++ b/litellm/tests/test_parallel_request_limiter.py @@ -99,6 +99,59 @@ async def test_pre_call_hook_rpm_limits(): assert e.status_code == 429 +@pytest.mark.asyncio +async def test_pre_call_hook_team_rpm_limits(): + """ + Test if error raised on hitting team rpm limits + """ + litellm.set_verbose = True + _api_key = "sk-12345" + _team_id = "unique-team-id" + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, + max_parallel_requests=1, + tpm_limit=9, + rpm_limit=10, + team_rpm_limit=1, + team_id=_team_id, + ) + local_cache = DualCache() + parallel_request_handler = MaxParallelRequestsHandler() + + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" + ) + + kwargs = { + "litellm_params": { + "metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id} + } + } + + await parallel_request_handler.async_log_success_event( + kwargs=kwargs, + response_obj="", + start_time="", + end_time="", + ) + + print(f"local_cache: {local_cache}") + + ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} + + try: + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={}, + call_type="", + ) + + pytest.fail(f"Expected call to fail") + except Exception as e: + assert e.status_code == 429 + + @pytest.mark.asyncio async def test_pre_call_hook_tpm_limits(): """