diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index af59869c2..28e6d1853 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -39,6 +39,11 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): key=request_count_api_key ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} if current is None: + if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0: + # base case + raise HTTPException( + status_code=429, detail="Max parallel request limit reached." + ) new_val = { "current_requests": 1, "current_tpm": 0, @@ -81,9 +86,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): if rpm_limit is None: rpm_limit = sys.maxsize - if api_key is None: - return - self.user_api_key_cache = cache # save the api key cache for updating the value # ------------ # Setup values @@ -94,62 +96,99 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): current_minute = datetime.now().strftime("%M") precise_minute = f"{current_date}-{current_hour}-{current_minute}" - request_count_api_key = f"{api_key}::{precise_minute}::request_count" + if api_key is not None: + request_count_api_key = f"{api_key}::{precise_minute}::request_count" - # CHECK IF REQUEST ALLOWED for key - current = cache.get_cache( - key=request_count_api_key - ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} - self.print_verbose(f"current: {current}") - if ( - max_parallel_requests == sys.maxsize - and tpm_limit == sys.maxsize - and rpm_limit == sys.maxsize - ): - pass - elif max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0: - raise HTTPException( - status_code=429, detail="Max parallel request limit reached." - ) - elif current is None: - new_val = { - "current_requests": 1, - "current_tpm": 0, - "current_rpm": 0, - } - cache.set_cache(request_count_api_key, new_val) - elif ( - int(current["current_requests"]) < max_parallel_requests - and current["current_tpm"] < tpm_limit - and current["current_rpm"] < rpm_limit - ): - # Increase count for this token - new_val = { - "current_requests": current["current_requests"] + 1, - "current_tpm": current["current_tpm"], - "current_rpm": current["current_rpm"], - } - cache.set_cache(request_count_api_key, new_val) - else: - raise HTTPException( - status_code=429, detail="Max parallel request limit reached." - ) + # CHECK IF REQUEST ALLOWED for key + + current = cache.get_cache( + key=request_count_api_key + ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} + self.print_verbose(f"current: {current}") + if ( + max_parallel_requests == sys.maxsize + and tpm_limit == sys.maxsize + and rpm_limit == sys.maxsize + ): + pass + elif max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0: + raise HTTPException( + status_code=429, detail="Max parallel request limit reached." + ) + elif current is None: + new_val = { + "current_requests": 1, + "current_tpm": 0, + "current_rpm": 0, + } + cache.set_cache(request_count_api_key, new_val) + elif ( + int(current["current_requests"]) < max_parallel_requests + and current["current_tpm"] < tpm_limit + and current["current_rpm"] < rpm_limit + ): + # Increase count for this token + new_val = { + "current_requests": current["current_requests"] + 1, + "current_tpm": current["current_tpm"], + "current_rpm": current["current_rpm"], + } + cache.set_cache(request_count_api_key, new_val) + else: + raise HTTPException( + status_code=429, detail="Max parallel request limit reached." + ) # check if REQUEST ALLOWED for user_id user_id = user_api_key_dict.user_id - _user_id_rate_limits = user_api_key_dict.user_id_rate_limits + if user_id is not None: + _user_id_rate_limits = user_api_key_dict.user_id_rate_limits - # get user tpm/rpm limits - if _user_id_rate_limits is not None and isinstance(_user_id_rate_limits, dict): - user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None) - user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None) - if user_tpm_limit is None: - user_tpm_limit = sys.maxsize - if user_rpm_limit is None: - user_rpm_limit = sys.maxsize + # get user tpm/rpm limits + if _user_id_rate_limits is not None and isinstance( + _user_id_rate_limits, dict + ): + user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None) + user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None) + if user_tpm_limit is None: + user_tpm_limit = sys.maxsize + if user_rpm_limit is None: + user_rpm_limit = sys.maxsize + + # now do the same tpm/rpm checks + request_count_api_key = f"{user_id}::{precise_minute}::request_count" + + # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") + await self.check_key_in_limits( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type=call_type, + max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user + request_count_api_key=request_count_api_key, + tpm_limit=user_tpm_limit, + rpm_limit=user_rpm_limit, + ) + + # TEAM RATE LIMITS + ## get team tpm/rpm limits + team_id = user_api_key_dict.team_id + if team_id is not None: + team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize) + + if team_tpm_limit is None: + team_tpm_limit = sys.maxsize + team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize) + if team_rpm_limit is None: + team_rpm_limit = sys.maxsize + + if team_tpm_limit is None: + team_tpm_limit = sys.maxsize + if team_rpm_limit is None: + team_rpm_limit = sys.maxsize # now do the same tpm/rpm checks - request_count_api_key = f"{user_id}::{precise_minute}::request_count" + request_count_api_key = f"{team_id}::{precise_minute}::request_count" # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") await self.check_key_in_limits( @@ -157,41 +196,12 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): cache=cache, data=data, call_type=call_type, - max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user + max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a team request_count_api_key=request_count_api_key, - tpm_limit=user_tpm_limit, - rpm_limit=user_rpm_limit, + tpm_limit=team_tpm_limit, + rpm_limit=team_rpm_limit, ) - # TEAM RATE LIMITS - ## get team tpm/rpm limits - team_id = user_api_key_dict.team_id - team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize) - if team_tpm_limit is None: - team_tpm_limit = sys.maxsize - team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize) - if team_rpm_limit is None: - team_rpm_limit = sys.maxsize - - if team_tpm_limit is None: - team_tpm_limit = sys.maxsize - if team_rpm_limit is None: - team_rpm_limit = sys.maxsize - - # now do the same tpm/rpm checks - request_count_api_key = f"{team_id}::{precise_minute}::request_count" - - # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") - await self.check_key_in_limits( - user_api_key_dict=user_api_key_dict, - cache=cache, - data=data, - call_type=call_type, - max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user - request_count_api_key=request_count_api_key, - tpm_limit=team_tpm_limit, - rpm_limit=team_rpm_limit, - ) return async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): @@ -205,9 +215,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "user_api_key_team_id", None ) - if user_api_key is None: - return - if self.user_api_key_cache is None: return @@ -225,30 +232,35 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): if isinstance(response_obj, ModelResponse): total_tokens = response_obj.usage.total_tokens - request_count_api_key = f"{user_api_key}::{precise_minute}::request_count" - - current = self.user_api_key_cache.get_cache(key=request_count_api_key) or { - "current_requests": 1, - "current_tpm": total_tokens, - "current_rpm": 1, - } - # ------------ # Update usage - API Key # ------------ - new_val = { - "current_requests": max(current["current_requests"] - 1, 0), - "current_tpm": current["current_tpm"] + total_tokens, - "current_rpm": current["current_rpm"] + 1, - } + if user_api_key is not None: + request_count_api_key = ( + f"{user_api_key}::{precise_minute}::request_count" + ) - self.print_verbose( - f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" - ) - self.user_api_key_cache.set_cache( - request_count_api_key, new_val, ttl=60 - ) # store in cache for 1 min. + current = self.user_api_key_cache.get_cache( + key=request_count_api_key + ) or { + "current_requests": 1, + "current_tpm": total_tokens, + "current_rpm": 1, + } + + new_val = { + "current_requests": max(current["current_requests"] - 1, 0), + "current_tpm": current["current_tpm"] + total_tokens, + "current_rpm": current["current_rpm"] + 1, + } + + self.print_verbose( + f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" + ) + self.user_api_key_cache.set_cache( + request_count_api_key, new_val, ttl=60 + ) # store in cache for 1 min. # ------------ # Update usage - User @@ -287,36 +299,36 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # ------------ # Update usage - Team # ------------ - if user_api_key_team_id is None: - return + if user_api_key_team_id is not None: + total_tokens = 0 - total_tokens = 0 + if isinstance(response_obj, ModelResponse): + total_tokens = response_obj.usage.total_tokens - if isinstance(response_obj, ModelResponse): - total_tokens = response_obj.usage.total_tokens + request_count_api_key = ( + f"{user_api_key_team_id}::{precise_minute}::request_count" + ) - request_count_api_key = ( - f"{user_api_key_team_id}::{precise_minute}::request_count" - ) + current = self.user_api_key_cache.get_cache( + key=request_count_api_key + ) or { + "current_requests": 1, + "current_tpm": total_tokens, + "current_rpm": 1, + } - current = self.user_api_key_cache.get_cache(key=request_count_api_key) or { - "current_requests": 1, - "current_tpm": total_tokens, - "current_rpm": 1, - } + new_val = { + "current_requests": max(current["current_requests"] - 1, 0), + "current_tpm": current["current_tpm"] + total_tokens, + "current_rpm": current["current_rpm"] + 1, + } - new_val = { - "current_requests": max(current["current_requests"] - 1, 0), - "current_tpm": current["current_tpm"] + total_tokens, - "current_rpm": current["current_rpm"] + 1, - } - - self.print_verbose( - f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" - ) - self.user_api_key_cache.set_cache( - request_count_api_key, new_val, ttl=60 - ) # store in cache for 1 min. + self.print_verbose( + f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" + ) + self.user_api_key_cache.set_cache( + request_count_api_key, new_val, ttl=60 + ) # store in cache for 1 min. except Exception as e: self.print_verbose(e) # noqa diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f2628cbe1..750ea27dd 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -361,6 +361,7 @@ async def user_api_key_auth( valid_token = await jwt_handler.auth_jwt(token=api_key) # get scopes scopes = jwt_handler.get_scopes(token=valid_token) + # check if admin is_admin = jwt_handler.is_admin(scopes=scopes) # if admin return @@ -453,9 +454,9 @@ async def user_api_key_auth( return UserAPIKeyAuth( api_key=None, team_id=team_object.team_id, - tpm_limit=team_object.tpm_limit, - rpm_limit=team_object.rpm_limit, - models=team_object.models, + team_tpm_limit=team_object.tpm_limit, + team_rpm_limit=team_object.rpm_limit, + team_models=team_object.models, user_role="app_owner", ) #### ELSE #### @@ -5759,7 +5760,7 @@ async def new_team( }, ) - if data.models is not None: + if data.models is not None and len(user_api_key_dict.models) > 0: for m in data.models: if m not in user_api_key_dict.models: raise HTTPException( diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index fe5a70b9c..0699137cc 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -177,3 +177,171 @@ async def test_valid_invalid_token(): response = await jwt_handler.auth_jwt(token=token) except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") + + +@pytest.fixture +def prisma_client(): + import litellm + from litellm.proxy.utils import PrismaClient, ProxyLogging + from litellm.proxy.proxy_cli import append_query_params + + proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) + + ### add connection pool + pool timeout args + params = {"connection_limit": 100, "pool_timeout": 60} + database_url = os.getenv("DATABASE_URL") + modified_url = append_query_params(database_url, params) + os.environ["DATABASE_URL"] = modified_url + + # Assuming DBClient is a class that needs to be instantiated + prisma_client = PrismaClient( + database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj + ) + + return prisma_client + + +@pytest.mark.asyncio +async def test_team_token_output(prisma_client): + import jwt, json + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.backends import default_backend + from fastapi import Request + from starlette.datastructures import URL + from litellm.proxy.proxy_server import user_api_key_auth, new_team + from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth + import litellm + import uuid + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + await litellm.proxy.proxy_server.prisma_client.connect() + + # Generate a private / public key pair using RSA algorithm + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + # Get private key in PEM format + private_key = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + # Get public key in PEM format + public_key = key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + public_key_obj = serialization.load_pem_public_key( + public_key, backend=default_backend() + ) + + # Convert RSA public key object to JWK (JSON Web Key) + public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) + + assert isinstance(public_jwk, dict) + + # set cache + cache = DualCache() + + await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + + jwt_handler = JWTHandler() + + jwt_handler.user_api_key_cache = cache + + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth() + + # VALID TOKEN + ## GENERATE A TOKEN + # Assuming the current time is in UTC + expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp()) + + team_id = f"team123_{uuid.uuid4()}" + payload = { + "sub": "user123", + "exp": expiration_time, # set the token to expire in 10 minutes + "scope": "litellm_team", + "client_id": team_id, + } + + # Generate the JWT token + # But before, you should convert bytes to string + private_key_str = private_key.decode("utf-8") + + ## team token + token = jwt.encode(payload, private_key_str, algorithm="RS256") + + ## admin token + payload = { + "sub": "user123", + "exp": expiration_time, # set the token to expire in 10 minutes + "scope": "litellm_proxy_admin", + } + + admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") + + ## VERIFY IT WORKS + + # verify token + + response = await jwt_handler.auth_jwt(token=token) + + ## RUN IT THROUGH USER API KEY AUTH + + """ + - 1. Initial call should fail -> team doesn't exist + - 2. Create team via admin token + - 3. 2nd call w/ same team -> call should succeed -> assert UserAPIKeyAuth object correctly formatted + """ + + bearer_token = "Bearer " + token + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + ## 1. INITIAL TEAM CALL - should fail + # use generated key to auth in + setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True}) + setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) + try: + result = await user_api_key_auth(request=request, api_key=bearer_token) + pytest.fail("Team doesn't exist. This should fail") + except Exception as e: + pass + + ## 2. CREATE TEAM W/ ADMIN TOKEN - should succeed + try: + bearer_token = "Bearer " + admin_token + + request._url = URL(url="/team/new") + result = await user_api_key_auth(request=request, api_key=bearer_token) + await new_team( + data=NewTeamRequest( + team_id=team_id, + tpm_limit=100, + rpm_limit=99, + models=["gpt-3.5-turbo", "gpt-4"], + ), + user_api_key_dict=result, + ) + except Exception as e: + pytest.fail(f"This should not fail - {str(e)}") + + ## 3. 2nd CALL W/ TEAM TOKEN - should succeed + bearer_token = "Bearer " + token + request._url = URL(url="/chat/completions") + try: + team_result: UserAPIKeyAuth = await user_api_key_auth( + request=request, api_key=bearer_token + ) + except Exception as e: + pytest.fail(f"Team exists. This should not fail - {e}") + + ## 4. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py) + + assert team_result.team_tpm_limit == 100 + assert team_result.team_rpm_limit == 99 + assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]