Litellm dev 02 06 2025 p3 (#8343)

* feat(handle_jwt.py): initial commit to allow scope based model access

* feat(handle_jwt.py): allow model access based on token scopes

allow admin to control model access from IDP

* test(test_jwt.py): add unit testing for scope based model access

* docs(token_auth.md): add scope based model access to docs

* docs(token_auth.md): update docs

* docs(token_auth.md): update docs

* build: add gemini commercial rate limits

* fix: fix linting error
This commit is contained in:
Krish Dholakia 2025-02-06 23:15:33 -08:00 committed by GitHub
parent f87ab251b0
commit d720744656
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 238 additions and 7 deletions

View file

@ -371,3 +371,67 @@ Supported internal roles:
- `proxy_admin`: Proxy admin will be used for RBAC spend tracking. Use this for granting admin access to a token.
### [Architecture Diagram (Control Model Access)](./jwt_auth_arch)
## [BETA] Control Model Access with Scopes
Control which models a JWT can access. Set `enforce_scope_based_access: true` to enforce scope-based access control.
### 1. Setup config.yaml with scope mappings.
```yaml
model_list:
- model_name: anthropic-claude
litellm_params:
model: anthropic/claude-3-5-sonnet
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: gpt-3.5-turbo-testing
litellm_params:
model: gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
team_id_jwt_field: "client_id" # 👈 set the field in the JWT token that contains the team id
team_id_upsert: true # 👈 upsert the team to db, if team id is not found in db
scope_mappings:
- scope: litellm.api.consumer
models: ["anthropic-claude"]
- scope: litellm.api.gpt_3_5_turbo
models: ["gpt-3.5-turbo-testing"]
enforce_scope_based_access: true # 👈 enforce scope-based access control
enforce_rbac: true # 👈 enforces only a Team/User/ProxyAdmin can access the proxy.
```
#### Scope Mapping Spec
- `scope`: The scope to be used for the JWT token.
- `models`: The models that the JWT token can access. Value is the `model_name` in `model_list`. Note: Wildcard routes are not currently supported.
### 2. Create a JWT with the correct scopes.
Expected Token:
```
{
"scope": ["litellm.api.consumer", "litellm.api.gpt_3_5_turbo"]
}
```
### 3. Test the flow.
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer eyJhbGci...' \
-d '{
"model": "gpt-3.5-turbo-testing",
"messages": [
{
"role": "user",
"content": "Hey, how'\''s it going 1234?"
}
]
}'
```

View file

@ -27,3 +27,17 @@ model_list:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
team_id_jwt_field: "client_id"
team_id_upsert: true
scope_mappings:
- scope: litellm.api.consumer
models: ["anthropic-claude"]
routes: ["/v1/chat/completions"]
- scope: litellm.api.gpt_3_5_turbo
models: ["gpt-3.5-turbo-testing"]
enforce_scope_based_access: true
enforce_rbac: true

View file

@ -1040,6 +1040,13 @@ class LiteLLM_TeamTable(TeamBase):
"model_max_budget",
"model_aliases",
]
if (
isinstance(values.get("members_with_roles"), dict)
and not values["members_with_roles"]
):
values["members_with_roles"] = []
for field in dict_fields:
value = values.get(field)
if value is not None and isinstance(value, str):
@ -2279,11 +2286,14 @@ RBAC_ROLES = Literal[
]
class RoleBasedPermissions(LiteLLMPydanticObjectBase):
role: RBAC_ROLES
class OIDCPermissions(LiteLLMPydanticObjectBase):
models: Optional[List[str]] = None
routes: Optional[List[str]] = None
class RoleBasedPermissions(OIDCPermissions):
role: RBAC_ROLES
model_config = {
"extra": "forbid",
}
@ -2294,6 +2304,14 @@ class RoleMapping(BaseModel):
internal_role: RBAC_ROLES
class ScopeMapping(OIDCPermissions):
scope: str
model_config = {
"extra": "forbid",
}
class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
"""
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
@ -2323,6 +2341,7 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
"info_routes",
]
team_id_jwt_field: Optional[str] = None
team_id_upsert: bool = False
team_ids_jwt_field: Optional[str] = None
upsert_sso_user_to_team: bool = False
team_allowed_routes: List[
@ -2351,6 +2370,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
object_id_jwt_field: Optional[str] = (
None # can be either user / team, inferred from the role mapping
)
scope_mappings: Optional[List[ScopeMapping]] = None
enforce_scope_based_access: bool = False
def __init__(self, **kwargs: Any) -> None:
# get the attribute names for this Pydantic model
@ -2361,6 +2382,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
user_allowed_roles = kwargs.get("user_allowed_roles")
object_id_jwt_field = kwargs.get("object_id_jwt_field")
role_mappings = kwargs.get("role_mappings")
scope_mappings = kwargs.get("scope_mappings")
enforce_scope_based_access = kwargs.get("enforce_scope_based_access")
if invalid_keys:
raise ValueError(
@ -2378,4 +2401,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
"if object_id_jwt_field is set, role_mappings must also be set. Needed to infer if the caller is a user or team."
)
if scope_mappings is not None and not enforce_scope_based_access:
raise ValueError(
"scope_mappings must be set if enforce_scope_based_access is true."
)
super().__init__(**kwargs)

View file

@ -655,11 +655,20 @@ async def _delete_cache_key_object(
@log_db_metrics
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
return await prisma_client.db.litellm_teamtable.find_unique(
async def _get_team_db_check(
team_id: str, prisma_client: PrismaClient, team_id_upsert: Optional[bool] = None
):
response = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
if response is None and team_id_upsert:
response = await prisma_client.db.litellm_teamtable.create(
data={"team_id": team_id}
)
return response
async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient):
return await prisma_client.db.litellm_teamtable.find_unique(
@ -675,6 +684,7 @@ async def _get_team_object_from_user_api_key_cache(
db_cache_expiry: int,
proxy_logging_obj: Optional[ProxyLogging],
key: str,
team_id_upsert: Optional[bool] = None,
) -> LiteLLM_TeamTableCachedObj:
db_access_time_key = key
should_check_db = _should_check_db(
@ -684,7 +694,7 @@ async def _get_team_object_from_user_api_key_cache(
)
if should_check_db:
response = await _get_team_db_check(
team_id=team_id, prisma_client=prisma_client
team_id=team_id, prisma_client=prisma_client, team_id_upsert=team_id_upsert
)
else:
response = None
@ -752,6 +762,7 @@ async def get_team_object(
proxy_logging_obj: Optional[ProxyLogging] = None,
check_cache_only: Optional[bool] = None,
check_db_only: Optional[bool] = None,
team_id_upsert: Optional[bool] = None,
) -> LiteLLM_TeamTableCachedObj:
"""
- Check if team id in proxy Team Table
@ -795,6 +806,7 @@ async def get_team_object(
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
key=key,
team_id_upsert=team_id_upsert,
)
except Exception:
raise Exception(

View file

@ -30,6 +30,7 @@ from litellm.proxy._types import (
LiteLLM_TeamTable,
LiteLLM_UserTable,
LitellmUserRoles,
ScopeMapping,
Span,
)
from litellm.proxy.utils import PrismaClient, ProxyLogging
@ -318,7 +319,7 @@ class JWTHandler:
org_id = default_value
return org_id
def get_scopes(self, token: dict) -> list:
def get_scopes(self, token: dict) -> List[str]:
try:
if isinstance(token["scope"], str):
# Assuming the scopes are stored in 'scope' claim and are space-separated
@ -543,6 +544,40 @@ class JWTAuthManager:
return True
@staticmethod
def check_scope_based_access(
scope_mappings: List[ScopeMapping],
scopes: List[str],
request_data: dict,
general_settings: dict,
) -> None:
"""
Check if scope allows access to the requested model
"""
if not scope_mappings:
return None
allowed_models = []
for sm in scope_mappings:
if sm.scope in scopes and sm.models:
allowed_models.extend(sm.models)
requested_model = request_data.get("model")
if not requested_model:
return None
if requested_model not in allowed_models:
raise HTTPException(
status_code=403,
detail={
"error": "model={} not allowed. Allowed_models={}".format(
requested_model, allowed_models
)
},
)
return None
@staticmethod
async def check_rbac_role(
jwt_handler: JWTHandler,
@ -636,6 +671,7 @@ class JWTAuthManager:
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert,
)
return individual_team_id, team_object
@ -829,6 +865,19 @@ class JWTAuthManager:
rbac_role,
)
# Check Scope Based Access
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
if (
jwt_handler.litellm_jwtauth.enforce_scope_based_access
and jwt_handler.litellm_jwtauth.scope_mappings
):
JWTAuthManager.check_scope_based_access(
scope_mappings=jwt_handler.litellm_jwtauth.scope_mappings,
scopes=scopes,
request_data=request_data,
general_settings=general_settings,
)
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
# Get basic user info

View file

@ -1183,3 +1183,67 @@ def test_can_rbac_role_call_route():
},
route="/v1/embeddings",
)
@pytest.mark.parametrize(
"requested_model, should_work",
[
("gpt-3.5-turbo-testing", True),
("gpt-4o", False),
],
)
def test_check_scope_based_access(requested_model, should_work):
from litellm.proxy.auth.handle_jwt import JWTAuthManager
from litellm.proxy._types import ScopeMapping
args = {
"scope_mappings": [
ScopeMapping(
models=["anthropic-claude"],
routes=["/v1/chat/completions"],
scope="litellm.api.consumer",
),
ScopeMapping(
models=["gpt-3.5-turbo-testing"],
routes=None,
scope="litellm.api.gpt_3_5_turbo",
),
],
"scopes": [
"profile",
"groups-scope",
"email",
"litellm.api.gpt_3_5_turbo",
"litellm.api.consumer",
],
"request_data": {
"model": requested_model,
"messages": [{"role": "user", "content": "Hey, how's it going 1234?"}],
},
"general_settings": {
"enable_jwt_auth": True,
"litellm_jwtauth": {
"team_id_jwt_field": "client_id",
"team_id_upsert": True,
"scope_mappings": [
{
"scope": "litellm.api.consumer",
"models": ["anthropic-claude"],
"routes": ["/v1/chat/completions"],
},
{
"scope": "litellm.api.gpt_3_5_turbo",
"models": ["gpt-3.5-turbo-testing"],
},
],
"enforce_scope_based_access": True,
"enforce_rbac": True,
},
},
}
if should_work:
JWTAuthManager.check_scope_based_access(**args)
else:
with pytest.raises(HTTPException):
JWTAuthManager.check_scope_based_access(**args)