add is_virtual_key_allowed_to_call_route

This commit is contained in:
Ishaan Jaff 2025-04-16 17:27:18 -07:00
parent 119ea80f60
commit b8a1bc521a
3 changed files with 63 additions and 3 deletions

View file

@ -1494,6 +1494,7 @@ class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase):
budget_duration: Optional[str] = None budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None budget_reset_at: Optional[datetime] = None
allowed_cache_controls: Optional[list] = [] allowed_cache_controls: Optional[list] = []
allowed_routes: Optional[list] = []
permissions: Dict = {} permissions: Dict = {}
model_spend: Dict = {} model_spend: Dict = {}
model_max_budget: Dict = {} model_max_budget: Dict = {}

View file

@ -270,6 +270,11 @@ def _is_api_route_allowed(
if valid_token is None: if valid_token is None:
raise Exception("Invalid proxy server token passed. valid_token=None.") raise Exception("Invalid proxy server token passed. valid_token=None.")
# Check if Virtual Key is allowed to call the route - Applies to all Roles
RouteChecks.is_virtual_key_allowed_to_call_route(
route=route, valid_token=valid_token
)
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
RouteChecks.non_proxy_admin_allowed_routes_check( RouteChecks.non_proxy_admin_allowed_routes_check(
user_obj=user_obj, user_obj=user_obj,

View file

@ -16,6 +16,31 @@ from .auth_checks_organization import _user_is_org_admin
class RouteChecks: class RouteChecks:
@staticmethod
def is_virtual_key_allowed_to_call_route(
route: str, valid_token: UserAPIKeyAuth
) -> bool:
"""
Raises Exception if Virtual Key is not allowed to call the route
"""
if (
valid_token.allowed_routes
and isinstance(valid_token.allowed_routes, list)
and len(valid_token.allowed_routes) > 0
):
# explicit check for allowed routes
if route in valid_token.allowed_routes:
return True
# check if wildcard pattern is allowed
for allowed_route in valid_token.allowed_routes:
if RouteChecks._route_matches_wildcard_pattern(
route=route, pattern=allowed_route
):
return True
raise Exception("Virtual key is not allowed to call this route.")
@staticmethod @staticmethod
def non_proxy_admin_allowed_routes_check( def non_proxy_admin_allowed_routes_check(
user_obj: Optional[LiteLLM_UserTable], user_obj: Optional[LiteLLM_UserTable],
@ -220,6 +245,35 @@ class RouteChecks:
return True return True
return False return False
@staticmethod
def _route_matches_wildcard_pattern(route: str, pattern: str) -> bool:
"""
Check if route matches the wildcard pattern
eg.
pattern: "/scim/v2/*"
route: "/scim/v2/Users"
- returns: True
pattern: "/scim/v2/*"
route: "/chat/completions"
- returns: False
pattern: "/scim/v2/*"
route: "/scim/v2/Users/123"
- returns: True
"""
if pattern.endswith("*"):
# Get the prefix (everything before the wildcard)
prefix = pattern[:-1]
return route.startswith(prefix)
else:
# If there's no wildcard, the pattern and route should match exactly
return route == pattern
@staticmethod @staticmethod
def check_route_access(route: str, allowed_routes: List[str]) -> bool: def check_route_access(route: str, allowed_routes: List[str]) -> bool:
""" """