[BETA] Support OIDC role based access to proxy (#8260)

* feat(proxy/_types.py): add new jwt field params

allows users + services to auth into proxy

* feat(handle_jwt.py): allow team role proxy access

allows proxy admin to set allowed team roles

* fix(proxy/_types.py): add 'routes' to role based permissions

allow proxy admin to restrict what routes a team can access easily

* feat(handle_jwt.py): support more flexible role based route access

v2 on role based 'allowed_routes'

* test(test_jwt.py): add unit test for rbac for proxy routes

* feat(handle_jwt.py): ensure cost tracking always works for any jwt request with `enforce_rbac=True`

* docs(token_auth.md): add documentation on controlling model access via OIDC Roles

* test: increase time delay before retrying

* test: handle model overloaded for test
This commit is contained in:
Krish Dholakia 2025-02-04 21:59:39 -08:00 committed by GitHub
parent 1d030ebed7
commit 015b822099
10 changed files with 413 additions and 143 deletions

View file

@ -3,7 +3,7 @@ import TabItem from '@theme/TabItem';
# OIDC - JWT-based Auth # OIDC - JWT-based Auth
Use JWT's to auth admins / projects into the proxy. Use JWT's to auth admins / users / projects into the proxy.
:::info :::info
@ -156,27 +156,6 @@ scope: ["litellm-proxy-admin",...]
scope: "litellm-proxy-admin ..." scope: "litellm-proxy-admin ..."
``` ```
## Control Model Access with Roles
Reject a JWT token if it's valid but doesn't have the required scopes / fields.
Only tokens which with valid Admin (`admin_jwt_scope`), User (`user_id_jwt_field`), Team (`team_id_jwt_field`) are allowed.
```yaml
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
user_roles_jwt_field: "resource_access.litellm-test-client-id.roles"
user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM
enforce_rbac: true # if true, will check if the user has the correct role to access the model + endpoint
role_permissions: # control what models + endpointsare allowed for each role
- role: internal_user
models: ["anthropic-claude"]
```
**[Architecture Diagram (Control Model Access)](./jwt_auth_arch)**
## Control model access with Teams ## Control model access with Teams
@ -331,3 +310,64 @@ general_settings:
user_allowed_email_domain: "my-co.com" # allows user@my-co.com to call proxy user_allowed_email_domain: "my-co.com" # allows user@my-co.com to call proxy
user_id_upsert: true # 👈 upserts the user to db, if valid email but not in db user_id_upsert: true # 👈 upserts the user to db, if valid email but not in db
``` ```
## [BETA] Control Access with OIDC Roles
Allow JWT tokens with supported roles to access the proxy.
Let users and teams access the proxy, without needing to add them to the DB.
Very important, set `enforce_rbac: true` to ensure that the RBAC system is enabled.
**Note:** This is in beta and might change unexpectedly.
```yaml
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
object_id_jwt_field: "oid" # can be either user / team, inferred from the role mapping
roles_jwt_field: "roles"
role_mappings:
- role: litellm.api.consumer
internal_role: "team"
enforce_rbac: true # 👈 VERY IMPORTANT
role_permissions: # default model + endpoint permissions for a role.
- role: team
models: ["anthropic-claude"]
routes: ["/v1/chat/completions"]
environment_variables:
JWT_AUDIENCE: "api://LiteLLM_Proxy" # ensures audience is validated
```
- `object_id_jwt_field`: The field in the JWT token that contains the object id. This id can be either a user id or a team id. Use this instead of `user_id_jwt_field` and `team_id_jwt_field`. If the same field could be both.
- `roles_jwt_field`: The field in the JWT token that contains the roles. This field is a list of roles that the user has. To index into a nested field, use dot notation - eg. `resource_access.litellm-test-client-id.roles`.
- `role_mappings`: A list of role mappings. Map the received role in the JWT token to an internal role on LiteLLM.
- `JWT_AUDIENCE`: The audience of the JWT token. This is used to validate the audience of the JWT token. Set via an environment variable.
### Example Token
```
{
"aud": "api://LiteLLM_Proxy",
"oid": "eec236bd-0135-4b28-9354-8fc4032d543e",
"roles": ["litellm.api.consumer"]
}
```
### Role Mapping Spec
- `role`: The expected role in the JWT token.
- `internal_role`: The internal role on LiteLLM that will be used to control access.
Supported internal roles:
- `team`: Team object will be used for RBAC spend tracking. Use this for tracking spend for a 'use case'.
- `internal_user`: User object will be used for RBAC spend tracking. Use this for tracking spend for an 'individual user'.
- `proxy_admin`: Proxy admin will be used for RBAC spend tracking. Use this for granting admin access to a token.
### [Architecture Diagram (Control Model Access)](./jwt_auth_arch)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -35,6 +35,15 @@ litellm_settings:
general_settings: general_settings:
enable_jwt_auth: True enable_jwt_auth: True
litellm_jwtauth: litellm_jwtauth:
user_id_jwt_field: "sub" object_id_jwt_field: "client_id" # can be either user / team, inferred from the role mapping
user_email_jwt_field: "email" roles_jwt_field: "resource_access.litellm-test-client-id.roles"
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD role_mappings:
- role: litellm.api.consumer
internal_role: "team"
enforce_rbac: true
role_permissions: # default model + endpoint permissions for a role.
- role: team
models: ["anthropic-claude"]
routes: ["openai_routes"]

View file

@ -397,92 +397,6 @@ class LiteLLMRoutes(enum.Enum):
) )
# class LiteLLMAllowedRoutes(LiteLLMPydanticObjectBase):
# """
# Defines allowed routes based on key type.
# Types = ["admin", "team", "user", "unmapped"]
# """
# admin_allowed_routes: List[
# Literal["openai_routes", "info_routes", "management_routes", "spend_tracking_routes", "global_spend_tracking_routes"]
# ] = ["management_routes"]
class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
"""
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
Attributes:
- admin_jwt_scope: The JWT scope required for proxy admin roles.
- admin_allowed_routes: list of allowed routes for proxy admin roles.
- team_jwt_scope: The JWT scope required for proxy team roles.
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
- team_allowed_routes: list of allowed routes for proxy team roles.
- user_id_jwt_field: The field in the JWT token that stores the user id (maps to `LiteLLMUserTable`). Use this for internal employees.
- user_email_jwt_field: The field in the JWT token that stores the user email (maps to `LiteLLMUserTable`). Use this for internal employees.
- user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy.
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
- public_allowed_routes: list of allowed routes for authenticated but unknown litellm role jwt tokens.
- enforce_rbac: If true, enforce RBAC for all routes.
See `auth_checks.py` for the specific routes
"""
admin_jwt_scope: str = "litellm_proxy_admin"
admin_allowed_routes: List[str] = [
"management_routes",
"spend_tracking_routes",
"global_spend_tracking_routes",
"info_routes",
]
team_id_jwt_field: Optional[str] = None
team_ids_jwt_field: Optional[str] = None
upsert_sso_user_to_team: bool = False
team_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"]
] = ["openai_routes", "info_routes"]
team_id_default: Optional[str] = Field(
default=None,
description="If no team_id given, default permissions/spend-tracking to this team.s",
)
org_id_jwt_field: Optional[str] = None
user_id_jwt_field: Optional[str] = None
user_email_jwt_field: Optional[str] = None
user_allowed_email_domain: Optional[str] = None
user_roles_jwt_field: Optional[str] = None
user_allowed_roles: Optional[List[str]] = None
user_id_upsert: bool = Field(
default=False, description="If user doesn't exist, upsert them into the db."
)
end_user_id_jwt_field: Optional[str] = None
public_key_ttl: float = 600
public_allowed_routes: List[str] = ["public_routes"]
enforce_rbac: bool = False
def __init__(self, **kwargs: Any) -> None:
# get the attribute names for this Pydantic model
allowed_keys = self.__annotations__.keys()
invalid_keys = set(kwargs.keys()) - allowed_keys
user_roles_jwt_field = kwargs.get("user_roles_jwt_field")
user_allowed_roles = kwargs.get("user_allowed_roles")
if invalid_keys:
raise ValueError(
f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}."
)
if (user_roles_jwt_field is not None and user_allowed_roles is None) or (
user_roles_jwt_field is None and user_allowed_roles is not None
):
raise ValueError(
"user_allowed_roles must be provided if user_roles_jwt_field is set."
)
super().__init__(**kwargs)
class LiteLLMPromptInjectionParams(LiteLLMPydanticObjectBase): class LiteLLMPromptInjectionParams(LiteLLMPydanticObjectBase):
heuristics_check: bool = False heuristics_check: bool = False
vector_db_check: bool = False vector_db_check: bool = False
@ -2364,6 +2278,103 @@ RBAC_ROLES = Literal[
] ]
class RoleBasedPermissions(TypedDict): class RoleBasedPermissions(LiteLLMPydanticObjectBase):
role: Required[RBAC_ROLES] role: RBAC_ROLES
models: Required[List[str]] models: Optional[List[str]] = None
routes: Optional[List[str]] = None
model_config = {
"extra": "forbid",
}
class RoleMapping(BaseModel):
role: str
internal_role: RBAC_ROLES
class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
"""
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
Attributes:
- admin_jwt_scope: The JWT scope required for proxy admin roles.
- admin_allowed_routes: list of allowed routes for proxy admin roles.
- team_jwt_scope: The JWT scope required for proxy team roles.
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
- team_allowed_routes: list of allowed routes for proxy team roles.
- user_id_jwt_field: The field in the JWT token that stores the user id (maps to `LiteLLMUserTable`). Use this for internal employees.
- user_email_jwt_field: The field in the JWT token that stores the user email (maps to `LiteLLMUserTable`). Use this for internal employees.
- user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy.
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
- public_allowed_routes: list of allowed routes for authenticated but unknown litellm role jwt tokens.
- enforce_rbac: If true, enforce RBAC for all routes.
See `auth_checks.py` for the specific routes
"""
admin_jwt_scope: str = "litellm_proxy_admin"
admin_allowed_routes: List[str] = [
"management_routes",
"spend_tracking_routes",
"global_spend_tracking_routes",
"info_routes",
]
team_id_jwt_field: Optional[str] = None
team_ids_jwt_field: Optional[str] = None
upsert_sso_user_to_team: bool = False
team_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"]
] = ["openai_routes", "info_routes"]
team_id_default: Optional[str] = Field(
default=None,
description="If no team_id given, default permissions/spend-tracking to this team.s",
)
org_id_jwt_field: Optional[str] = None
user_id_jwt_field: Optional[str] = None
user_email_jwt_field: Optional[str] = None
user_allowed_email_domain: Optional[str] = None
user_roles_jwt_field: Optional[str] = None
user_allowed_roles: Optional[List[str]] = None
user_id_upsert: bool = Field(
default=False, description="If user doesn't exist, upsert them into the db."
)
end_user_id_jwt_field: Optional[str] = None
public_key_ttl: float = 600
public_allowed_routes: List[str] = ["public_routes"]
enforce_rbac: bool = False
roles_jwt_field: Optional[str] = None # v2 on role mappings
role_mappings: Optional[List[RoleMapping]] = None
object_id_jwt_field: Optional[str] = (
None # can be either user / team, inferred from the role mapping
)
def __init__(self, **kwargs: Any) -> None:
# get the attribute names for this Pydantic model
allowed_keys = self.__annotations__.keys()
invalid_keys = set(kwargs.keys()) - allowed_keys
user_roles_jwt_field = kwargs.get("user_roles_jwt_field")
user_allowed_roles = kwargs.get("user_allowed_roles")
object_id_jwt_field = kwargs.get("object_id_jwt_field")
role_mappings = kwargs.get("role_mappings")
if invalid_keys:
raise ValueError(
f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}."
)
if (user_roles_jwt_field is not None and user_allowed_roles is None) or (
user_roles_jwt_field is None and user_allowed_roles is not None
):
raise ValueError(
"user_allowed_roles must be provided if user_roles_jwt_field is set."
)
if object_id_jwt_field is not None and role_mappings is None:
raise ValueError(
"if object_id_jwt_field is set, role_mappings must also be set. Needed to infer if the caller is a user or team."
)
super().__init__(**kwargs)

View file

@ -200,6 +200,7 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
- user_route: str - the route the user is trying to call - user_route: str - the route the user is trying to call
- allowed_routes: List[str|LiteLLMRoutes] - the list of allowed routes for the user. - allowed_routes: List[str|LiteLLMRoutes] - the list of allowed routes for the user.
""" """
for allowed_route in allowed_routes: for allowed_route in allowed_routes:
if ( if (
allowed_route in LiteLLMRoutes.__members__ allowed_route in LiteLLMRoutes.__members__
@ -402,6 +403,29 @@ def _update_last_db_access_time(
last_db_access_time[key] = (value, time.time()) last_db_access_time[key] = (value, time.time())
def _get_role_based_permissions(
rbac_role: RBAC_ROLES,
general_settings: dict,
key: Literal["models", "routes"],
) -> Optional[List[str]]:
"""
Get the role based permissions from the general settings.
"""
role_based_permissions = cast(
Optional[List[RoleBasedPermissions]],
general_settings.get("role_permissions", []),
)
if role_based_permissions is None:
return None
for role_based_permission in role_based_permissions:
if role_based_permission.role == rbac_role:
return getattr(role_based_permission, key)
return None
def get_role_based_models( def get_role_based_models(
rbac_role: RBAC_ROLES, rbac_role: RBAC_ROLES,
general_settings: dict, general_settings: dict,
@ -412,18 +436,26 @@ def get_role_based_models(
Used by JWT Auth. Used by JWT Auth.
""" """
role_based_permissions = cast( return _get_role_based_permissions(
Optional[List[RoleBasedPermissions]], rbac_role=rbac_role,
general_settings.get("role_permissions", []), general_settings=general_settings,
key="models",
) )
if role_based_permissions is None:
return None
for role_based_permission in role_based_permissions:
if role_based_permission["role"] == rbac_role:
return role_based_permission["models"]
return None def get_role_based_routes(
rbac_role: RBAC_ROLES,
general_settings: dict,
) -> Optional[List[str]]:
"""
Get the routes allowed for a user role.
"""
return _get_role_based_permissions(
rbac_role=rbac_role,
general_settings=general_settings,
key="routes",
)
async def _get_fuzzy_user_object( async def _get_fuzzy_user_object(

View file

@ -35,11 +35,13 @@ from litellm.proxy._types import (
from litellm.proxy.utils import PrismaClient, ProxyLogging from litellm.proxy.utils import PrismaClient, ProxyLogging
from .auth_checks import ( from .auth_checks import (
_allowed_routes_check,
allowed_routes_check, allowed_routes_check,
get_actual_routes, get_actual_routes,
get_end_user_object, get_end_user_object,
get_org_object, get_org_object,
get_role_based_models, get_role_based_models,
get_role_based_routes,
get_team_object, get_team_object,
get_user_object, get_user_object,
) )
@ -78,6 +80,37 @@ class JWTHandler:
parts = token.split(".") parts = token.split(".")
return len(parts) == 3 return len(parts) == 3
def _rbac_role_from_role_mapping(self, token: dict) -> Optional[RBAC_ROLES]:
"""
Returns the RBAC role the token 'belongs' to based on role mappings.
Args:
token (dict): The JWT token containing role information
Returns:
Optional[RBAC_ROLES]: The mapped internal RBAC role if a mapping exists,
None otherwise
Note:
The function handles both single string roles and lists of roles from the JWT.
If multiple mappings match the JWT roles, the first matching mapping is returned.
"""
if self.litellm_jwtauth.role_mappings is None:
return None
jwt_role = self.get_jwt_role(token=token, default_value=None)
if not jwt_role:
return None
jwt_role_set = set(jwt_role)
for role_mapping in self.litellm_jwtauth.role_mappings:
# Check if the mapping role matches any of the JWT roles
if role_mapping.role in jwt_role_set:
return role_mapping.internal_role
return None
def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]: def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]:
""" """
Returns the RBAC role the token 'belongs' to. Returns the RBAC role the token 'belongs' to.
@ -109,6 +142,8 @@ class JWTHandler:
user_roles=user_roles user_roles=user_roles
): ):
return LitellmUserRoles.INTERNAL_USER return LitellmUserRoles.INTERNAL_USER
elif rbac_role := self._rbac_role_from_role_mapping(token=token):
return rbac_role
return None return None
@ -212,6 +247,29 @@ class JWTHandler:
user_roles = default_value user_roles = default_value
return user_roles return user_roles
def get_jwt_role(
self, token: dict, default_value: Optional[List[str]]
) -> Optional[List[str]]:
"""
Generic implementation of `get_user_roles` that can be used for both user and team roles.
Returns the jwt role from the token.
Set via 'roles_jwt_field' in the config.
"""
try:
if self.litellm_jwtauth.roles_jwt_field is not None:
user_roles = get_nested_value(
data=token,
key_path=self.litellm_jwtauth.roles_jwt_field,
default=default_value,
)
else:
user_roles = default_value
except KeyError:
user_roles = default_value
return user_roles
def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool: def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool:
""" """
Returns the user role from the token. Returns the user role from the token.
@ -240,6 +298,16 @@ class JWTHandler:
user_email = default_value user_email = default_value
return user_email return user_email
def get_object_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
if self.litellm_jwtauth.object_id_jwt_field is not None:
object_id = token[self.litellm_jwtauth.object_id_jwt_field]
else:
object_id = default_value
except KeyError:
object_id = default_value
return object_id
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try: try:
if self.litellm_jwtauth.org_id_jwt_field is not None: if self.litellm_jwtauth.org_id_jwt_field is not None:
@ -423,6 +491,35 @@ class JWTHandler:
class JWTAuthManager: class JWTAuthManager:
"""Manages JWT authentication and authorization operations""" """Manages JWT authentication and authorization operations"""
@staticmethod
def can_rbac_role_call_route(
rbac_role: RBAC_ROLES,
general_settings: dict,
route: str,
) -> Literal[True]:
"""
Checks if user is allowed to access the route, based on their role.
"""
role_based_routes = get_role_based_routes(
rbac_role=rbac_role, general_settings=general_settings
)
if role_based_routes is None or route is None:
return True
is_allowed = _allowed_routes_check(
user_route=route,
allowed_routes=role_based_routes,
)
if not is_allowed:
raise HTTPException(
status_code=403,
detail=f"Role={rbac_role} not allowed to call route={route}. Allowed routes={role_based_routes}",
)
return True
@staticmethod @staticmethod
def can_rbac_role_call_model( def can_rbac_role_call_model(
rbac_role: RBAC_ROLES, rbac_role: RBAC_ROLES,
@ -441,7 +538,7 @@ class JWTAuthManager:
if model not in role_based_models: if model not in role_based_models:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail=f"User role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}", detail=f"Role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
) )
return True return True
@ -452,10 +549,11 @@ class JWTAuthManager:
jwt_valid_token: dict, jwt_valid_token: dict,
general_settings: dict, general_settings: dict,
request_data: dict, request_data: dict,
route: str,
rbac_role: Optional[RBAC_ROLES],
) -> None: ) -> None:
"""Validate RBAC role and model access permissions""" """Validate RBAC role and model access permissions"""
if jwt_handler.litellm_jwtauth.enforce_rbac is True: if jwt_handler.litellm_jwtauth.enforce_rbac is True:
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
if rbac_role is None: if rbac_role is None:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
@ -466,6 +564,11 @@ class JWTAuthManager:
general_settings=general_settings, general_settings=general_settings,
model=request_data.get("model"), model=request_data.get("model"),
) )
JWTAuthManager.can_rbac_role_call_route(
rbac_role=rbac_role,
general_settings=general_settings,
route=route,
)
@staticmethod @staticmethod
async def check_admin_access( async def check_admin_access(
@ -685,6 +788,21 @@ class JWTAuthManager:
return user_object, org_object, end_user_object return user_object, org_object, end_user_object
@staticmethod
def validate_object_id(
user_id: Optional[str],
team_id: Optional[str],
enforce_rbac: bool,
is_proxy_admin: bool,
) -> Literal[True]:
"""If enforce_rbac is true, validate that a valid rbac id is returned for spend tracking"""
if enforce_rbac and not is_proxy_admin and not user_id and not team_id:
raise HTTPException(
status_code=403,
detail="No user or team id found in token. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.",
)
return True
@staticmethod @staticmethod
async def auth_builder( async def auth_builder(
api_key: str, api_key: str,
@ -701,10 +819,18 @@ class JWTAuthManager:
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key) jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
# Check RBAC # Check RBAC
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
await JWTAuthManager.check_rbac_role( await JWTAuthManager.check_rbac_role(
jwt_handler, jwt_valid_token, general_settings, request_data jwt_handler,
jwt_valid_token,
general_settings,
request_data,
route,
rbac_role,
) )
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
# Get basic user info # Get basic user info
scopes = jwt_handler.get_scopes(token=jwt_valid_token) scopes = jwt_handler.get_scopes(token=jwt_valid_token)
user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info( user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info(
@ -716,6 +842,16 @@ class JWTAuthManager:
end_user_id = jwt_handler.get_end_user_id( end_user_id = jwt_handler.get_end_user_id(
token=jwt_valid_token, default_value=None token=jwt_valid_token, default_value=None
) )
team_id: Optional[str] = None
team_object: Optional[LiteLLM_TeamTable] = None
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
if rbac_role and object_id:
if rbac_role == LitellmUserRoles.TEAM:
team_id = object_id
elif rbac_role == LitellmUserRoles.INTERNAL_USER:
user_id = object_id
# Check admin access # Check admin access
admin_result = await JWTAuthManager.check_admin_access( admin_result = await JWTAuthManager.check_admin_access(
@ -726,15 +862,20 @@ class JWTAuthManager:
# Get team with model access # Get team with model access
## SPECIFIC TEAM ID ## SPECIFIC TEAM ID
team_id, team_object = await JWTAuthManager.find_and_validate_specific_team_id(
jwt_handler, if not team_id:
jwt_valid_token, team_id, team_object = (
prisma_client, await JWTAuthManager.find_and_validate_specific_team_id(
user_api_key_cache, jwt_handler,
parent_otel_span, jwt_valid_token,
proxy_logging_obj, prisma_client,
) user_api_key_cache,
if not team_object: parent_otel_span,
proxy_logging_obj,
)
)
if not team_object and not team_id:
## CHECK USER GROUP ACCESS ## CHECK USER GROUP ACCESS
all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token) all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token)
team_id, team_object = await JWTAuthManager.find_team_with_model_access( team_id, team_object = await JWTAuthManager.find_team_with_model_access(
@ -762,6 +903,14 @@ class JWTAuthManager:
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,
) )
# Validate that a valid rbac id is returned for spend tracking
JWTAuthManager.validate_object_id(
user_id=user_id,
team_id=team_id,
enforce_rbac=general_settings.get("enforce_rbac", False),
is_proxy_admin=False,
)
return JWTAuthBuilderResult( return JWTAuthBuilderResult(
is_proxy_admin=False, is_proxy_admin=False,
team_id=team_id, team_id=team_id,

View file

@ -2093,6 +2093,14 @@ class ProxyConfig:
health_check_interval = general_settings.get("health_check_interval", 300) health_check_interval = general_settings.get("health_check_interval", 300)
health_check_details = general_settings.get("health_check_details", True) health_check_details = general_settings.get("health_check_details", True)
### RBAC ###
rbac_role_permissions = general_settings.get("role_permissions", None)
if rbac_role_permissions is not None:
general_settings["role_permissions"] = [ # validate role permissions
RoleBasedPermissions(**role_permission)
for role_permission in rbac_role_permissions
]
## check if user has set a premium feature in general_settings ## check if user has set a premium feature in general_settings
if ( if (
general_settings.get("enforced_params") is not None general_settings.get("enforced_params") is not None

View file

@ -468,7 +468,7 @@ class BaseLLMChatTest(ABC):
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
], ],
) )
@pytest.mark.flaky(retries=4, delay=1) @pytest.mark.flaky(retries=4, delay=2)
def test_image_url(self, detail, image_url): def test_image_url(self, detail, image_url):
litellm.set_verbose = True litellm.set_verbose = True
from litellm.utils import supports_vision from litellm.utils import supports_vision
@ -515,9 +515,13 @@ class BaseLLMChatTest(ABC):
], ],
} }
] ]
response = self.completion_function( try:
**base_completion_call_args, messages=messages response = self.completion_function(
) **base_completion_call_args, messages=messages
)
except litellm.InternalServerError:
pytest.skip("Model is overloaded")
assert response is not None assert response is not None
@pytest.mark.flaky(retries=4, delay=1) @pytest.mark.flaky(retries=4, delay=1)

View file

@ -21,7 +21,7 @@ from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from fastapi import Request from fastapi import Request, HTTPException
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from fastapi.responses import Response from fastapi.responses import Response
import litellm import litellm
@ -1164,3 +1164,22 @@ async def test_end_user_jwt_auth(monkeypatch):
mock_client.call_args.kwargs[ mock_client.call_args.kwargs[
"end_user_id" "end_user_id"
] == "81b3e52a-67a6-4efb-9645-70527e101479" ] == "81b3e52a-67a6-4efb-9645-70527e101479"
def test_can_rbac_role_call_route():
from litellm.proxy.auth.handle_jwt import JWTAuthManager
from litellm.proxy._types import RoleBasedPermissions
from litellm.proxy._types import LitellmUserRoles
with pytest.raises(HTTPException):
JWTAuthManager.can_rbac_role_call_route(
rbac_role=LitellmUserRoles.TEAM,
general_settings={
"role_permissions": [
RoleBasedPermissions(
role=LitellmUserRoles.TEAM, routes=["/v1/chat/completions"]
)
]
},
route="/v1/embeddings",
)