mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Litellm dev 02 06 2025 p3 (#8343)
* feat(handle_jwt.py): initial commit to allow scope based model access * feat(handle_jwt.py): allow model access based on token scopes allow admin to control model access from IDP * test(test_jwt.py): add unit testing for scope based model access * docs(token_auth.md): add scope based model access to docs * docs(token_auth.md): update docs * docs(token_auth.md): update docs * build: add gemini commercial rate limits * fix: fix linting error
This commit is contained in:
parent
f87ab251b0
commit
d720744656
6 changed files with 238 additions and 7 deletions
|
@ -370,4 +370,68 @@ Supported internal roles:
|
|||
- `internal_user`: User object will be used for RBAC spend tracking. Use this for tracking spend for an 'individual user'.
|
||||
- `proxy_admin`: Proxy admin will be used for RBAC spend tracking. Use this for granting admin access to a token.
|
||||
|
||||
### [Architecture Diagram (Control Model Access)](./jwt_auth_arch)
|
||||
### [Architecture Diagram (Control Model Access)](./jwt_auth_arch)
|
||||
|
||||
## [BETA] Control Model Access with Scopes
|
||||
|
||||
Control which models a JWT can access. Set `enforce_scope_based_access: true` to enforce scope-based access control.
|
||||
|
||||
### 1. Setup config.yaml with scope mappings.
|
||||
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: anthropic-claude
|
||||
litellm_params:
|
||||
model: anthropic/claude-3-5-sonnet
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: gpt-3.5-turbo-testing
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
general_settings:
|
||||
enable_jwt_auth: True
|
||||
litellm_jwtauth:
|
||||
team_id_jwt_field: "client_id" # 👈 set the field in the JWT token that contains the team id
|
||||
team_id_upsert: true # 👈 upsert the team to db, if team id is not found in db
|
||||
scope_mappings:
|
||||
- scope: litellm.api.consumer
|
||||
models: ["anthropic-claude"]
|
||||
- scope: litellm.api.gpt_3_5_turbo
|
||||
models: ["gpt-3.5-turbo-testing"]
|
||||
enforce_scope_based_access: true # 👈 enforce scope-based access control
|
||||
enforce_rbac: true # 👈 enforces only a Team/User/ProxyAdmin can access the proxy.
|
||||
```
|
||||
|
||||
#### Scope Mapping Spec
|
||||
|
||||
- `scope`: The scope to be used for the JWT token.
|
||||
- `models`: The models that the JWT token can access. Value is the `model_name` in `model_list`. Note: Wildcard routes are not currently supported.
|
||||
|
||||
### 2. Create a JWT with the correct scopes.
|
||||
|
||||
Expected Token:
|
||||
|
||||
```
|
||||
{
|
||||
"scope": ["litellm.api.consumer", "litellm.api.gpt_3_5_turbo"]
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Test the flow.
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer eyJhbGci...' \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo-testing",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey, how'\''s it going 1234?"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
|
@ -27,3 +27,17 @@ model_list:
|
|||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
general_settings:
|
||||
enable_jwt_auth: True
|
||||
litellm_jwtauth:
|
||||
team_id_jwt_field: "client_id"
|
||||
team_id_upsert: true
|
||||
scope_mappings:
|
||||
- scope: litellm.api.consumer
|
||||
models: ["anthropic-claude"]
|
||||
routes: ["/v1/chat/completions"]
|
||||
- scope: litellm.api.gpt_3_5_turbo
|
||||
models: ["gpt-3.5-turbo-testing"]
|
||||
enforce_scope_based_access: true
|
||||
enforce_rbac: true
|
||||
|
|
|
@ -1040,6 +1040,13 @@ class LiteLLM_TeamTable(TeamBase):
|
|||
"model_max_budget",
|
||||
"model_aliases",
|
||||
]
|
||||
|
||||
if (
|
||||
isinstance(values.get("members_with_roles"), dict)
|
||||
and not values["members_with_roles"]
|
||||
):
|
||||
values["members_with_roles"] = []
|
||||
|
||||
for field in dict_fields:
|
||||
value = values.get(field)
|
||||
if value is not None and isinstance(value, str):
|
||||
|
@ -2279,11 +2286,14 @@ RBAC_ROLES = Literal[
|
|||
]
|
||||
|
||||
|
||||
class RoleBasedPermissions(LiteLLMPydanticObjectBase):
|
||||
role: RBAC_ROLES
|
||||
class OIDCPermissions(LiteLLMPydanticObjectBase):
|
||||
models: Optional[List[str]] = None
|
||||
routes: Optional[List[str]] = None
|
||||
|
||||
|
||||
class RoleBasedPermissions(OIDCPermissions):
|
||||
role: RBAC_ROLES
|
||||
|
||||
model_config = {
|
||||
"extra": "forbid",
|
||||
}
|
||||
|
@ -2294,6 +2304,14 @@ class RoleMapping(BaseModel):
|
|||
internal_role: RBAC_ROLES
|
||||
|
||||
|
||||
class ScopeMapping(OIDCPermissions):
|
||||
scope: str
|
||||
|
||||
model_config = {
|
||||
"extra": "forbid",
|
||||
}
|
||||
|
||||
|
||||
class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||
"""
|
||||
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
|
||||
|
@ -2323,6 +2341,7 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
|||
"info_routes",
|
||||
]
|
||||
team_id_jwt_field: Optional[str] = None
|
||||
team_id_upsert: bool = False
|
||||
team_ids_jwt_field: Optional[str] = None
|
||||
upsert_sso_user_to_team: bool = False
|
||||
team_allowed_routes: List[
|
||||
|
@ -2351,6 +2370,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
|||
object_id_jwt_field: Optional[str] = (
|
||||
None # can be either user / team, inferred from the role mapping
|
||||
)
|
||||
scope_mappings: Optional[List[ScopeMapping]] = None
|
||||
enforce_scope_based_access: bool = False
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
# get the attribute names for this Pydantic model
|
||||
|
@ -2361,6 +2382,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
|||
user_allowed_roles = kwargs.get("user_allowed_roles")
|
||||
object_id_jwt_field = kwargs.get("object_id_jwt_field")
|
||||
role_mappings = kwargs.get("role_mappings")
|
||||
scope_mappings = kwargs.get("scope_mappings")
|
||||
enforce_scope_based_access = kwargs.get("enforce_scope_based_access")
|
||||
|
||||
if invalid_keys:
|
||||
raise ValueError(
|
||||
|
@ -2378,4 +2401,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
|||
"if object_id_jwt_field is set, role_mappings must also be set. Needed to infer if the caller is a user or team."
|
||||
)
|
||||
|
||||
if scope_mappings is not None and not enforce_scope_based_access:
|
||||
raise ValueError(
|
||||
"scope_mappings must be set if enforce_scope_based_access is true."
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
|
|
@ -655,11 +655,20 @@ async def _delete_cache_key_object(
|
|||
|
||||
|
||||
@log_db_metrics
|
||||
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
|
||||
return await prisma_client.db.litellm_teamtable.find_unique(
|
||||
async def _get_team_db_check(
|
||||
team_id: str, prisma_client: PrismaClient, team_id_upsert: Optional[bool] = None
|
||||
):
|
||||
response = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
|
||||
if response is None and team_id_upsert:
|
||||
response = await prisma_client.db.litellm_teamtable.create(
|
||||
data={"team_id": team_id}
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient):
|
||||
return await prisma_client.db.litellm_teamtable.find_unique(
|
||||
|
@ -675,6 +684,7 @@ async def _get_team_object_from_user_api_key_cache(
|
|||
db_cache_expiry: int,
|
||||
proxy_logging_obj: Optional[ProxyLogging],
|
||||
key: str,
|
||||
team_id_upsert: Optional[bool] = None,
|
||||
) -> LiteLLM_TeamTableCachedObj:
|
||||
db_access_time_key = key
|
||||
should_check_db = _should_check_db(
|
||||
|
@ -684,7 +694,7 @@ async def _get_team_object_from_user_api_key_cache(
|
|||
)
|
||||
if should_check_db:
|
||||
response = await _get_team_db_check(
|
||||
team_id=team_id, prisma_client=prisma_client
|
||||
team_id=team_id, prisma_client=prisma_client, team_id_upsert=team_id_upsert
|
||||
)
|
||||
else:
|
||||
response = None
|
||||
|
@ -752,6 +762,7 @@ async def get_team_object(
|
|||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||
check_cache_only: Optional[bool] = None,
|
||||
check_db_only: Optional[bool] = None,
|
||||
team_id_upsert: Optional[bool] = None,
|
||||
) -> LiteLLM_TeamTableCachedObj:
|
||||
"""
|
||||
- Check if team id in proxy Team Table
|
||||
|
@ -795,6 +806,7 @@ async def get_team_object(
|
|||
last_db_access_time=last_db_access_time,
|
||||
db_cache_expiry=db_cache_expiry,
|
||||
key=key,
|
||||
team_id_upsert=team_id_upsert,
|
||||
)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
|
|
|
@ -30,6 +30,7 @@ from litellm.proxy._types import (
|
|||
LiteLLM_TeamTable,
|
||||
LiteLLM_UserTable,
|
||||
LitellmUserRoles,
|
||||
ScopeMapping,
|
||||
Span,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
|
@ -318,7 +319,7 @@ class JWTHandler:
|
|||
org_id = default_value
|
||||
return org_id
|
||||
|
||||
def get_scopes(self, token: dict) -> list:
|
||||
def get_scopes(self, token: dict) -> List[str]:
|
||||
try:
|
||||
if isinstance(token["scope"], str):
|
||||
# Assuming the scopes are stored in 'scope' claim and are space-separated
|
||||
|
@ -543,6 +544,40 @@ class JWTAuthManager:
|
|||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def check_scope_based_access(
|
||||
scope_mappings: List[ScopeMapping],
|
||||
scopes: List[str],
|
||||
request_data: dict,
|
||||
general_settings: dict,
|
||||
) -> None:
|
||||
"""
|
||||
Check if scope allows access to the requested model
|
||||
"""
|
||||
if not scope_mappings:
|
||||
return None
|
||||
|
||||
allowed_models = []
|
||||
for sm in scope_mappings:
|
||||
if sm.scope in scopes and sm.models:
|
||||
allowed_models.extend(sm.models)
|
||||
|
||||
requested_model = request_data.get("model")
|
||||
|
||||
if not requested_model:
|
||||
return None
|
||||
|
||||
if requested_model not in allowed_models:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "model={} not allowed. Allowed_models={}".format(
|
||||
requested_model, allowed_models
|
||||
)
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def check_rbac_role(
|
||||
jwt_handler: JWTHandler,
|
||||
|
@ -636,6 +671,7 @@ class JWTAuthManager:
|
|||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert,
|
||||
)
|
||||
|
||||
return individual_team_id, team_object
|
||||
|
@ -829,6 +865,19 @@ class JWTAuthManager:
|
|||
rbac_role,
|
||||
)
|
||||
|
||||
# Check Scope Based Access
|
||||
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||
if (
|
||||
jwt_handler.litellm_jwtauth.enforce_scope_based_access
|
||||
and jwt_handler.litellm_jwtauth.scope_mappings
|
||||
):
|
||||
JWTAuthManager.check_scope_based_access(
|
||||
scope_mappings=jwt_handler.litellm_jwtauth.scope_mappings,
|
||||
scopes=scopes,
|
||||
request_data=request_data,
|
||||
general_settings=general_settings,
|
||||
)
|
||||
|
||||
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
|
||||
|
||||
# Get basic user info
|
||||
|
|
|
@ -1183,3 +1183,67 @@ def test_can_rbac_role_call_route():
|
|||
},
|
||||
route="/v1/embeddings",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"requested_model, should_work",
|
||||
[
|
||||
("gpt-3.5-turbo-testing", True),
|
||||
("gpt-4o", False),
|
||||
],
|
||||
)
|
||||
def test_check_scope_based_access(requested_model, should_work):
|
||||
from litellm.proxy.auth.handle_jwt import JWTAuthManager
|
||||
from litellm.proxy._types import ScopeMapping
|
||||
|
||||
args = {
|
||||
"scope_mappings": [
|
||||
ScopeMapping(
|
||||
models=["anthropic-claude"],
|
||||
routes=["/v1/chat/completions"],
|
||||
scope="litellm.api.consumer",
|
||||
),
|
||||
ScopeMapping(
|
||||
models=["gpt-3.5-turbo-testing"],
|
||||
routes=None,
|
||||
scope="litellm.api.gpt_3_5_turbo",
|
||||
),
|
||||
],
|
||||
"scopes": [
|
||||
"profile",
|
||||
"groups-scope",
|
||||
"email",
|
||||
"litellm.api.gpt_3_5_turbo",
|
||||
"litellm.api.consumer",
|
||||
],
|
||||
"request_data": {
|
||||
"model": requested_model,
|
||||
"messages": [{"role": "user", "content": "Hey, how's it going 1234?"}],
|
||||
},
|
||||
"general_settings": {
|
||||
"enable_jwt_auth": True,
|
||||
"litellm_jwtauth": {
|
||||
"team_id_jwt_field": "client_id",
|
||||
"team_id_upsert": True,
|
||||
"scope_mappings": [
|
||||
{
|
||||
"scope": "litellm.api.consumer",
|
||||
"models": ["anthropic-claude"],
|
||||
"routes": ["/v1/chat/completions"],
|
||||
},
|
||||
{
|
||||
"scope": "litellm.api.gpt_3_5_turbo",
|
||||
"models": ["gpt-3.5-turbo-testing"],
|
||||
},
|
||||
],
|
||||
"enforce_scope_based_access": True,
|
||||
"enforce_rbac": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if should_work:
|
||||
JWTAuthManager.check_scope_based_access(**args)
|
||||
else:
|
||||
with pytest.raises(HTTPException):
|
||||
JWTAuthManager.check_scope_based_access(**args)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue