diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index ad90173a44..771e8526c6 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Extra, Field, root_validator -import enum +import enum, sys from typing import Optional, List, Union, Dict, Literal from datetime import datetime import uuid, json @@ -161,6 +161,8 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api max_parallel_requests: Optional[int] = None duration: str = "1h" metadata: dict = {} + tpm_limit: int = sys.maxsize + rpm_limit: int = sys.maxsize class GenerateKeyResponse(LiteLLMBase): diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 7bec1d5d66..2ef19a1498 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -1,5 +1,5 @@ from typing import Optional -import litellm, traceback +import litellm, traceback, sys from litellm.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger @@ -28,12 +28,18 @@ class MaxParallelRequestsHandler(CustomLogger): ): self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook") api_key = user_api_key_dict.api_key - max_parallel_requests = user_api_key_dict.max_parallel_requests + max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize + tpm_limit = user_api_key_dict.tpm_limit + rpm_limit = user_api_key_dict.rpm_limit if api_key is None: return - if max_parallel_requests is None: + if ( + max_parallel_requests == sys.maxsize + and tpm_limit == sys.maxsize + and rpm_limit == sys.maxsize + ): return self.user_api_key_cache = cache # save the api key cache for updating the value @@ -60,7 +66,11 @@ class MaxParallelRequestsHandler(CustomLogger): "current_rpm": 0, } cache.set_cache(request_count_api_key, new_val) - elif int(current["current_requests"]) < max_parallel_requests: + elif ( + int(current["current_requests"]) < max_parallel_requests + and current["current_tpm"] < tpm_limit + and current["current_rpm"] < rpm_limit + ): # Increase count for this token new_val = { "current_requests": current["current_requests"] + 1,