mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Litellm dev 01 31 2025 p2 (#8164)
* docs(token_auth.md): clarify title * refactor(handle_jwt.py): add jwt auth manager + refactor to handle groups allows user to call model if user belongs to group with model access * refactor(handle_jwt.py): refactor to first check if service call then check user call * feat(handle_jwt.py): new `enforce_team_access` param only allows user to call model if a team they belong to has model access allows controlling user model access by team * fix(handle_jwt.py): fix error string, remove unecessary param * docs(token_auth.md): add controlling model access for jwt tokens via teams to docs * test: fix tests post refactor * fix: fix linting errors * fix: fix linting error * test: fix import error
This commit is contained in:
parent
2674fa4dc3
commit
a008a2d4f4
6 changed files with 447 additions and 242 deletions
|
@ -1,7 +1,7 @@
|
||||||
import Tabs from '@theme/Tabs';
|
import Tabs from '@theme/Tabs';
|
||||||
import TabItem from '@theme/TabItem';
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# SSO - JWT-based Auth
|
# OIDC - JWT-based Auth
|
||||||
|
|
||||||
Use JWT's to auth admins / projects into the proxy.
|
Use JWT's to auth admins / projects into the proxy.
|
||||||
|
|
||||||
|
@ -156,35 +156,12 @@ scope: ["litellm-proxy-admin",...]
|
||||||
scope: "litellm-proxy-admin ..."
|
scope: "litellm-proxy-admin ..."
|
||||||
```
|
```
|
||||||
|
|
||||||
## Enforce Role-Based Access Control (RBAC)
|
## Control Model Access with Roles
|
||||||
|
|
||||||
Reject a JWT token if it's valid but doesn't have the required scopes / fields.
|
Reject a JWT token if it's valid but doesn't have the required scopes / fields.
|
||||||
|
|
||||||
Only tokens which with valid Admin (`admin_jwt_scope`), User (`user_id_jwt_field`), Team (`team_id_jwt_field`) are allowed.
|
Only tokens which with valid Admin (`admin_jwt_scope`), User (`user_id_jwt_field`), Team (`team_id_jwt_field`) are allowed.
|
||||||
|
|
||||||
```yaml
|
|
||||||
general_settings:
|
|
||||||
master_key: sk-1234
|
|
||||||
enable_jwt_auth: True
|
|
||||||
litellm_jwtauth:
|
|
||||||
admin_jwt_scope: "litellm_proxy_endpoints_access"
|
|
||||||
admin_allowed_routes:
|
|
||||||
- openai_routes
|
|
||||||
- info_routes
|
|
||||||
public_key_ttl: 600
|
|
||||||
enforce_rbac: true # 👈 Enforce RBAC
|
|
||||||
```
|
|
||||||
|
|
||||||
Expected Scope in JWT:
|
|
||||||
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"scope": "litellm_proxy_endpoints_access"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Control Model Access
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
general_settings:
|
general_settings:
|
||||||
enable_jwt_auth: True
|
enable_jwt_auth: True
|
||||||
|
@ -198,9 +175,57 @@ general_settings:
|
||||||
models: ["anthropic-claude"]
|
models: ["anthropic-claude"]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
**[Architecture Diagram (Control Model Access)](./jwt_auth_arch)**
|
**[Architecture Diagram (Control Model Access)](./jwt_auth_arch)**
|
||||||
|
|
||||||
|
## Control model access with Teams
|
||||||
|
|
||||||
|
|
||||||
|
1. Specify the JWT field that contains the team ids, that the user belongs to.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234
|
||||||
|
litellm_jwtauth:
|
||||||
|
user_id_jwt_field: "sub"
|
||||||
|
team_ids_jwt_field: "groups"
|
||||||
|
```
|
||||||
|
|
||||||
|
This is assuming your token looks like this:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
...,
|
||||||
|
"sub": "my-unique-user",
|
||||||
|
"groups": ["team_id_1", "team_id_2"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Create the teams on LiteLLM
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST '<PROXY_BASE_URL>/team/new' \
|
||||||
|
-H 'Authorization: Bearer <PROXY_MASTER_KEY>' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-D '{
|
||||||
|
"team_alias": "team_1",
|
||||||
|
"team_id": "team_id_1" # 👈 MUST BE THE SAME AS THE SSO GROUP ID
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test the flow
|
||||||
|
|
||||||
|
SSO for UI: [**See Walkthrough**](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
|
||||||
|
|
||||||
|
OIDC Auth for API: [**See Walkthrough**](https://www.loom.com/share/00fe2deab59a426183a46b1e2b522200?sid=4ed6d497-ead6-47f9-80c0-ca1c4b6b4814)
|
||||||
|
|
||||||
|
|
||||||
|
### Flow
|
||||||
|
|
||||||
|
- Validate if user id is in the DB (LiteLLM_UserTable)
|
||||||
|
- Validate if any of the groups are in the DB (LiteLLM_TeamTable)
|
||||||
|
- Validate if any group has model access
|
||||||
|
- If all checks pass, allow the request
|
||||||
|
|
||||||
|
|
||||||
## Advanced - Allowed Routes
|
## Advanced - Allowed Routes
|
||||||
|
|
||||||
Configure which routes a JWT can access via the config.
|
Configure which routes a JWT can access via the config.
|
||||||
|
|
|
@ -3,6 +3,10 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-3.5-turbo
|
model: gpt-3.5-turbo
|
||||||
rpm: 3
|
rpm: 3
|
||||||
|
- model_name: o3-mini
|
||||||
|
litellm_params:
|
||||||
|
model: o3-mini
|
||||||
|
rpm: 3
|
||||||
- model_name: anthropic-claude
|
- model_name: anthropic-claude
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: claude-3-5-haiku-20241022
|
model: claude-3-5-haiku-20241022
|
||||||
|
@ -18,4 +22,11 @@ model_list:
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
callbacks: ["langsmith"]
|
callbacks: ["langsmith"]
|
||||||
disable_no_log_param: true
|
disable_no_log_param: true
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
enable_jwt_auth: True
|
||||||
|
litellm_jwtauth:
|
||||||
|
user_id_jwt_field: "sub"
|
||||||
|
user_email_jwt_field: "email"
|
||||||
|
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD
|
|
@ -8,11 +8,12 @@ JWT token must have 'litellm_proxy_admin' in scope.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional, cast
|
from typing import Any, List, Literal, Optional, Set, Tuple, cast
|
||||||
|
|
||||||
from cryptography import x509
|
from cryptography import x509
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
from cryptography.hazmat.primitives import serialization
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
|
@ -21,11 +22,27 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
RBAC_ROLES,
|
RBAC_ROLES,
|
||||||
JWKKeyValue,
|
JWKKeyValue,
|
||||||
|
JWTAuthBuilderResult,
|
||||||
JWTKeyItem,
|
JWTKeyItem,
|
||||||
|
LiteLLM_EndUserTable,
|
||||||
LiteLLM_JWTAuth,
|
LiteLLM_JWTAuth,
|
||||||
|
LiteLLM_OrganizationTable,
|
||||||
|
LiteLLM_TeamTable,
|
||||||
|
LiteLLM_UserTable,
|
||||||
LitellmUserRoles,
|
LitellmUserRoles,
|
||||||
|
Span,
|
||||||
|
)
|
||||||
|
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||||
|
|
||||||
|
from .auth_checks import (
|
||||||
|
allowed_routes_check,
|
||||||
|
get_actual_routes,
|
||||||
|
get_end_user_object,
|
||||||
|
get_org_object,
|
||||||
|
get_role_based_models,
|
||||||
|
get_team_object,
|
||||||
|
get_user_object,
|
||||||
)
|
)
|
||||||
from litellm.proxy.utils import PrismaClient
|
|
||||||
|
|
||||||
|
|
||||||
class JWTHandler:
|
class JWTHandler:
|
||||||
|
@ -401,3 +418,355 @@ class JWTHandler:
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
await self.http_handler.close()
|
await self.http_handler.close()
|
||||||
|
|
||||||
|
|
||||||
|
class JWTAuthManager:
|
||||||
|
"""Manages JWT authentication and authorization operations"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def can_rbac_role_call_model(
|
||||||
|
rbac_role: RBAC_ROLES,
|
||||||
|
general_settings: dict,
|
||||||
|
model: Optional[str],
|
||||||
|
) -> Literal[True]:
|
||||||
|
"""
|
||||||
|
Checks if user is allowed to access the model, based on their role.
|
||||||
|
"""
|
||||||
|
role_based_models = get_role_based_models(
|
||||||
|
rbac_role=rbac_role, general_settings=general_settings
|
||||||
|
)
|
||||||
|
if role_based_models is None or model is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if model not in role_based_models:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=f"User role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def check_rbac_role(
|
||||||
|
jwt_handler: JWTHandler,
|
||||||
|
jwt_valid_token: dict,
|
||||||
|
general_settings: dict,
|
||||||
|
request_data: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Validate RBAC role and model access permissions"""
|
||||||
|
if jwt_handler.litellm_jwtauth.enforce_rbac is True:
|
||||||
|
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
|
||||||
|
if rbac_role is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.",
|
||||||
|
)
|
||||||
|
JWTAuthManager.can_rbac_role_call_model(
|
||||||
|
rbac_role=rbac_role,
|
||||||
|
general_settings=general_settings,
|
||||||
|
model=request_data.get("model"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def check_admin_access(
|
||||||
|
jwt_handler: JWTHandler,
|
||||||
|
scopes: list,
|
||||||
|
route: str,
|
||||||
|
user_id: Optional[str],
|
||||||
|
org_id: Optional[str],
|
||||||
|
api_key: str,
|
||||||
|
) -> Optional[JWTAuthBuilderResult]:
|
||||||
|
"""Check admin status and route access permissions"""
|
||||||
|
if not jwt_handler.is_admin(scopes=scopes):
|
||||||
|
return None
|
||||||
|
|
||||||
|
is_allowed = allowed_routes_check(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
user_route=route,
|
||||||
|
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||||
|
)
|
||||||
|
if not is_allowed:
|
||||||
|
allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
|
||||||
|
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||||
|
raise Exception(
|
||||||
|
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return JWTAuthBuilderResult(
|
||||||
|
is_proxy_admin=True,
|
||||||
|
team_object=None,
|
||||||
|
user_object=None,
|
||||||
|
end_user_object=None,
|
||||||
|
org_object=None,
|
||||||
|
token=api_key,
|
||||||
|
team_id=None,
|
||||||
|
user_id=user_id,
|
||||||
|
end_user_id=None,
|
||||||
|
org_id=org_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def find_and_validate_specific_team_id(
|
||||||
|
jwt_handler: JWTHandler,
|
||||||
|
jwt_valid_token: dict,
|
||||||
|
prisma_client: Optional[PrismaClient],
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
parent_otel_span: Optional[Span],
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]:
|
||||||
|
"""Find and validate specific team ID"""
|
||||||
|
individual_team_id = jwt_handler.get_team_id(
|
||||||
|
token=jwt_valid_token, default_value=None
|
||||||
|
)
|
||||||
|
|
||||||
|
if not individual_team_id and jwt_handler.is_required_team_id() is True:
|
||||||
|
raise Exception(
|
||||||
|
f"No team id found in token. Checked team_id field '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
## VALIDATE TEAM OBJECT ###
|
||||||
|
team_object: Optional[LiteLLM_TeamTable] = None
|
||||||
|
if individual_team_id:
|
||||||
|
team_object = await get_team_object(
|
||||||
|
team_id=individual_team_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
return individual_team_id, team_object
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_team_ids(jwt_handler: JWTHandler, jwt_valid_token: dict) -> Set[str]:
|
||||||
|
"""Get combined team IDs from groups and individual team_id"""
|
||||||
|
team_ids_from_groups = jwt_handler.get_team_ids_from_jwt(token=jwt_valid_token)
|
||||||
|
|
||||||
|
all_team_ids = set(team_ids_from_groups)
|
||||||
|
|
||||||
|
return all_team_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def find_team_with_model_access(
|
||||||
|
team_ids: Set[str],
|
||||||
|
requested_model: Optional[str],
|
||||||
|
route: str,
|
||||||
|
jwt_handler: JWTHandler,
|
||||||
|
prisma_client: Optional[PrismaClient],
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
parent_otel_span: Optional[Span],
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]:
|
||||||
|
"""Find first team with access to the requested model"""
|
||||||
|
|
||||||
|
if not team_ids:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
for team_id in team_ids:
|
||||||
|
try:
|
||||||
|
team_object = await get_team_object(
|
||||||
|
team_id=team_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
if team_object and team_object.models is not None:
|
||||||
|
team_models = team_object.models
|
||||||
|
if isinstance(team_models, list) and (
|
||||||
|
not requested_model
|
||||||
|
or requested_model in team_models
|
||||||
|
or "*" in team_models
|
||||||
|
):
|
||||||
|
is_allowed = allowed_routes_check(
|
||||||
|
user_role=LitellmUserRoles.TEAM,
|
||||||
|
user_route=route,
|
||||||
|
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||||
|
)
|
||||||
|
if is_allowed:
|
||||||
|
return team_id, team_object
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if requested_model:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=f"No team has access to the requested model: {requested_model}. Checked teams={team_ids}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_info(
|
||||||
|
jwt_handler: JWTHandler,
|
||||||
|
jwt_valid_token: dict,
|
||||||
|
) -> Tuple[Optional[str], Optional[str], Optional[bool]]:
|
||||||
|
"""Get user email and validation status"""
|
||||||
|
user_email = jwt_handler.get_user_email(
|
||||||
|
token=jwt_valid_token, default_value=None
|
||||||
|
)
|
||||||
|
valid_user_email = None
|
||||||
|
if jwt_handler.is_enforced_email_domain():
|
||||||
|
valid_user_email = (
|
||||||
|
False
|
||||||
|
if user_email is None
|
||||||
|
else jwt_handler.is_allowed_domain(user_email=user_email)
|
||||||
|
)
|
||||||
|
user_id = jwt_handler.get_user_id(
|
||||||
|
token=jwt_valid_token, default_value=user_email
|
||||||
|
)
|
||||||
|
return user_id, user_email, valid_user_email
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_objects(
|
||||||
|
user_id: Optional[str],
|
||||||
|
org_id: Optional[str],
|
||||||
|
end_user_id: Optional[str],
|
||||||
|
valid_user_email: Optional[bool],
|
||||||
|
jwt_handler: JWTHandler,
|
||||||
|
prisma_client: Optional[PrismaClient],
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
parent_otel_span: Optional[Span],
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
) -> Tuple[
|
||||||
|
Optional[LiteLLM_UserTable],
|
||||||
|
Optional[LiteLLM_OrganizationTable],
|
||||||
|
Optional[LiteLLM_EndUserTable],
|
||||||
|
]:
|
||||||
|
"""Get user, org, and end user objects"""
|
||||||
|
org_object: Optional[LiteLLM_OrganizationTable] = None
|
||||||
|
if org_id:
|
||||||
|
org_object = (
|
||||||
|
await get_org_object(
|
||||||
|
org_id=org_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
if org_id
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
user_object: Optional[LiteLLM_UserTable] = None
|
||||||
|
if user_id:
|
||||||
|
user_object = (
|
||||||
|
await get_user_object(
|
||||||
|
user_id=user_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
user_id_upsert=jwt_handler.is_upsert_user_id(
|
||||||
|
valid_user_email=valid_user_email
|
||||||
|
),
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
if user_id
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
end_user_object: Optional[LiteLLM_EndUserTable] = None
|
||||||
|
if end_user_id:
|
||||||
|
end_user_object = (
|
||||||
|
await get_end_user_object(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
if end_user_id
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return user_object, org_object, end_user_object
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def auth_builder(
|
||||||
|
api_key: str,
|
||||||
|
jwt_handler: JWTHandler,
|
||||||
|
request_data: dict,
|
||||||
|
general_settings: dict,
|
||||||
|
route: str,
|
||||||
|
prisma_client: Optional[PrismaClient],
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
parent_otel_span: Optional[Span],
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
) -> JWTAuthBuilderResult:
|
||||||
|
"""Main authentication and authorization builder"""
|
||||||
|
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
|
||||||
|
|
||||||
|
# Check RBAC
|
||||||
|
await JWTAuthManager.check_rbac_role(
|
||||||
|
jwt_handler, jwt_valid_token, general_settings, request_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get basic user info
|
||||||
|
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||||
|
user_id, _, valid_user_email = await JWTAuthManager.get_user_info(
|
||||||
|
jwt_handler, jwt_valid_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get IDs
|
||||||
|
org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None)
|
||||||
|
end_user_id = jwt_handler.get_end_user_id(
|
||||||
|
token=jwt_valid_token, default_value=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check admin access
|
||||||
|
admin_result = await JWTAuthManager.check_admin_access(
|
||||||
|
jwt_handler, scopes, route, user_id, org_id, api_key
|
||||||
|
)
|
||||||
|
if admin_result:
|
||||||
|
return admin_result
|
||||||
|
|
||||||
|
# Get team with model access
|
||||||
|
## SPECIFIC TEAM ID
|
||||||
|
team_id, team_object = await JWTAuthManager.find_and_validate_specific_team_id(
|
||||||
|
jwt_handler,
|
||||||
|
jwt_valid_token,
|
||||||
|
prisma_client,
|
||||||
|
user_api_key_cache,
|
||||||
|
parent_otel_span,
|
||||||
|
proxy_logging_obj,
|
||||||
|
)
|
||||||
|
if not team_object:
|
||||||
|
## CHECK USER GROUP ACCESS
|
||||||
|
all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token)
|
||||||
|
team_id, team_object = await JWTAuthManager.find_team_with_model_access(
|
||||||
|
team_ids=all_team_ids,
|
||||||
|
requested_model=request_data.get("model"),
|
||||||
|
route=route,
|
||||||
|
jwt_handler=jwt_handler,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get other objects
|
||||||
|
user_object, org_object, end_user_object = await JWTAuthManager.get_objects(
|
||||||
|
user_id=user_id,
|
||||||
|
org_id=org_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
valid_user_email=valid_user_email,
|
||||||
|
jwt_handler=jwt_handler,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
return JWTAuthBuilderResult(
|
||||||
|
is_proxy_admin=False,
|
||||||
|
team_id=team_id,
|
||||||
|
team_object=team_object,
|
||||||
|
user_id=user_id,
|
||||||
|
user_object=user_object,
|
||||||
|
org_id=org_id,
|
||||||
|
org_object=org_object,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
end_user_object=end_user_object,
|
||||||
|
token=api_key,
|
||||||
|
)
|
||||||
|
|
|
@ -26,14 +26,10 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
_handle_failed_db_connection_for_get_key_object,
|
_handle_failed_db_connection_for_get_key_object,
|
||||||
_virtual_key_max_budget_check,
|
_virtual_key_max_budget_check,
|
||||||
_virtual_key_soft_budget_check,
|
_virtual_key_soft_budget_check,
|
||||||
allowed_routes_check,
|
|
||||||
can_key_call_model,
|
can_key_call_model,
|
||||||
common_checks,
|
common_checks,
|
||||||
get_actual_routes,
|
|
||||||
get_end_user_object,
|
get_end_user_object,
|
||||||
get_key_object,
|
get_key_object,
|
||||||
get_org_object,
|
|
||||||
get_role_based_models,
|
|
||||||
get_team_object,
|
get_team_object,
|
||||||
get_user_object,
|
get_user_object,
|
||||||
is_valid_fallback_model,
|
is_valid_fallback_model,
|
||||||
|
@ -47,7 +43,7 @@ from litellm.proxy.auth.auth_utils import (
|
||||||
route_in_additonal_public_routes,
|
route_in_additonal_public_routes,
|
||||||
should_run_auth_on_pass_through_provider_route,
|
should_run_auth_on_pass_through_provider_route,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTAuthManager, JWTHandler
|
||||||
from litellm.proxy.auth.oauth2_check import check_oauth2_token
|
from litellm.proxy.auth.oauth2_check import check_oauth2_token
|
||||||
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
|
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
|
||||||
from litellm.proxy.auth.route_checks import RouteChecks
|
from litellm.proxy.auth.route_checks import RouteChecks
|
||||||
|
@ -282,201 +278,6 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
|
||||||
return LitellmUserRoles.TEAM
|
return LitellmUserRoles.TEAM
|
||||||
|
|
||||||
|
|
||||||
def can_rbac_role_call_model(
|
|
||||||
rbac_role: RBAC_ROLES,
|
|
||||||
general_settings: dict,
|
|
||||||
model: Optional[str],
|
|
||||||
) -> Literal[True]:
|
|
||||||
"""
|
|
||||||
Checks if user is allowed to access the model, based on their role.
|
|
||||||
"""
|
|
||||||
role_based_models = get_role_based_models(
|
|
||||||
rbac_role=rbac_role, general_settings=general_settings
|
|
||||||
)
|
|
||||||
if role_based_models is None or model is None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if model not in role_based_models:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail=f"User role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
async def _jwt_auth_user_api_key_auth_builder(
|
|
||||||
api_key: str,
|
|
||||||
jwt_handler: JWTHandler,
|
|
||||||
request_data: dict,
|
|
||||||
general_settings: dict,
|
|
||||||
route: str,
|
|
||||||
prisma_client: Optional[PrismaClient],
|
|
||||||
user_api_key_cache: DualCache,
|
|
||||||
parent_otel_span: Optional[Span],
|
|
||||||
proxy_logging_obj: ProxyLogging,
|
|
||||||
) -> JWTAuthBuilderResult:
|
|
||||||
|
|
||||||
# check if valid token
|
|
||||||
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
|
|
||||||
|
|
||||||
# check if unmatched token and enforce_rbac is true
|
|
||||||
if jwt_handler.litellm_jwtauth.enforce_rbac is True:
|
|
||||||
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
|
|
||||||
if rbac_role is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# run rbac validation checks
|
|
||||||
can_rbac_role_call_model(
|
|
||||||
rbac_role=rbac_role,
|
|
||||||
general_settings=general_settings,
|
|
||||||
model=request_data.get("model"),
|
|
||||||
)
|
|
||||||
# get scopes
|
|
||||||
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
|
||||||
|
|
||||||
# [OPTIONAL] allowed user email domains
|
|
||||||
valid_user_email: Optional[bool] = None
|
|
||||||
user_email: Optional[str] = None
|
|
||||||
if jwt_handler.is_enforced_email_domain():
|
|
||||||
"""
|
|
||||||
if 'allowed_email_subdomains' is set,
|
|
||||||
|
|
||||||
- checks if token contains 'email' field
|
|
||||||
- checks if 'email' is from an allowed domain
|
|
||||||
"""
|
|
||||||
user_email = jwt_handler.get_user_email(
|
|
||||||
token=jwt_valid_token, default_value=None
|
|
||||||
)
|
|
||||||
if user_email is None:
|
|
||||||
valid_user_email = False
|
|
||||||
else:
|
|
||||||
valid_user_email = jwt_handler.is_allowed_domain(user_email=user_email)
|
|
||||||
|
|
||||||
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
|
|
||||||
user_object = None
|
|
||||||
user_id = jwt_handler.get_user_id(token=jwt_valid_token, default_value=user_email)
|
|
||||||
|
|
||||||
# get org id
|
|
||||||
org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None)
|
|
||||||
# get team id
|
|
||||||
team_id = jwt_handler.get_team_id(token=jwt_valid_token, default_value=None)
|
|
||||||
# get end user id
|
|
||||||
end_user_id = jwt_handler.get_end_user_id(token=jwt_valid_token, default_value=None)
|
|
||||||
|
|
||||||
# check if admin
|
|
||||||
is_admin = jwt_handler.is_admin(scopes=scopes)
|
|
||||||
# if admin return
|
|
||||||
if is_admin:
|
|
||||||
# check allowed admin routes
|
|
||||||
is_allowed = allowed_routes_check(
|
|
||||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
||||||
user_route=route,
|
|
||||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
|
||||||
)
|
|
||||||
if is_allowed:
|
|
||||||
return JWTAuthBuilderResult(
|
|
||||||
is_proxy_admin=True,
|
|
||||||
team_object=None,
|
|
||||||
user_object=None,
|
|
||||||
end_user_object=None,
|
|
||||||
org_object=None,
|
|
||||||
token=api_key,
|
|
||||||
team_id=team_id,
|
|
||||||
user_id=user_id,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
org_id=org_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
|
|
||||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
|
||||||
raise Exception(
|
|
||||||
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if team_id is None and jwt_handler.is_required_team_id() is True:
|
|
||||||
raise Exception(
|
|
||||||
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
team_object: Optional[LiteLLM_TeamTable] = None
|
|
||||||
if team_id is not None:
|
|
||||||
# check allowed team routes
|
|
||||||
is_allowed = allowed_routes_check(
|
|
||||||
user_role=LitellmUserRoles.TEAM,
|
|
||||||
user_route=route,
|
|
||||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
|
||||||
)
|
|
||||||
if is_allowed is False:
|
|
||||||
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
|
|
||||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
|
||||||
raise Exception(
|
|
||||||
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# check if team in db
|
|
||||||
team_object = await get_team_object(
|
|
||||||
team_id=team_id,
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
user_api_key_cache=user_api_key_cache,
|
|
||||||
parent_otel_span=parent_otel_span,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
|
|
||||||
|
|
||||||
org_object: Optional[LiteLLM_OrganizationTable] = None
|
|
||||||
if org_id is not None:
|
|
||||||
org_object = await get_org_object(
|
|
||||||
org_id=org_id,
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
user_api_key_cache=user_api_key_cache,
|
|
||||||
parent_otel_span=parent_otel_span,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_id is not None:
|
|
||||||
# get the user object
|
|
||||||
user_object = await get_user_object(
|
|
||||||
user_id=user_id,
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
user_api_key_cache=user_api_key_cache,
|
|
||||||
user_id_upsert=jwt_handler.is_upsert_user_id(
|
|
||||||
valid_user_email=valid_user_email
|
|
||||||
),
|
|
||||||
parent_otel_span=parent_otel_span,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
)
|
|
||||||
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
|
|
||||||
end_user_object = None
|
|
||||||
|
|
||||||
if end_user_id is not None:
|
|
||||||
# get the end-user object
|
|
||||||
end_user_object = await get_end_user_object(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
user_api_key_cache=user_api_key_cache,
|
|
||||||
parent_otel_span=parent_otel_span,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
return JWTAuthBuilderResult(
|
|
||||||
is_proxy_admin=False,
|
|
||||||
team_id=team_id,
|
|
||||||
team_object=team_object,
|
|
||||||
user_id=user_id,
|
|
||||||
user_object=user_object,
|
|
||||||
org_id=org_id,
|
|
||||||
org_object=org_object,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
end_user_object=end_user_object,
|
|
||||||
token=api_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _user_api_key_auth_builder( # noqa: PLR0915
|
async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
|
@ -612,7 +413,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
is_jwt = jwt_handler.is_jwt(token=api_key)
|
is_jwt = jwt_handler.is_jwt(token=api_key)
|
||||||
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
|
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
|
||||||
if is_jwt:
|
if is_jwt:
|
||||||
result = await _jwt_auth_user_api_key_auth_builder(
|
result = await JWTAuthManager.auth_builder(
|
||||||
request_data=request_data,
|
request_data=request_data,
|
||||||
general_settings=general_settings,
|
general_settings=general_settings,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
|
|
@ -986,7 +986,7 @@ async def test_allow_access_by_email(public_jwt_key, user_email, should_work):
|
||||||
# )
|
# )
|
||||||
# ),
|
# ),
|
||||||
with patch.object(
|
with patch.object(
|
||||||
litellm.proxy.auth.user_api_key_auth,
|
litellm.proxy.auth.handle_jwt,
|
||||||
"get_user_object",
|
"get_user_object",
|
||||||
side_effect=mock_user_object,
|
side_effect=mock_user_object,
|
||||||
) as mock_client:
|
) as mock_client:
|
||||||
|
|
|
@ -799,8 +799,7 @@ async def test_user_api_key_auth_websocket():
|
||||||
@pytest.mark.parametrize("enforce_rbac", [True, False])
|
@pytest.mark.parametrize("enforce_rbac", [True, False])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypatch):
|
async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypatch):
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler, JWTAuthManager
|
||||||
from litellm.proxy.auth.user_api_key_auth import _jwt_auth_user_api_key_auth_builder
|
|
||||||
from unittest.mock import patch, Mock
|
from unittest.mock import patch, Mock
|
||||||
from litellm.proxy._types import LiteLLM_JWTAuth
|
from litellm.proxy._types import LiteLLM_JWTAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
@ -861,9 +860,9 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa
|
||||||
|
|
||||||
if enforce_rbac:
|
if enforce_rbac:
|
||||||
with pytest.raises(HTTPException):
|
with pytest.raises(HTTPException):
|
||||||
await _jwt_auth_user_api_key_auth_builder(**args)
|
await JWTAuthManager.auth_builder(**args)
|
||||||
else:
|
else:
|
||||||
await _jwt_auth_user_api_key_auth_builder(**args)
|
await JWTAuthManager.auth_builder(**args)
|
||||||
|
|
||||||
|
|
||||||
def test_user_api_key_auth_end_user_str():
|
def test_user_api_key_auth_end_user_str():
|
||||||
|
@ -882,7 +881,7 @@ def test_user_api_key_auth_end_user_str():
|
||||||
|
|
||||||
|
|
||||||
def test_can_rbac_role_call_model():
|
def test_can_rbac_role_call_model():
|
||||||
from litellm.proxy.auth.user_api_key_auth import can_rbac_role_call_model
|
from litellm.proxy.auth.handle_jwt import JWTAuthManager
|
||||||
from litellm.proxy._types import RoleBasedPermissions
|
from litellm.proxy._types import RoleBasedPermissions
|
||||||
|
|
||||||
roles_based_permissions = [
|
roles_based_permissions = [
|
||||||
|
@ -896,21 +895,21 @@ def test_can_rbac_role_call_model():
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
assert can_rbac_role_call_model(
|
assert JWTAuthManager.can_rbac_role_call_model(
|
||||||
rbac_role=LitellmUserRoles.INTERNAL_USER,
|
rbac_role=LitellmUserRoles.INTERNAL_USER,
|
||||||
general_settings={"role_permissions": roles_based_permissions},
|
general_settings={"role_permissions": roles_based_permissions},
|
||||||
model="gpt-4",
|
model="gpt-4",
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(HTTPException):
|
with pytest.raises(HTTPException):
|
||||||
can_rbac_role_call_model(
|
JWTAuthManager.can_rbac_role_call_model(
|
||||||
rbac_role=LitellmUserRoles.INTERNAL_USER,
|
rbac_role=LitellmUserRoles.INTERNAL_USER,
|
||||||
general_settings={"role_permissions": roles_based_permissions},
|
general_settings={"role_permissions": roles_based_permissions},
|
||||||
model="gpt-4o",
|
model="gpt-4o",
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(HTTPException):
|
with pytest.raises(HTTPException):
|
||||||
can_rbac_role_call_model(
|
JWTAuthManager.can_rbac_role_call_model(
|
||||||
rbac_role=LitellmUserRoles.PROXY_ADMIN,
|
rbac_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
general_settings={"role_permissions": roles_based_permissions},
|
general_settings={"role_permissions": roles_based_permissions},
|
||||||
model="gpt-4o",
|
model="gpt-4o",
|
||||||
|
@ -918,15 +917,15 @@ def test_can_rbac_role_call_model():
|
||||||
|
|
||||||
|
|
||||||
def test_can_rbac_role_call_model_no_role_permissions():
|
def test_can_rbac_role_call_model_no_role_permissions():
|
||||||
from litellm.proxy.auth.user_api_key_auth import can_rbac_role_call_model
|
from litellm.proxy.auth.handle_jwt import JWTAuthManager
|
||||||
|
|
||||||
assert can_rbac_role_call_model(
|
assert JWTAuthManager.can_rbac_role_call_model(
|
||||||
rbac_role=LitellmUserRoles.INTERNAL_USER,
|
rbac_role=LitellmUserRoles.INTERNAL_USER,
|
||||||
general_settings={},
|
general_settings={},
|
||||||
model="gpt-4",
|
model="gpt-4",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert can_rbac_role_call_model(
|
assert JWTAuthManager.can_rbac_role_call_model(
|
||||||
rbac_role=LitellmUserRoles.PROXY_ADMIN,
|
rbac_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
general_settings={"role_permissions": []},
|
general_settings={"role_permissions": []},
|
||||||
model="anthropic-claude",
|
model="anthropic-claude",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue