mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
add is_virtual_key_allowed_to_call_route
This commit is contained in:
parent
119ea80f60
commit
b8a1bc521a
3 changed files with 63 additions and 3 deletions
|
@ -1494,6 +1494,7 @@ class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase):
|
|||
budget_duration: Optional[str] = None
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
allowed_cache_controls: Optional[list] = []
|
||||
allowed_routes: Optional[list] = []
|
||||
permissions: Dict = {}
|
||||
model_spend: Dict = {}
|
||||
model_max_budget: Dict = {}
|
||||
|
|
|
@ -270,6 +270,11 @@ def _is_api_route_allowed(
|
|||
if valid_token is None:
|
||||
raise Exception("Invalid proxy server token passed. valid_token=None.")
|
||||
|
||||
# Check if Virtual Key is allowed to call the route - Applies to all Roles
|
||||
RouteChecks.is_virtual_key_allowed_to_call_route(
|
||||
route=route, valid_token=valid_token
|
||||
)
|
||||
|
||||
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
||||
RouteChecks.non_proxy_admin_allowed_routes_check(
|
||||
user_obj=user_obj,
|
||||
|
|
|
@ -16,6 +16,31 @@ from .auth_checks_organization import _user_is_org_admin
|
|||
|
||||
|
||||
class RouteChecks:
|
||||
@staticmethod
|
||||
def is_virtual_key_allowed_to_call_route(
|
||||
route: str, valid_token: UserAPIKeyAuth
|
||||
) -> bool:
|
||||
"""
|
||||
Raises Exception if Virtual Key is not allowed to call the route
|
||||
"""
|
||||
if (
|
||||
valid_token.allowed_routes
|
||||
and isinstance(valid_token.allowed_routes, list)
|
||||
and len(valid_token.allowed_routes) > 0
|
||||
):
|
||||
# explicit check for allowed routes
|
||||
if route in valid_token.allowed_routes:
|
||||
return True
|
||||
|
||||
# check if wildcard pattern is allowed
|
||||
for allowed_route in valid_token.allowed_routes:
|
||||
if RouteChecks._route_matches_wildcard_pattern(
|
||||
route=route, pattern=allowed_route
|
||||
):
|
||||
return True
|
||||
|
||||
raise Exception("Virtual key is not allowed to call this route.")
|
||||
|
||||
@staticmethod
|
||||
def non_proxy_admin_allowed_routes_check(
|
||||
user_obj: Optional[LiteLLM_UserTable],
|
||||
|
@ -220,6 +245,35 @@ class RouteChecks:
|
|||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _route_matches_wildcard_pattern(route: str, pattern: str) -> bool:
|
||||
"""
|
||||
Check if route matches the wildcard pattern
|
||||
|
||||
eg.
|
||||
|
||||
pattern: "/scim/v2/*"
|
||||
route: "/scim/v2/Users"
|
||||
- returns: True
|
||||
|
||||
pattern: "/scim/v2/*"
|
||||
route: "/chat/completions"
|
||||
- returns: False
|
||||
|
||||
|
||||
pattern: "/scim/v2/*"
|
||||
route: "/scim/v2/Users/123"
|
||||
- returns: True
|
||||
|
||||
"""
|
||||
if pattern.endswith("*"):
|
||||
# Get the prefix (everything before the wildcard)
|
||||
prefix = pattern[:-1]
|
||||
return route.startswith(prefix)
|
||||
else:
|
||||
# If there's no wildcard, the pattern and route should match exactly
|
||||
return route == pattern
|
||||
|
||||
@staticmethod
|
||||
def check_route_access(route: str, allowed_routes: List[str]) -> bool:
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue