mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
add is_virtual_key_allowed_to_call_route
This commit is contained in:
parent
119ea80f60
commit
b8a1bc521a
3 changed files with 63 additions and 3 deletions
|
@ -1494,6 +1494,7 @@ class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase):
|
||||||
budget_duration: Optional[str] = None
|
budget_duration: Optional[str] = None
|
||||||
budget_reset_at: Optional[datetime] = None
|
budget_reset_at: Optional[datetime] = None
|
||||||
allowed_cache_controls: Optional[list] = []
|
allowed_cache_controls: Optional[list] = []
|
||||||
|
allowed_routes: Optional[list] = []
|
||||||
permissions: Dict = {}
|
permissions: Dict = {}
|
||||||
model_spend: Dict = {}
|
model_spend: Dict = {}
|
||||||
model_max_budget: Dict = {}
|
model_max_budget: Dict = {}
|
||||||
|
|
|
@ -2,11 +2,11 @@
|
||||||
## Common auth checks between jwt + key based auth
|
## Common auth checks between jwt + key based auth
|
||||||
"""
|
"""
|
||||||
Got Valid Token from Cache, DB
|
Got Valid Token from Cache, DB
|
||||||
Run checks for:
|
Run checks for:
|
||||||
|
|
||||||
1. If user can call model
|
1. If user can call model
|
||||||
2. If user is in budget
|
2. If user is in budget
|
||||||
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
|
@ -270,6 +270,11 @@ def _is_api_route_allowed(
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
raise Exception("Invalid proxy server token passed. valid_token=None.")
|
raise Exception("Invalid proxy server token passed. valid_token=None.")
|
||||||
|
|
||||||
|
# Check if Virtual Key is allowed to call the route - Applies to all Roles
|
||||||
|
RouteChecks.is_virtual_key_allowed_to_call_route(
|
||||||
|
route=route, valid_token=valid_token
|
||||||
|
)
|
||||||
|
|
||||||
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
||||||
RouteChecks.non_proxy_admin_allowed_routes_check(
|
RouteChecks.non_proxy_admin_allowed_routes_check(
|
||||||
user_obj=user_obj,
|
user_obj=user_obj,
|
||||||
|
|
|
@ -16,6 +16,31 @@ from .auth_checks_organization import _user_is_org_admin
|
||||||
|
|
||||||
|
|
||||||
class RouteChecks:
|
class RouteChecks:
|
||||||
|
@staticmethod
|
||||||
|
def is_virtual_key_allowed_to_call_route(
|
||||||
|
route: str, valid_token: UserAPIKeyAuth
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Raises Exception if Virtual Key is not allowed to call the route
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
valid_token.allowed_routes
|
||||||
|
and isinstance(valid_token.allowed_routes, list)
|
||||||
|
and len(valid_token.allowed_routes) > 0
|
||||||
|
):
|
||||||
|
# explicit check for allowed routes
|
||||||
|
if route in valid_token.allowed_routes:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# check if wildcard pattern is allowed
|
||||||
|
for allowed_route in valid_token.allowed_routes:
|
||||||
|
if RouteChecks._route_matches_wildcard_pattern(
|
||||||
|
route=route, pattern=allowed_route
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
raise Exception("Virtual key is not allowed to call this route.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def non_proxy_admin_allowed_routes_check(
|
def non_proxy_admin_allowed_routes_check(
|
||||||
user_obj: Optional[LiteLLM_UserTable],
|
user_obj: Optional[LiteLLM_UserTable],
|
||||||
|
@ -220,6 +245,35 @@ class RouteChecks:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _route_matches_wildcard_pattern(route: str, pattern: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if route matches the wildcard pattern
|
||||||
|
|
||||||
|
eg.
|
||||||
|
|
||||||
|
pattern: "/scim/v2/*"
|
||||||
|
route: "/scim/v2/Users"
|
||||||
|
- returns: True
|
||||||
|
|
||||||
|
pattern: "/scim/v2/*"
|
||||||
|
route: "/chat/completions"
|
||||||
|
- returns: False
|
||||||
|
|
||||||
|
|
||||||
|
pattern: "/scim/v2/*"
|
||||||
|
route: "/scim/v2/Users/123"
|
||||||
|
- returns: True
|
||||||
|
|
||||||
|
"""
|
||||||
|
if pattern.endswith("*"):
|
||||||
|
# Get the prefix (everything before the wildcard)
|
||||||
|
prefix = pattern[:-1]
|
||||||
|
return route.startswith(prefix)
|
||||||
|
else:
|
||||||
|
# If there's no wildcard, the pattern and route should match exactly
|
||||||
|
return route == pattern
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_route_access(route: str, allowed_routes: List[str]) -> bool:
|
def check_route_access(route: str, allowed_routes: List[str]) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue