Litellm dev 02 06 2025 p3 (#8343)

* feat(handle_jwt.py): initial commit to allow scope based model access

* feat(handle_jwt.py): allow model access based on token scopes

allow admin to control model access from IDP

* test(test_jwt.py): add unit testing for scope based model access

* docs(token_auth.md): add scope based model access to docs

* docs(token_auth.md): update docs

* docs(token_auth.md): update docs

* build: add gemini commercial rate limits

* fix: fix linting error
This commit is contained in:
Krish Dholakia 2025-02-06 23:15:33 -08:00 committed by GitHub
parent f87ab251b0
commit d720744656
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 238 additions and 7 deletions

View file

@ -30,6 +30,7 @@ from litellm.proxy._types import (
LiteLLM_TeamTable,
LiteLLM_UserTable,
LitellmUserRoles,
ScopeMapping,
Span,
)
from litellm.proxy.utils import PrismaClient, ProxyLogging
@ -318,7 +319,7 @@ class JWTHandler:
org_id = default_value
return org_id
def get_scopes(self, token: dict) -> list:
def get_scopes(self, token: dict) -> List[str]:
try:
if isinstance(token["scope"], str):
# Assuming the scopes are stored in 'scope' claim and are space-separated
@ -543,6 +544,40 @@ class JWTAuthManager:
return True
@staticmethod
def check_scope_based_access(
scope_mappings: List[ScopeMapping],
scopes: List[str],
request_data: dict,
general_settings: dict,
) -> None:
"""
Check if scope allows access to the requested model
"""
if not scope_mappings:
return None
allowed_models = []
for sm in scope_mappings:
if sm.scope in scopes and sm.models:
allowed_models.extend(sm.models)
requested_model = request_data.get("model")
if not requested_model:
return None
if requested_model not in allowed_models:
raise HTTPException(
status_code=403,
detail={
"error": "model={} not allowed. Allowed_models={}".format(
requested_model, allowed_models
)
},
)
return None
@staticmethod
async def check_rbac_role(
jwt_handler: JWTHandler,
@ -636,6 +671,7 @@ class JWTAuthManager:
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert,
)
return individual_team_id, team_object
@ -829,6 +865,19 @@ class JWTAuthManager:
rbac_role,
)
# Check Scope Based Access
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
if (
jwt_handler.litellm_jwtauth.enforce_scope_based_access
and jwt_handler.litellm_jwtauth.scope_mappings
):
JWTAuthManager.check_scope_based_access(
scope_mappings=jwt_handler.litellm_jwtauth.scope_mappings,
scopes=scopes,
request_data=request_data,
general_settings=general_settings,
)
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
# Get basic user info