mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Litellm dev 02 06 2025 p3 (#8343)
* feat(handle_jwt.py): initial commit to allow scope based model access * feat(handle_jwt.py): allow model access based on token scopes allow admin to control model access from IDP * test(test_jwt.py): add unit testing for scope based model access * docs(token_auth.md): add scope based model access to docs * docs(token_auth.md): update docs * docs(token_auth.md): update docs * build: add gemini commercial rate limits * fix: fix linting error
This commit is contained in:
parent
561b4fcb89
commit
8a9e8838e2
6 changed files with 238 additions and 7 deletions
|
@ -370,4 +370,68 @@ Supported internal roles:
|
||||||
- `internal_user`: User object will be used for RBAC spend tracking. Use this for tracking spend for an 'individual user'.
|
- `internal_user`: User object will be used for RBAC spend tracking. Use this for tracking spend for an 'individual user'.
|
||||||
- `proxy_admin`: Proxy admin will be used for RBAC spend tracking. Use this for granting admin access to a token.
|
- `proxy_admin`: Proxy admin will be used for RBAC spend tracking. Use this for granting admin access to a token.
|
||||||
|
|
||||||
### [Architecture Diagram (Control Model Access)](./jwt_auth_arch)
|
### [Architecture Diagram (Control Model Access)](./jwt_auth_arch)
|
||||||
|
|
||||||
|
## [BETA] Control Model Access with Scopes
|
||||||
|
|
||||||
|
Control which models a JWT can access. Set `enforce_scope_based_access: true` to enforce scope-based access control.
|
||||||
|
|
||||||
|
### 1. Setup config.yaml with scope mappings.
|
||||||
|
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: anthropic-claude
|
||||||
|
litellm_params:
|
||||||
|
model: anthropic/claude-3-5-sonnet
|
||||||
|
api_key: os.environ/ANTHROPIC_API_KEY
|
||||||
|
- model_name: gpt-3.5-turbo-testing
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
enable_jwt_auth: True
|
||||||
|
litellm_jwtauth:
|
||||||
|
team_id_jwt_field: "client_id" # 👈 set the field in the JWT token that contains the team id
|
||||||
|
team_id_upsert: true # 👈 upsert the team to db, if team id is not found in db
|
||||||
|
scope_mappings:
|
||||||
|
- scope: litellm.api.consumer
|
||||||
|
models: ["anthropic-claude"]
|
||||||
|
- scope: litellm.api.gpt_3_5_turbo
|
||||||
|
models: ["gpt-3.5-turbo-testing"]
|
||||||
|
enforce_scope_based_access: true # 👈 enforce scope-based access control
|
||||||
|
enforce_rbac: true # 👈 enforces only a Team/User/ProxyAdmin can access the proxy.
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Scope Mapping Spec
|
||||||
|
|
||||||
|
- `scope`: The scope to be used for the JWT token.
|
||||||
|
- `models`: The models that the JWT token can access. Value is the `model_name` in `model_list`. Note: Wildcard routes are not currently supported.
|
||||||
|
|
||||||
|
### 2. Create a JWT with the correct scopes.
|
||||||
|
|
||||||
|
Expected Token:
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"scope": ["litellm.api.consumer", "litellm.api.gpt_3_5_turbo"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Test the flow.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer eyJhbGci...' \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-3.5-turbo-testing",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hey, how'\''s it going 1234?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
|
@ -27,3 +27,17 @@ model_list:
|
||||||
model: openai/fake
|
model: openai/fake
|
||||||
api_key: fake-key
|
api_key: fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
enable_jwt_auth: True
|
||||||
|
litellm_jwtauth:
|
||||||
|
team_id_jwt_field: "client_id"
|
||||||
|
team_id_upsert: true
|
||||||
|
scope_mappings:
|
||||||
|
- scope: litellm.api.consumer
|
||||||
|
models: ["anthropic-claude"]
|
||||||
|
routes: ["/v1/chat/completions"]
|
||||||
|
- scope: litellm.api.gpt_3_5_turbo
|
||||||
|
models: ["gpt-3.5-turbo-testing"]
|
||||||
|
enforce_scope_based_access: true
|
||||||
|
enforce_rbac: true
|
||||||
|
|
|
@ -1040,6 +1040,13 @@ class LiteLLM_TeamTable(TeamBase):
|
||||||
"model_max_budget",
|
"model_max_budget",
|
||||||
"model_aliases",
|
"model_aliases",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(values.get("members_with_roles"), dict)
|
||||||
|
and not values["members_with_roles"]
|
||||||
|
):
|
||||||
|
values["members_with_roles"] = []
|
||||||
|
|
||||||
for field in dict_fields:
|
for field in dict_fields:
|
||||||
value = values.get(field)
|
value = values.get(field)
|
||||||
if value is not None and isinstance(value, str):
|
if value is not None and isinstance(value, str):
|
||||||
|
@ -2279,11 +2286,14 @@ RBAC_ROLES = Literal[
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class RoleBasedPermissions(LiteLLMPydanticObjectBase):
|
class OIDCPermissions(LiteLLMPydanticObjectBase):
|
||||||
role: RBAC_ROLES
|
|
||||||
models: Optional[List[str]] = None
|
models: Optional[List[str]] = None
|
||||||
routes: Optional[List[str]] = None
|
routes: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class RoleBasedPermissions(OIDCPermissions):
|
||||||
|
role: RBAC_ROLES
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"extra": "forbid",
|
"extra": "forbid",
|
||||||
}
|
}
|
||||||
|
@ -2294,6 +2304,14 @@ class RoleMapping(BaseModel):
|
||||||
internal_role: RBAC_ROLES
|
internal_role: RBAC_ROLES
|
||||||
|
|
||||||
|
|
||||||
|
class ScopeMapping(OIDCPermissions):
|
||||||
|
scope: str
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"extra": "forbid",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||||
"""
|
"""
|
||||||
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
|
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
|
||||||
|
@ -2323,6 +2341,7 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||||
"info_routes",
|
"info_routes",
|
||||||
]
|
]
|
||||||
team_id_jwt_field: Optional[str] = None
|
team_id_jwt_field: Optional[str] = None
|
||||||
|
team_id_upsert: bool = False
|
||||||
team_ids_jwt_field: Optional[str] = None
|
team_ids_jwt_field: Optional[str] = None
|
||||||
upsert_sso_user_to_team: bool = False
|
upsert_sso_user_to_team: bool = False
|
||||||
team_allowed_routes: List[
|
team_allowed_routes: List[
|
||||||
|
@ -2351,6 +2370,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||||
object_id_jwt_field: Optional[str] = (
|
object_id_jwt_field: Optional[str] = (
|
||||||
None # can be either user / team, inferred from the role mapping
|
None # can be either user / team, inferred from the role mapping
|
||||||
)
|
)
|
||||||
|
scope_mappings: Optional[List[ScopeMapping]] = None
|
||||||
|
enforce_scope_based_access: bool = False
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
# get the attribute names for this Pydantic model
|
# get the attribute names for this Pydantic model
|
||||||
|
@ -2361,6 +2382,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||||
user_allowed_roles = kwargs.get("user_allowed_roles")
|
user_allowed_roles = kwargs.get("user_allowed_roles")
|
||||||
object_id_jwt_field = kwargs.get("object_id_jwt_field")
|
object_id_jwt_field = kwargs.get("object_id_jwt_field")
|
||||||
role_mappings = kwargs.get("role_mappings")
|
role_mappings = kwargs.get("role_mappings")
|
||||||
|
scope_mappings = kwargs.get("scope_mappings")
|
||||||
|
enforce_scope_based_access = kwargs.get("enforce_scope_based_access")
|
||||||
|
|
||||||
if invalid_keys:
|
if invalid_keys:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -2378,4 +2401,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||||
"if object_id_jwt_field is set, role_mappings must also be set. Needed to infer if the caller is a user or team."
|
"if object_id_jwt_field is set, role_mappings must also be set. Needed to infer if the caller is a user or team."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if scope_mappings is not None and not enforce_scope_based_access:
|
||||||
|
raise ValueError(
|
||||||
|
"scope_mappings must be set if enforce_scope_based_access is true."
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
|
@ -655,11 +655,20 @@ async def _delete_cache_key_object(
|
||||||
|
|
||||||
|
|
||||||
@log_db_metrics
|
@log_db_metrics
|
||||||
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
|
async def _get_team_db_check(
|
||||||
return await prisma_client.db.litellm_teamtable.find_unique(
|
team_id: str, prisma_client: PrismaClient, team_id_upsert: Optional[bool] = None
|
||||||
|
):
|
||||||
|
response = await prisma_client.db.litellm_teamtable.find_unique(
|
||||||
where={"team_id": team_id}
|
where={"team_id": team_id}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if response is None and team_id_upsert:
|
||||||
|
response = await prisma_client.db.litellm_teamtable.create(
|
||||||
|
data={"team_id": team_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient):
|
async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient):
|
||||||
return await prisma_client.db.litellm_teamtable.find_unique(
|
return await prisma_client.db.litellm_teamtable.find_unique(
|
||||||
|
@ -675,6 +684,7 @@ async def _get_team_object_from_user_api_key_cache(
|
||||||
db_cache_expiry: int,
|
db_cache_expiry: int,
|
||||||
proxy_logging_obj: Optional[ProxyLogging],
|
proxy_logging_obj: Optional[ProxyLogging],
|
||||||
key: str,
|
key: str,
|
||||||
|
team_id_upsert: Optional[bool] = None,
|
||||||
) -> LiteLLM_TeamTableCachedObj:
|
) -> LiteLLM_TeamTableCachedObj:
|
||||||
db_access_time_key = key
|
db_access_time_key = key
|
||||||
should_check_db = _should_check_db(
|
should_check_db = _should_check_db(
|
||||||
|
@ -684,7 +694,7 @@ async def _get_team_object_from_user_api_key_cache(
|
||||||
)
|
)
|
||||||
if should_check_db:
|
if should_check_db:
|
||||||
response = await _get_team_db_check(
|
response = await _get_team_db_check(
|
||||||
team_id=team_id, prisma_client=prisma_client
|
team_id=team_id, prisma_client=prisma_client, team_id_upsert=team_id_upsert
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = None
|
response = None
|
||||||
|
@ -752,6 +762,7 @@ async def get_team_object(
|
||||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||||
check_cache_only: Optional[bool] = None,
|
check_cache_only: Optional[bool] = None,
|
||||||
check_db_only: Optional[bool] = None,
|
check_db_only: Optional[bool] = None,
|
||||||
|
team_id_upsert: Optional[bool] = None,
|
||||||
) -> LiteLLM_TeamTableCachedObj:
|
) -> LiteLLM_TeamTableCachedObj:
|
||||||
"""
|
"""
|
||||||
- Check if team id in proxy Team Table
|
- Check if team id in proxy Team Table
|
||||||
|
@ -795,6 +806,7 @@ async def get_team_object(
|
||||||
last_db_access_time=last_db_access_time,
|
last_db_access_time=last_db_access_time,
|
||||||
db_cache_expiry=db_cache_expiry,
|
db_cache_expiry=db_cache_expiry,
|
||||||
key=key,
|
key=key,
|
||||||
|
team_id_upsert=team_id_upsert,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
|
@ -30,6 +30,7 @@ from litellm.proxy._types import (
|
||||||
LiteLLM_TeamTable,
|
LiteLLM_TeamTable,
|
||||||
LiteLLM_UserTable,
|
LiteLLM_UserTable,
|
||||||
LitellmUserRoles,
|
LitellmUserRoles,
|
||||||
|
ScopeMapping,
|
||||||
Span,
|
Span,
|
||||||
)
|
)
|
||||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||||
|
@ -318,7 +319,7 @@ class JWTHandler:
|
||||||
org_id = default_value
|
org_id = default_value
|
||||||
return org_id
|
return org_id
|
||||||
|
|
||||||
def get_scopes(self, token: dict) -> list:
|
def get_scopes(self, token: dict) -> List[str]:
|
||||||
try:
|
try:
|
||||||
if isinstance(token["scope"], str):
|
if isinstance(token["scope"], str):
|
||||||
# Assuming the scopes are stored in 'scope' claim and are space-separated
|
# Assuming the scopes are stored in 'scope' claim and are space-separated
|
||||||
|
@ -543,6 +544,40 @@ class JWTAuthManager:
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_scope_based_access(
|
||||||
|
scope_mappings: List[ScopeMapping],
|
||||||
|
scopes: List[str],
|
||||||
|
request_data: dict,
|
||||||
|
general_settings: dict,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Check if scope allows access to the requested model
|
||||||
|
"""
|
||||||
|
if not scope_mappings:
|
||||||
|
return None
|
||||||
|
|
||||||
|
allowed_models = []
|
||||||
|
for sm in scope_mappings:
|
||||||
|
if sm.scope in scopes and sm.models:
|
||||||
|
allowed_models.extend(sm.models)
|
||||||
|
|
||||||
|
requested_model = request_data.get("model")
|
||||||
|
|
||||||
|
if not requested_model:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if requested_model not in allowed_models:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={
|
||||||
|
"error": "model={} not allowed. Allowed_models={}".format(
|
||||||
|
requested_model, allowed_models
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def check_rbac_role(
|
async def check_rbac_role(
|
||||||
jwt_handler: JWTHandler,
|
jwt_handler: JWTHandler,
|
||||||
|
@ -636,6 +671,7 @@ class JWTAuthManager:
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_api_key_cache=user_api_key_cache,
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert,
|
||||||
)
|
)
|
||||||
|
|
||||||
return individual_team_id, team_object
|
return individual_team_id, team_object
|
||||||
|
@ -829,6 +865,19 @@ class JWTAuthManager:
|
||||||
rbac_role,
|
rbac_role,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check Scope Based Access
|
||||||
|
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||||
|
if (
|
||||||
|
jwt_handler.litellm_jwtauth.enforce_scope_based_access
|
||||||
|
and jwt_handler.litellm_jwtauth.scope_mappings
|
||||||
|
):
|
||||||
|
JWTAuthManager.check_scope_based_access(
|
||||||
|
scope_mappings=jwt_handler.litellm_jwtauth.scope_mappings,
|
||||||
|
scopes=scopes,
|
||||||
|
request_data=request_data,
|
||||||
|
general_settings=general_settings,
|
||||||
|
)
|
||||||
|
|
||||||
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
|
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
|
||||||
|
|
||||||
# Get basic user info
|
# Get basic user info
|
||||||
|
|
|
@ -1183,3 +1183,67 @@ def test_can_rbac_role_call_route():
|
||||||
},
|
},
|
||||||
route="/v1/embeddings",
|
route="/v1/embeddings",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"requested_model, should_work",
|
||||||
|
[
|
||||||
|
("gpt-3.5-turbo-testing", True),
|
||||||
|
("gpt-4o", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_check_scope_based_access(requested_model, should_work):
|
||||||
|
from litellm.proxy.auth.handle_jwt import JWTAuthManager
|
||||||
|
from litellm.proxy._types import ScopeMapping
|
||||||
|
|
||||||
|
args = {
|
||||||
|
"scope_mappings": [
|
||||||
|
ScopeMapping(
|
||||||
|
models=["anthropic-claude"],
|
||||||
|
routes=["/v1/chat/completions"],
|
||||||
|
scope="litellm.api.consumer",
|
||||||
|
),
|
||||||
|
ScopeMapping(
|
||||||
|
models=["gpt-3.5-turbo-testing"],
|
||||||
|
routes=None,
|
||||||
|
scope="litellm.api.gpt_3_5_turbo",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"scopes": [
|
||||||
|
"profile",
|
||||||
|
"groups-scope",
|
||||||
|
"email",
|
||||||
|
"litellm.api.gpt_3_5_turbo",
|
||||||
|
"litellm.api.consumer",
|
||||||
|
],
|
||||||
|
"request_data": {
|
||||||
|
"model": requested_model,
|
||||||
|
"messages": [{"role": "user", "content": "Hey, how's it going 1234?"}],
|
||||||
|
},
|
||||||
|
"general_settings": {
|
||||||
|
"enable_jwt_auth": True,
|
||||||
|
"litellm_jwtauth": {
|
||||||
|
"team_id_jwt_field": "client_id",
|
||||||
|
"team_id_upsert": True,
|
||||||
|
"scope_mappings": [
|
||||||
|
{
|
||||||
|
"scope": "litellm.api.consumer",
|
||||||
|
"models": ["anthropic-claude"],
|
||||||
|
"routes": ["/v1/chat/completions"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"scope": "litellm.api.gpt_3_5_turbo",
|
||||||
|
"models": ["gpt-3.5-turbo-testing"],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"enforce_scope_based_access": True,
|
||||||
|
"enforce_rbac": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if should_work:
|
||||||
|
JWTAuthManager.check_scope_based_access(**args)
|
||||||
|
else:
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
JWTAuthManager.check_scope_based_access(**args)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue