Litellm dev 02 06 2025 p3 (#8343)

* feat(handle_jwt.py): initial commit to allow scope based model access

* feat(handle_jwt.py): allow model access based on token scopes

allow admin to control model access from IDP

* test(test_jwt.py): add unit testing for scope based model access

* docs(token_auth.md): add scope based model access to docs

* docs(token_auth.md): update docs

* docs(token_auth.md): update docs

* build: add gemini commercial rate limits

* fix: fix linting error
This commit is contained in:
Krish Dholakia 2025-02-06 23:15:33 -08:00 committed by GitHub
parent f87ab251b0
commit d720744656
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 238 additions and 7 deletions

View file

@ -370,4 +370,68 @@ Supported internal roles:
- `internal_user`: User object will be used for RBAC spend tracking. Use this for tracking spend for an 'individual user'. - `internal_user`: User object will be used for RBAC spend tracking. Use this for tracking spend for an 'individual user'.
- `proxy_admin`: Proxy admin will be used for RBAC spend tracking. Use this for granting admin access to a token. - `proxy_admin`: Proxy admin will be used for RBAC spend tracking. Use this for granting admin access to a token.
### [Architecture Diagram (Control Model Access)](./jwt_auth_arch) ### [Architecture Diagram (Control Model Access)](./jwt_auth_arch)
## [BETA] Control Model Access with Scopes
Control which models a JWT can access. Set `enforce_scope_based_access: true` to enforce scope-based access control.
### 1. Setup config.yaml with scope mappings.
```yaml
model_list:
- model_name: anthropic-claude
litellm_params:
model: anthropic/claude-3-5-sonnet
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: gpt-3.5-turbo-testing
litellm_params:
model: gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
team_id_jwt_field: "client_id" # 👈 set the field in the JWT token that contains the team id
team_id_upsert: true # 👈 upsert the team to db, if team id is not found in db
scope_mappings:
- scope: litellm.api.consumer
models: ["anthropic-claude"]
- scope: litellm.api.gpt_3_5_turbo
models: ["gpt-3.5-turbo-testing"]
enforce_scope_based_access: true # 👈 enforce scope-based access control
enforce_rbac: true # 👈 enforces only a Team/User/ProxyAdmin can access the proxy.
```
#### Scope Mapping Spec
- `scope`: The scope to be used for the JWT token.
- `models`: The models that the JWT token can access. Value is the `model_name` in `model_list`. Note: Wildcard routes are not currently supported.
### 2. Create a JWT with the correct scopes.
Expected Token:
```
{
"scope": ["litellm.api.consumer", "litellm.api.gpt_3_5_turbo"]
}
```
### 3. Test the flow.
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer eyJhbGci...' \
-d '{
"model": "gpt-3.5-turbo-testing",
"messages": [
{
"role": "user",
"content": "Hey, how'\''s it going 1234?"
}
]
}'
```

View file

@ -27,3 +27,17 @@ model_list:
model: openai/fake model: openai/fake
api_key: fake-key api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
team_id_jwt_field: "client_id"
team_id_upsert: true
scope_mappings:
- scope: litellm.api.consumer
models: ["anthropic-claude"]
routes: ["/v1/chat/completions"]
- scope: litellm.api.gpt_3_5_turbo
models: ["gpt-3.5-turbo-testing"]
enforce_scope_based_access: true
enforce_rbac: true

View file

@ -1040,6 +1040,13 @@ class LiteLLM_TeamTable(TeamBase):
"model_max_budget", "model_max_budget",
"model_aliases", "model_aliases",
] ]
if (
isinstance(values.get("members_with_roles"), dict)
and not values["members_with_roles"]
):
values["members_with_roles"] = []
for field in dict_fields: for field in dict_fields:
value = values.get(field) value = values.get(field)
if value is not None and isinstance(value, str): if value is not None and isinstance(value, str):
@ -2279,11 +2286,14 @@ RBAC_ROLES = Literal[
] ]
class RoleBasedPermissions(LiteLLMPydanticObjectBase): class OIDCPermissions(LiteLLMPydanticObjectBase):
role: RBAC_ROLES
models: Optional[List[str]] = None models: Optional[List[str]] = None
routes: Optional[List[str]] = None routes: Optional[List[str]] = None
class RoleBasedPermissions(OIDCPermissions):
role: RBAC_ROLES
model_config = { model_config = {
"extra": "forbid", "extra": "forbid",
} }
@ -2294,6 +2304,14 @@ class RoleMapping(BaseModel):
internal_role: RBAC_ROLES internal_role: RBAC_ROLES
class ScopeMapping(OIDCPermissions):
scope: str
model_config = {
"extra": "forbid",
}
class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
""" """
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth. A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
@ -2323,6 +2341,7 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
"info_routes", "info_routes",
] ]
team_id_jwt_field: Optional[str] = None team_id_jwt_field: Optional[str] = None
team_id_upsert: bool = False
team_ids_jwt_field: Optional[str] = None team_ids_jwt_field: Optional[str] = None
upsert_sso_user_to_team: bool = False upsert_sso_user_to_team: bool = False
team_allowed_routes: List[ team_allowed_routes: List[
@ -2351,6 +2370,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
object_id_jwt_field: Optional[str] = ( object_id_jwt_field: Optional[str] = (
None # can be either user / team, inferred from the role mapping None # can be either user / team, inferred from the role mapping
) )
scope_mappings: Optional[List[ScopeMapping]] = None
enforce_scope_based_access: bool = False
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
# get the attribute names for this Pydantic model # get the attribute names for this Pydantic model
@ -2361,6 +2382,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
user_allowed_roles = kwargs.get("user_allowed_roles") user_allowed_roles = kwargs.get("user_allowed_roles")
object_id_jwt_field = kwargs.get("object_id_jwt_field") object_id_jwt_field = kwargs.get("object_id_jwt_field")
role_mappings = kwargs.get("role_mappings") role_mappings = kwargs.get("role_mappings")
scope_mappings = kwargs.get("scope_mappings")
enforce_scope_based_access = kwargs.get("enforce_scope_based_access")
if invalid_keys: if invalid_keys:
raise ValueError( raise ValueError(
@ -2378,4 +2401,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
"if object_id_jwt_field is set, role_mappings must also be set. Needed to infer if the caller is a user or team." "if object_id_jwt_field is set, role_mappings must also be set. Needed to infer if the caller is a user or team."
) )
if scope_mappings is not None and not enforce_scope_based_access:
raise ValueError(
"scope_mappings must be set if enforce_scope_based_access is true."
)
super().__init__(**kwargs) super().__init__(**kwargs)

View file

@ -655,11 +655,20 @@ async def _delete_cache_key_object(
@log_db_metrics @log_db_metrics
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient): async def _get_team_db_check(
return await prisma_client.db.litellm_teamtable.find_unique( team_id: str, prisma_client: PrismaClient, team_id_upsert: Optional[bool] = None
):
response = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id} where={"team_id": team_id}
) )
if response is None and team_id_upsert:
response = await prisma_client.db.litellm_teamtable.create(
data={"team_id": team_id}
)
return response
async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient): async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient):
return await prisma_client.db.litellm_teamtable.find_unique( return await prisma_client.db.litellm_teamtable.find_unique(
@ -675,6 +684,7 @@ async def _get_team_object_from_user_api_key_cache(
db_cache_expiry: int, db_cache_expiry: int,
proxy_logging_obj: Optional[ProxyLogging], proxy_logging_obj: Optional[ProxyLogging],
key: str, key: str,
team_id_upsert: Optional[bool] = None,
) -> LiteLLM_TeamTableCachedObj: ) -> LiteLLM_TeamTableCachedObj:
db_access_time_key = key db_access_time_key = key
should_check_db = _should_check_db( should_check_db = _should_check_db(
@ -684,7 +694,7 @@ async def _get_team_object_from_user_api_key_cache(
) )
if should_check_db: if should_check_db:
response = await _get_team_db_check( response = await _get_team_db_check(
team_id=team_id, prisma_client=prisma_client team_id=team_id, prisma_client=prisma_client, team_id_upsert=team_id_upsert
) )
else: else:
response = None response = None
@ -752,6 +762,7 @@ async def get_team_object(
proxy_logging_obj: Optional[ProxyLogging] = None, proxy_logging_obj: Optional[ProxyLogging] = None,
check_cache_only: Optional[bool] = None, check_cache_only: Optional[bool] = None,
check_db_only: Optional[bool] = None, check_db_only: Optional[bool] = None,
team_id_upsert: Optional[bool] = None,
) -> LiteLLM_TeamTableCachedObj: ) -> LiteLLM_TeamTableCachedObj:
""" """
- Check if team id in proxy Team Table - Check if team id in proxy Team Table
@ -795,6 +806,7 @@ async def get_team_object(
last_db_access_time=last_db_access_time, last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry, db_cache_expiry=db_cache_expiry,
key=key, key=key,
team_id_upsert=team_id_upsert,
) )
except Exception: except Exception:
raise Exception( raise Exception(

View file

@ -30,6 +30,7 @@ from litellm.proxy._types import (
LiteLLM_TeamTable, LiteLLM_TeamTable,
LiteLLM_UserTable, LiteLLM_UserTable,
LitellmUserRoles, LitellmUserRoles,
ScopeMapping,
Span, Span,
) )
from litellm.proxy.utils import PrismaClient, ProxyLogging from litellm.proxy.utils import PrismaClient, ProxyLogging
@ -318,7 +319,7 @@ class JWTHandler:
org_id = default_value org_id = default_value
return org_id return org_id
def get_scopes(self, token: dict) -> list: def get_scopes(self, token: dict) -> List[str]:
try: try:
if isinstance(token["scope"], str): if isinstance(token["scope"], str):
# Assuming the scopes are stored in 'scope' claim and are space-separated # Assuming the scopes are stored in 'scope' claim and are space-separated
@ -543,6 +544,40 @@ class JWTAuthManager:
return True return True
@staticmethod
def check_scope_based_access(
scope_mappings: List[ScopeMapping],
scopes: List[str],
request_data: dict,
general_settings: dict,
) -> None:
"""
Check if scope allows access to the requested model
"""
if not scope_mappings:
return None
allowed_models = []
for sm in scope_mappings:
if sm.scope in scopes and sm.models:
allowed_models.extend(sm.models)
requested_model = request_data.get("model")
if not requested_model:
return None
if requested_model not in allowed_models:
raise HTTPException(
status_code=403,
detail={
"error": "model={} not allowed. Allowed_models={}".format(
requested_model, allowed_models
)
},
)
return None
@staticmethod @staticmethod
async def check_rbac_role( async def check_rbac_role(
jwt_handler: JWTHandler, jwt_handler: JWTHandler,
@ -636,6 +671,7 @@ class JWTAuthManager:
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span, parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,
team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert,
) )
return individual_team_id, team_object return individual_team_id, team_object
@ -829,6 +865,19 @@ class JWTAuthManager:
rbac_role, rbac_role,
) )
# Check Scope Based Access
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
if (
jwt_handler.litellm_jwtauth.enforce_scope_based_access
and jwt_handler.litellm_jwtauth.scope_mappings
):
JWTAuthManager.check_scope_based_access(
scope_mappings=jwt_handler.litellm_jwtauth.scope_mappings,
scopes=scopes,
request_data=request_data,
general_settings=general_settings,
)
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None) object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
# Get basic user info # Get basic user info

View file

@ -1183,3 +1183,67 @@ def test_can_rbac_role_call_route():
}, },
route="/v1/embeddings", route="/v1/embeddings",
) )
@pytest.mark.parametrize(
"requested_model, should_work",
[
("gpt-3.5-turbo-testing", True),
("gpt-4o", False),
],
)
def test_check_scope_based_access(requested_model, should_work):
from litellm.proxy.auth.handle_jwt import JWTAuthManager
from litellm.proxy._types import ScopeMapping
args = {
"scope_mappings": [
ScopeMapping(
models=["anthropic-claude"],
routes=["/v1/chat/completions"],
scope="litellm.api.consumer",
),
ScopeMapping(
models=["gpt-3.5-turbo-testing"],
routes=None,
scope="litellm.api.gpt_3_5_turbo",
),
],
"scopes": [
"profile",
"groups-scope",
"email",
"litellm.api.gpt_3_5_turbo",
"litellm.api.consumer",
],
"request_data": {
"model": requested_model,
"messages": [{"role": "user", "content": "Hey, how's it going 1234?"}],
},
"general_settings": {
"enable_jwt_auth": True,
"litellm_jwtauth": {
"team_id_jwt_field": "client_id",
"team_id_upsert": True,
"scope_mappings": [
{
"scope": "litellm.api.consumer",
"models": ["anthropic-claude"],
"routes": ["/v1/chat/completions"],
},
{
"scope": "litellm.api.gpt_3_5_turbo",
"models": ["gpt-3.5-turbo-testing"],
},
],
"enforce_scope_based_access": True,
"enforce_rbac": True,
},
},
}
if should_work:
JWTAuthManager.check_scope_based_access(**args)
else:
with pytest.raises(HTTPException):
JWTAuthManager.check_scope_based_access(**args)