diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 4bd47248b8..c82d8c0132 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1494,6 +1494,7 @@ class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase): budget_duration: Optional[str] = None budget_reset_at: Optional[datetime] = None allowed_cache_controls: Optional[list] = [] + allowed_routes: Optional[list] = [] permissions: Dict = {} model_spend: Dict = {} model_max_budget: Dict = {} diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 1e0c8a4609..bf8fd2350e 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -2,11 +2,11 @@ ## Common auth checks between jwt + key based auth """ Got Valid Token from Cache, DB -Run checks for: +Run checks for: 1. If user can call model -2. If user is in budget -3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget +2. If user is in budget +3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget """ import asyncio import re @@ -270,6 +270,11 @@ def _is_api_route_allowed( if valid_token is None: raise Exception("Invalid proxy server token passed. valid_token=None.") + # Check if Virtual Key is allowed to call the route - Applies to all Roles + RouteChecks.is_virtual_key_allowed_to_call_route( + route=route, valid_token=valid_token + ) + if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin RouteChecks.non_proxy_admin_allowed_routes_check( user_obj=user_obj, diff --git a/litellm/proxy/auth/route_checks.py b/litellm/proxy/auth/route_checks.py index 41529512b6..347a266da9 100644 --- a/litellm/proxy/auth/route_checks.py +++ b/litellm/proxy/auth/route_checks.py @@ -16,6 +16,31 @@ from .auth_checks_organization import _user_is_org_admin class RouteChecks: + @staticmethod + def is_virtual_key_allowed_to_call_route( + route: str, valid_token: UserAPIKeyAuth + ) -> bool: + """ + Raises Exception if Virtual Key is not allowed to call the route + """ + if ( + valid_token.allowed_routes + and isinstance(valid_token.allowed_routes, list) + and len(valid_token.allowed_routes) > 0 + ): + # explicit check for allowed routes + if route in valid_token.allowed_routes: + return True + + # check if wildcard pattern is allowed + for allowed_route in valid_token.allowed_routes: + if RouteChecks._route_matches_wildcard_pattern( + route=route, pattern=allowed_route + ): + return True + + raise Exception("Virtual key is not allowed to call this route.") + @staticmethod def non_proxy_admin_allowed_routes_check( user_obj: Optional[LiteLLM_UserTable], @@ -220,6 +245,35 @@ class RouteChecks: return True return False + @staticmethod + def _route_matches_wildcard_pattern(route: str, pattern: str) -> bool: + """ + Check if route matches the wildcard pattern + + eg. + + pattern: "/scim/v2/*" + route: "/scim/v2/Users" + - returns: True + + pattern: "/scim/v2/*" + route: "/chat/completions" + - returns: False + + + pattern: "/scim/v2/*" + route: "/scim/v2/Users/123" + - returns: True + + """ + if pattern.endswith("*"): + # Get the prefix (everything before the wildcard) + prefix = pattern[:-1] + return route.startswith(prefix) + else: + # If there's no wildcard, the pattern and route should match exactly + return route == pattern + @staticmethod def check_route_access(route: str, allowed_routes: List[str]) -> bool: """