Litellm dev 01 31 2025 p2 (#8164)

* docs(token_auth.md): clarify title

* refactor(handle_jwt.py): add jwt auth manager + refactor to handle groups

allows user to call model if user belongs to group with model access

* refactor(handle_jwt.py): refactor to first check if service call then check user call

* feat(handle_jwt.py): new `enforce_team_access` param

only allows user to call model if a team they belong to has model access

allows controlling user model access by team

* fix(handle_jwt.py): fix error string, remove unecessary param

* docs(token_auth.md): add controlling model access for jwt tokens via teams to docs

* test: fix tests post refactor

* fix: fix linting errors

* fix: fix linting error

* test: fix import error
This commit is contained in:
Krish Dholakia 2025-01-31 22:52:35 -08:00 committed by GitHub
parent 2674fa4dc3
commit a008a2d4f4
6 changed files with 447 additions and 242 deletions

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# SSO - JWT-based Auth # OIDC - JWT-based Auth
Use JWT's to auth admins / projects into the proxy. Use JWT's to auth admins / projects into the proxy.
@ -156,35 +156,12 @@ scope: ["litellm-proxy-admin",...]
scope: "litellm-proxy-admin ..." scope: "litellm-proxy-admin ..."
``` ```
## Enforce Role-Based Access Control (RBAC) ## Control Model Access with Roles
Reject a JWT token if it's valid but doesn't have the required scopes / fields. Reject a JWT token if it's valid but doesn't have the required scopes / fields.
Only tokens which with valid Admin (`admin_jwt_scope`), User (`user_id_jwt_field`), Team (`team_id_jwt_field`) are allowed. Only tokens which with valid Admin (`admin_jwt_scope`), User (`user_id_jwt_field`), Team (`team_id_jwt_field`) are allowed.
```yaml
general_settings:
master_key: sk-1234
enable_jwt_auth: True
litellm_jwtauth:
admin_jwt_scope: "litellm_proxy_endpoints_access"
admin_allowed_routes:
- openai_routes
- info_routes
public_key_ttl: 600
enforce_rbac: true # 👈 Enforce RBAC
```
Expected Scope in JWT:
```
{
"scope": "litellm_proxy_endpoints_access"
}
```
### Control Model Access
```yaml ```yaml
general_settings: general_settings:
enable_jwt_auth: True enable_jwt_auth: True
@ -198,9 +175,57 @@ general_settings:
models: ["anthropic-claude"] models: ["anthropic-claude"]
``` ```
**[Architecture Diagram (Control Model Access)](./jwt_auth_arch)** **[Architecture Diagram (Control Model Access)](./jwt_auth_arch)**
## Control model access with Teams
1. Specify the JWT field that contains the team ids, that the user belongs to.
```yaml
general_settings:
master_key: sk-1234
litellm_jwtauth:
user_id_jwt_field: "sub"
team_ids_jwt_field: "groups"
```
This is assuming your token looks like this:
```
{
...,
"sub": "my-unique-user",
"groups": ["team_id_1", "team_id_2"]
}
```
2. Create the teams on LiteLLM
```bash
curl -X POST '<PROXY_BASE_URL>/team/new' \
-H 'Authorization: Bearer <PROXY_MASTER_KEY>' \
-H 'Content-Type: application/json' \
-D '{
"team_alias": "team_1",
"team_id": "team_id_1" # 👈 MUST BE THE SAME AS THE SSO GROUP ID
}'
```
3. Test the flow
SSO for UI: [**See Walkthrough**](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
OIDC Auth for API: [**See Walkthrough**](https://www.loom.com/share/00fe2deab59a426183a46b1e2b522200?sid=4ed6d497-ead6-47f9-80c0-ca1c4b6b4814)
### Flow
- Validate if user id is in the DB (LiteLLM_UserTable)
- Validate if any of the groups are in the DB (LiteLLM_TeamTable)
- Validate if any group has model access
- If all checks pass, allow the request
## Advanced - Allowed Routes ## Advanced - Allowed Routes
Configure which routes a JWT can access via the config. Configure which routes a JWT can access via the config.

View file

@ -3,6 +3,10 @@ model_list:
litellm_params: litellm_params:
model: gpt-3.5-turbo model: gpt-3.5-turbo
rpm: 3 rpm: 3
- model_name: o3-mini
litellm_params:
model: o3-mini
rpm: 3
- model_name: anthropic-claude - model_name: anthropic-claude
litellm_params: litellm_params:
model: claude-3-5-haiku-20241022 model: claude-3-5-haiku-20241022
@ -19,3 +23,10 @@ model_list:
litellm_settings: litellm_settings:
callbacks: ["langsmith"] callbacks: ["langsmith"]
disable_no_log_param: true disable_no_log_param: true
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
user_id_jwt_field: "sub"
user_email_jwt_field: "email"
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD

View file

@ -8,11 +8,12 @@ JWT token must have 'litellm_proxy_admin' in scope.
import json import json
import os import os
from typing import List, Optional, cast from typing import Any, List, Literal, Optional, Set, Tuple, cast
from cryptography import x509 from cryptography import x509
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
@ -21,11 +22,27 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import ( from litellm.proxy._types import (
RBAC_ROLES, RBAC_ROLES,
JWKKeyValue, JWKKeyValue,
JWTAuthBuilderResult,
JWTKeyItem, JWTKeyItem,
LiteLLM_EndUserTable,
LiteLLM_JWTAuth, LiteLLM_JWTAuth,
LiteLLM_OrganizationTable,
LiteLLM_TeamTable,
LiteLLM_UserTable,
LitellmUserRoles, LitellmUserRoles,
Span,
)
from litellm.proxy.utils import PrismaClient, ProxyLogging
from .auth_checks import (
allowed_routes_check,
get_actual_routes,
get_end_user_object,
get_org_object,
get_role_based_models,
get_team_object,
get_user_object,
) )
from litellm.proxy.utils import PrismaClient
class JWTHandler: class JWTHandler:
@ -401,3 +418,355 @@ class JWTHandler:
async def close(self): async def close(self):
await self.http_handler.close() await self.http_handler.close()
class JWTAuthManager:
"""Manages JWT authentication and authorization operations"""
@staticmethod
def can_rbac_role_call_model(
rbac_role: RBAC_ROLES,
general_settings: dict,
model: Optional[str],
) -> Literal[True]:
"""
Checks if user is allowed to access the model, based on their role.
"""
role_based_models = get_role_based_models(
rbac_role=rbac_role, general_settings=general_settings
)
if role_based_models is None or model is None:
return True
if model not in role_based_models:
raise HTTPException(
status_code=403,
detail=f"User role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
)
return True
@staticmethod
async def check_rbac_role(
jwt_handler: JWTHandler,
jwt_valid_token: dict,
general_settings: dict,
request_data: dict,
) -> None:
"""Validate RBAC role and model access permissions"""
if jwt_handler.litellm_jwtauth.enforce_rbac is True:
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
if rbac_role is None:
raise HTTPException(
status_code=403,
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.",
)
JWTAuthManager.can_rbac_role_call_model(
rbac_role=rbac_role,
general_settings=general_settings,
model=request_data.get("model"),
)
@staticmethod
async def check_admin_access(
jwt_handler: JWTHandler,
scopes: list,
route: str,
user_id: Optional[str],
org_id: Optional[str],
api_key: str,
) -> Optional[JWTAuthBuilderResult]:
"""Check admin status and route access permissions"""
if not jwt_handler.is_admin(scopes=scopes):
return None
is_allowed = allowed_routes_check(
user_role=LitellmUserRoles.PROXY_ADMIN,
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if not is_allowed:
allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
return JWTAuthBuilderResult(
is_proxy_admin=True,
team_object=None,
user_object=None,
end_user_object=None,
org_object=None,
token=api_key,
team_id=None,
user_id=user_id,
end_user_id=None,
org_id=org_id,
)
@staticmethod
async def find_and_validate_specific_team_id(
jwt_handler: JWTHandler,
jwt_valid_token: dict,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span],
proxy_logging_obj: ProxyLogging,
) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]:
"""Find and validate specific team ID"""
individual_team_id = jwt_handler.get_team_id(
token=jwt_valid_token, default_value=None
)
if not individual_team_id and jwt_handler.is_required_team_id() is True:
raise Exception(
f"No team id found in token. Checked team_id field '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
)
## VALIDATE TEAM OBJECT ###
team_object: Optional[LiteLLM_TeamTable] = None
if individual_team_id:
team_object = await get_team_object(
team_id=individual_team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
return individual_team_id, team_object
@staticmethod
def get_all_team_ids(jwt_handler: JWTHandler, jwt_valid_token: dict) -> Set[str]:
"""Get combined team IDs from groups and individual team_id"""
team_ids_from_groups = jwt_handler.get_team_ids_from_jwt(token=jwt_valid_token)
all_team_ids = set(team_ids_from_groups)
return all_team_ids
@staticmethod
async def find_team_with_model_access(
team_ids: Set[str],
requested_model: Optional[str],
route: str,
jwt_handler: JWTHandler,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span],
proxy_logging_obj: ProxyLogging,
) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]:
"""Find first team with access to the requested model"""
if not team_ids:
return None, None
for team_id in team_ids:
try:
team_object = await get_team_object(
team_id=team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if team_object and team_object.models is not None:
team_models = team_object.models
if isinstance(team_models, list) and (
not requested_model
or requested_model in team_models
or "*" in team_models
):
is_allowed = allowed_routes_check(
user_role=LitellmUserRoles.TEAM,
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed:
return team_id, team_object
except Exception:
continue
if requested_model:
raise HTTPException(
status_code=403,
detail=f"No team has access to the requested model: {requested_model}. Checked teams={team_ids}",
)
return None, None
@staticmethod
async def get_user_info(
jwt_handler: JWTHandler,
jwt_valid_token: dict,
) -> Tuple[Optional[str], Optional[str], Optional[bool]]:
"""Get user email and validation status"""
user_email = jwt_handler.get_user_email(
token=jwt_valid_token, default_value=None
)
valid_user_email = None
if jwt_handler.is_enforced_email_domain():
valid_user_email = (
False
if user_email is None
else jwt_handler.is_allowed_domain(user_email=user_email)
)
user_id = jwt_handler.get_user_id(
token=jwt_valid_token, default_value=user_email
)
return user_id, user_email, valid_user_email
@staticmethod
async def get_objects(
user_id: Optional[str],
org_id: Optional[str],
end_user_id: Optional[str],
valid_user_email: Optional[bool],
jwt_handler: JWTHandler,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span],
proxy_logging_obj: ProxyLogging,
) -> Tuple[
Optional[LiteLLM_UserTable],
Optional[LiteLLM_OrganizationTable],
Optional[LiteLLM_EndUserTable],
]:
"""Get user, org, and end user objects"""
org_object: Optional[LiteLLM_OrganizationTable] = None
if org_id:
org_object = (
await get_org_object(
org_id=org_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if org_id
else None
)
user_object: Optional[LiteLLM_UserTable] = None
if user_id:
user_object = (
await get_user_object(
user_id=user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
user_id_upsert=jwt_handler.is_upsert_user_id(
valid_user_email=valid_user_email
),
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if user_id
else None
)
end_user_object: Optional[LiteLLM_EndUserTable] = None
if end_user_id:
end_user_object = (
await get_end_user_object(
end_user_id=end_user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if end_user_id
else None
)
return user_object, org_object, end_user_object
@staticmethod
async def auth_builder(
api_key: str,
jwt_handler: JWTHandler,
request_data: dict,
general_settings: dict,
route: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span],
proxy_logging_obj: ProxyLogging,
) -> JWTAuthBuilderResult:
"""Main authentication and authorization builder"""
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
# Check RBAC
await JWTAuthManager.check_rbac_role(
jwt_handler, jwt_valid_token, general_settings, request_data
)
# Get basic user info
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
user_id, _, valid_user_email = await JWTAuthManager.get_user_info(
jwt_handler, jwt_valid_token
)
# Get IDs
org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None)
end_user_id = jwt_handler.get_end_user_id(
token=jwt_valid_token, default_value=None
)
# Check admin access
admin_result = await JWTAuthManager.check_admin_access(
jwt_handler, scopes, route, user_id, org_id, api_key
)
if admin_result:
return admin_result
# Get team with model access
## SPECIFIC TEAM ID
team_id, team_object = await JWTAuthManager.find_and_validate_specific_team_id(
jwt_handler,
jwt_valid_token,
prisma_client,
user_api_key_cache,
parent_otel_span,
proxy_logging_obj,
)
if not team_object:
## CHECK USER GROUP ACCESS
all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token)
team_id, team_object = await JWTAuthManager.find_team_with_model_access(
team_ids=all_team_ids,
requested_model=request_data.get("model"),
route=route,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# Get other objects
user_object, org_object, end_user_object = await JWTAuthManager.get_objects(
user_id=user_id,
org_id=org_id,
end_user_id=end_user_id,
valid_user_email=valid_user_email,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
return JWTAuthBuilderResult(
is_proxy_admin=False,
team_id=team_id,
team_object=team_object,
user_id=user_id,
user_object=user_object,
org_id=org_id,
org_object=org_object,
end_user_id=end_user_id,
end_user_object=end_user_object,
token=api_key,
)

View file

@ -26,14 +26,10 @@ from litellm.proxy.auth.auth_checks import (
_handle_failed_db_connection_for_get_key_object, _handle_failed_db_connection_for_get_key_object,
_virtual_key_max_budget_check, _virtual_key_max_budget_check,
_virtual_key_soft_budget_check, _virtual_key_soft_budget_check,
allowed_routes_check,
can_key_call_model, can_key_call_model,
common_checks, common_checks,
get_actual_routes,
get_end_user_object, get_end_user_object,
get_key_object, get_key_object,
get_org_object,
get_role_based_models,
get_team_object, get_team_object,
get_user_object, get_user_object,
is_valid_fallback_model, is_valid_fallback_model,
@ -47,7 +43,7 @@ from litellm.proxy.auth.auth_utils import (
route_in_additonal_public_routes, route_in_additonal_public_routes,
should_run_auth_on_pass_through_provider_route, should_run_auth_on_pass_through_provider_route,
) )
from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.handle_jwt import JWTAuthManager, JWTHandler
from litellm.proxy.auth.oauth2_check import check_oauth2_token from litellm.proxy.auth.oauth2_check import check_oauth2_token
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.auth.route_checks import RouteChecks
@ -282,201 +278,6 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
return LitellmUserRoles.TEAM return LitellmUserRoles.TEAM
def can_rbac_role_call_model(
rbac_role: RBAC_ROLES,
general_settings: dict,
model: Optional[str],
) -> Literal[True]:
"""
Checks if user is allowed to access the model, based on their role.
"""
role_based_models = get_role_based_models(
rbac_role=rbac_role, general_settings=general_settings
)
if role_based_models is None or model is None:
return True
if model not in role_based_models:
raise HTTPException(
status_code=403,
detail=f"User role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
)
return True
async def _jwt_auth_user_api_key_auth_builder(
api_key: str,
jwt_handler: JWTHandler,
request_data: dict,
general_settings: dict,
route: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span],
proxy_logging_obj: ProxyLogging,
) -> JWTAuthBuilderResult:
# check if valid token
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
# check if unmatched token and enforce_rbac is true
if jwt_handler.litellm_jwtauth.enforce_rbac is True:
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
if rbac_role is None:
raise HTTPException(
status_code=403,
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org",
)
else:
# run rbac validation checks
can_rbac_role_call_model(
rbac_role=rbac_role,
general_settings=general_settings,
model=request_data.get("model"),
)
# get scopes
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
# [OPTIONAL] allowed user email domains
valid_user_email: Optional[bool] = None
user_email: Optional[str] = None
if jwt_handler.is_enforced_email_domain():
"""
if 'allowed_email_subdomains' is set,
- checks if token contains 'email' field
- checks if 'email' is from an allowed domain
"""
user_email = jwt_handler.get_user_email(
token=jwt_valid_token, default_value=None
)
if user_email is None:
valid_user_email = False
else:
valid_user_email = jwt_handler.is_allowed_domain(user_email=user_email)
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
user_object = None
user_id = jwt_handler.get_user_id(token=jwt_valid_token, default_value=user_email)
# get org id
org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None)
# get team id
team_id = jwt_handler.get_team_id(token=jwt_valid_token, default_value=None)
# get end user id
end_user_id = jwt_handler.get_end_user_id(token=jwt_valid_token, default_value=None)
# check if admin
is_admin = jwt_handler.is_admin(scopes=scopes)
# if admin return
if is_admin:
# check allowed admin routes
is_allowed = allowed_routes_check(
user_role=LitellmUserRoles.PROXY_ADMIN,
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed:
return JWTAuthBuilderResult(
is_proxy_admin=True,
team_object=None,
user_object=None,
end_user_object=None,
org_object=None,
token=api_key,
team_id=team_id,
user_id=user_id,
end_user_id=end_user_id,
org_id=org_id,
)
else:
allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
if team_id is None and jwt_handler.is_required_team_id() is True:
raise Exception(
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
)
team_object: Optional[LiteLLM_TeamTable] = None
if team_id is not None:
# check allowed team routes
is_allowed = allowed_routes_check(
user_role=LitellmUserRoles.TEAM,
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed is False:
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
# check if team in db
team_object = await get_team_object(
team_id=team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
org_object: Optional[LiteLLM_OrganizationTable] = None
if org_id is not None:
org_object = await get_org_object(
org_id=org_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if user_id is not None:
# get the user object
user_object = await get_user_object(
user_id=user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
user_id_upsert=jwt_handler.is_upsert_user_id(
valid_user_email=valid_user_email
),
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
end_user_object = None
if end_user_id is not None:
# get the end-user object
end_user_object = await get_end_user_object(
end_user_id=end_user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
return JWTAuthBuilderResult(
is_proxy_admin=False,
team_id=team_id,
team_object=team_object,
user_id=user_id,
user_object=user_object,
org_id=org_id,
org_object=org_object,
end_user_id=end_user_id,
end_user_object=end_user_object,
token=api_key,
)
async def _user_api_key_auth_builder( # noqa: PLR0915 async def _user_api_key_auth_builder( # noqa: PLR0915
request: Request, request: Request,
api_key: str, api_key: str,
@ -612,7 +413,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
is_jwt = jwt_handler.is_jwt(token=api_key) is_jwt = jwt_handler.is_jwt(token=api_key)
verbose_proxy_logger.debug("is_jwt: %s", is_jwt) verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
if is_jwt: if is_jwt:
result = await _jwt_auth_user_api_key_auth_builder( result = await JWTAuthManager.auth_builder(
request_data=request_data, request_data=request_data,
general_settings=general_settings, general_settings=general_settings,
api_key=api_key, api_key=api_key,

View file

@ -986,7 +986,7 @@ async def test_allow_access_by_email(public_jwt_key, user_email, should_work):
# ) # )
# ), # ),
with patch.object( with patch.object(
litellm.proxy.auth.user_api_key_auth, litellm.proxy.auth.handle_jwt,
"get_user_object", "get_user_object",
side_effect=mock_user_object, side_effect=mock_user_object,
) as mock_client: ) as mock_client:

View file

@ -799,8 +799,7 @@ async def test_user_api_key_auth_websocket():
@pytest.mark.parametrize("enforce_rbac", [True, False]) @pytest.mark.parametrize("enforce_rbac", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypatch): async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypatch):
from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.handle_jwt import JWTHandler, JWTAuthManager
from litellm.proxy.auth.user_api_key_auth import _jwt_auth_user_api_key_auth_builder
from unittest.mock import patch, Mock from unittest.mock import patch, Mock
from litellm.proxy._types import LiteLLM_JWTAuth from litellm.proxy._types import LiteLLM_JWTAuth
from litellm.caching import DualCache from litellm.caching import DualCache
@ -861,9 +860,9 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa
if enforce_rbac: if enforce_rbac:
with pytest.raises(HTTPException): with pytest.raises(HTTPException):
await _jwt_auth_user_api_key_auth_builder(**args) await JWTAuthManager.auth_builder(**args)
else: else:
await _jwt_auth_user_api_key_auth_builder(**args) await JWTAuthManager.auth_builder(**args)
def test_user_api_key_auth_end_user_str(): def test_user_api_key_auth_end_user_str():
@ -882,7 +881,7 @@ def test_user_api_key_auth_end_user_str():
def test_can_rbac_role_call_model(): def test_can_rbac_role_call_model():
from litellm.proxy.auth.user_api_key_auth import can_rbac_role_call_model from litellm.proxy.auth.handle_jwt import JWTAuthManager
from litellm.proxy._types import RoleBasedPermissions from litellm.proxy._types import RoleBasedPermissions
roles_based_permissions = [ roles_based_permissions = [
@ -896,21 +895,21 @@ def test_can_rbac_role_call_model():
), ),
] ]
assert can_rbac_role_call_model( assert JWTAuthManager.can_rbac_role_call_model(
rbac_role=LitellmUserRoles.INTERNAL_USER, rbac_role=LitellmUserRoles.INTERNAL_USER,
general_settings={"role_permissions": roles_based_permissions}, general_settings={"role_permissions": roles_based_permissions},
model="gpt-4", model="gpt-4",
) )
with pytest.raises(HTTPException): with pytest.raises(HTTPException):
can_rbac_role_call_model( JWTAuthManager.can_rbac_role_call_model(
rbac_role=LitellmUserRoles.INTERNAL_USER, rbac_role=LitellmUserRoles.INTERNAL_USER,
general_settings={"role_permissions": roles_based_permissions}, general_settings={"role_permissions": roles_based_permissions},
model="gpt-4o", model="gpt-4o",
) )
with pytest.raises(HTTPException): with pytest.raises(HTTPException):
can_rbac_role_call_model( JWTAuthManager.can_rbac_role_call_model(
rbac_role=LitellmUserRoles.PROXY_ADMIN, rbac_role=LitellmUserRoles.PROXY_ADMIN,
general_settings={"role_permissions": roles_based_permissions}, general_settings={"role_permissions": roles_based_permissions},
model="gpt-4o", model="gpt-4o",
@ -918,15 +917,15 @@ def test_can_rbac_role_call_model():
def test_can_rbac_role_call_model_no_role_permissions(): def test_can_rbac_role_call_model_no_role_permissions():
from litellm.proxy.auth.user_api_key_auth import can_rbac_role_call_model from litellm.proxy.auth.handle_jwt import JWTAuthManager
assert can_rbac_role_call_model( assert JWTAuthManager.can_rbac_role_call_model(
rbac_role=LitellmUserRoles.INTERNAL_USER, rbac_role=LitellmUserRoles.INTERNAL_USER,
general_settings={}, general_settings={},
model="gpt-4", model="gpt-4",
) )
assert can_rbac_role_call_model( assert JWTAuthManager.can_rbac_role_call_model(
rbac_role=LitellmUserRoles.PROXY_ADMIN, rbac_role=LitellmUserRoles.PROXY_ADMIN,
general_settings={"role_permissions": []}, general_settings={"role_permissions": []},
model="anthropic-claude", model="anthropic-claude",