add is_virtual_key_allowed_to_call_route

This commit is contained in:
Ishaan Jaff 2025-04-16 17:27:18 -07:00
parent 119ea80f60
commit b8a1bc521a
3 changed files with 63 additions and 3 deletions

View file

@ -1494,6 +1494,7 @@ class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase):
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
allowed_cache_controls: Optional[list] = []
allowed_routes: Optional[list] = []
permissions: Dict = {}
model_spend: Dict = {}
model_max_budget: Dict = {}

View file

@ -2,11 +2,11 @@
## Common auth checks between jwt + key based auth
"""
Got Valid Token from Cache, DB
Run checks for:
Run checks for:
1. If user can call model
2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
import asyncio
import re
@ -270,6 +270,11 @@ def _is_api_route_allowed(
if valid_token is None:
raise Exception("Invalid proxy server token passed. valid_token=None.")
# Check if Virtual Key is allowed to call the route - Applies to all Roles
RouteChecks.is_virtual_key_allowed_to_call_route(
route=route, valid_token=valid_token
)
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
RouteChecks.non_proxy_admin_allowed_routes_check(
user_obj=user_obj,

View file

@ -16,6 +16,31 @@ from .auth_checks_organization import _user_is_org_admin
class RouteChecks:
@staticmethod
def is_virtual_key_allowed_to_call_route(
route: str, valid_token: UserAPIKeyAuth
) -> bool:
"""
Raises Exception if Virtual Key is not allowed to call the route
"""
if (
valid_token.allowed_routes
and isinstance(valid_token.allowed_routes, list)
and len(valid_token.allowed_routes) > 0
):
# explicit check for allowed routes
if route in valid_token.allowed_routes:
return True
# check if wildcard pattern is allowed
for allowed_route in valid_token.allowed_routes:
if RouteChecks._route_matches_wildcard_pattern(
route=route, pattern=allowed_route
):
return True
raise Exception("Virtual key is not allowed to call this route.")
@staticmethod
def non_proxy_admin_allowed_routes_check(
user_obj: Optional[LiteLLM_UserTable],
@ -220,6 +245,35 @@ class RouteChecks:
return True
return False
@staticmethod
def _route_matches_wildcard_pattern(route: str, pattern: str) -> bool:
"""
Check if route matches the wildcard pattern
eg.
pattern: "/scim/v2/*"
route: "/scim/v2/Users"
- returns: True
pattern: "/scim/v2/*"
route: "/chat/completions"
- returns: False
pattern: "/scim/v2/*"
route: "/scim/v2/Users/123"
- returns: True
"""
if pattern.endswith("*"):
# Get the prefix (everything before the wildcard)
prefix = pattern[:-1]
return route.startswith(prefix)
else:
# If there's no wildcard, the pattern and route should match exactly
return route == pattern
@staticmethod
def check_route_access(route: str, allowed_routes: List[str]) -> bool:
"""