mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
JWT Auth - enforce_rbac
support + UI team view, spend calc fix (#7863)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s
* fix(user_dashboard.tsx): fix spend calculation when team selected sum all team keys, not user keys * docs(admin_ui_sso.md): fix docs tabbing * feat(user_api_key_auth.py): introduce new 'enforce_rbac' param on jwt auth allows proxy admin to prevent any unmapped yet authenticated jwt tokens from calling proxy Fixes https://github.com/BerriAI/litellm/issues/6793 * test: more unit testing + refactoring * fix: fix returning id when obj not found in db * fix(user_api_key_auth.py): add end user id tracking from jwt auth * docs(token_auth.md): add doc on rbac with JWTs * fix: fix unused params * test: remove old test
This commit is contained in:
parent
c306c2e0fc
commit
dca6904937
12 changed files with 449 additions and 197 deletions
|
@ -19,6 +19,7 @@ from fastapi.security.api_key import APIKeyHeader
|
|||
import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm._service_logger import ServiceLogging
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
_cache_key_object,
|
||||
|
@ -43,12 +44,13 @@ from litellm.proxy.auth.auth_utils import (
|
|||
route_in_additonal_public_routes,
|
||||
should_run_auth_on_pass_through_provider_route,
|
||||
)
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.auth.oauth2_check import check_oauth2_token
|
||||
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.auth.service_account_checks import service_account_checks
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.utils import _to_ns
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, _to_ns
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
|
||||
|
@ -226,6 +228,221 @@ def update_valid_token_with_end_user_params(
|
|||
return valid_token
|
||||
|
||||
|
||||
async def get_global_proxy_spend(
|
||||
litellm_proxy_admin_name: str,
|
||||
user_api_key_cache: DualCache,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
token: str,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
) -> Optional[float]:
|
||||
global_proxy_spend = None
|
||||
if litellm.max_budget > 0: # user set proxy max budget
|
||||
# check cache
|
||||
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name)
|
||||
)
|
||||
if global_proxy_spend is None and prisma_client is not None:
|
||||
# get from db
|
||||
sql_query = (
|
||||
"""SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
|
||||
)
|
||||
|
||||
response = await prisma_client.db.query_raw(query=sql_query)
|
||||
|
||||
global_proxy_spend = response[0]["total_spend"]
|
||||
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name),
|
||||
value=global_proxy_spend,
|
||||
)
|
||||
if global_proxy_spend is not None:
|
||||
user_info = CallInfo(
|
||||
user_id=litellm_proxy_admin_name,
|
||||
max_budget=litellm.max_budget,
|
||||
spend=global_proxy_spend,
|
||||
token=token,
|
||||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.budget_alerts(
|
||||
type="proxy_budget",
|
||||
user_info=user_info,
|
||||
)
|
||||
)
|
||||
return global_proxy_spend
|
||||
|
||||
|
||||
def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
|
||||
is_admin = jwt_handler.is_admin(scopes=scopes)
|
||||
if is_admin:
|
||||
return LitellmUserRoles.PROXY_ADMIN
|
||||
else:
|
||||
return LitellmUserRoles.TEAM
|
||||
|
||||
|
||||
async def _jwt_auth_user_api_key_auth_builder(
|
||||
api_key: str,
|
||||
jwt_handler: JWTHandler,
|
||||
route: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
) -> JWTAuthBuilderResult:
|
||||
|
||||
# check if valid token
|
||||
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
|
||||
|
||||
# check if unmatched token and enforce_rbac is true
|
||||
if (
|
||||
jwt_handler.litellm_jwtauth.enforce_rbac is True
|
||||
and jwt_handler.get_rbac_role(token=jwt_valid_token) is None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org",
|
||||
)
|
||||
# get scopes
|
||||
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||
|
||||
# [OPTIONAL] allowed user email domains
|
||||
valid_user_email: Optional[bool] = None
|
||||
user_email: Optional[str] = None
|
||||
if jwt_handler.is_enforced_email_domain():
|
||||
"""
|
||||
if 'allowed_email_subdomains' is set,
|
||||
|
||||
- checks if token contains 'email' field
|
||||
- checks if 'email' is from an allowed domain
|
||||
"""
|
||||
user_email = jwt_handler.get_user_email(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
if user_email is None:
|
||||
valid_user_email = False
|
||||
else:
|
||||
valid_user_email = jwt_handler.is_allowed_domain(user_email=user_email)
|
||||
|
||||
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
|
||||
user_object = None
|
||||
user_id = jwt_handler.get_user_id(token=jwt_valid_token, default_value=user_email)
|
||||
|
||||
# get org id
|
||||
org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None)
|
||||
# get team id
|
||||
team_id = jwt_handler.get_team_id(token=jwt_valid_token, default_value=None)
|
||||
# get end user id
|
||||
end_user_id = jwt_handler.get_end_user_id(token=jwt_valid_token, default_value=None)
|
||||
|
||||
# check if admin
|
||||
is_admin = jwt_handler.is_admin(scopes=scopes)
|
||||
# if admin return
|
||||
if is_admin:
|
||||
# check allowed admin routes
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
)
|
||||
if is_allowed:
|
||||
return JWTAuthBuilderResult(
|
||||
is_proxy_admin=True,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
org_object=None,
|
||||
token=api_key,
|
||||
team_id=team_id,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
org_id=org_id,
|
||||
)
|
||||
else:
|
||||
allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
|
||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||
raise Exception(
|
||||
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||
)
|
||||
|
||||
if team_id is None and jwt_handler.is_required_team_id() is True:
|
||||
raise Exception(
|
||||
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
||||
)
|
||||
|
||||
team_object: Optional[LiteLLM_TeamTable] = None
|
||||
if team_id is not None:
|
||||
# check allowed team routes
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role=LitellmUserRoles.TEAM,
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
)
|
||||
if is_allowed is False:
|
||||
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
|
||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||
raise Exception(
|
||||
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||
)
|
||||
|
||||
# check if team in db
|
||||
team_object = await get_team_object(
|
||||
team_id=team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
|
||||
|
||||
org_object: Optional[LiteLLM_OrganizationTable] = None
|
||||
if org_id is not None:
|
||||
org_object = await get_org_object(
|
||||
org_id=org_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if user_id is not None:
|
||||
# get the user object
|
||||
user_object = await get_user_object(
|
||||
user_id=user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
user_id_upsert=jwt_handler.is_upsert_user_id(
|
||||
valid_user_email=valid_user_email
|
||||
),
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
|
||||
end_user_object = None
|
||||
|
||||
if end_user_id is not None:
|
||||
# get the end-user object
|
||||
end_user_object = await get_end_user_object(
|
||||
end_user_id=end_user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
return {
|
||||
"is_proxy_admin": False,
|
||||
"team_id": team_id,
|
||||
"team_object": team_object,
|
||||
"user_id": user_id,
|
||||
"user_object": user_object,
|
||||
"org_id": org_id,
|
||||
"org_object": org_object,
|
||||
"end_user_id": end_user_id,
|
||||
"end_user_object": end_user_object,
|
||||
"token": api_key,
|
||||
}
|
||||
|
||||
|
||||
async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
request: Request,
|
||||
api_key: str,
|
||||
|
@ -361,164 +578,39 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
is_jwt = jwt_handler.is_jwt(token=api_key)
|
||||
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
|
||||
if is_jwt:
|
||||
# check if valid token
|
||||
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
|
||||
# get scopes
|
||||
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||
result = await _jwt_auth_user_api_key_auth_builder(
|
||||
api_key=api_key,
|
||||
jwt_handler=jwt_handler,
|
||||
route=route,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
# check if admin
|
||||
is_admin = jwt_handler.is_admin(scopes=scopes)
|
||||
# if admin return
|
||||
if is_admin:
|
||||
# check allowed admin routes
|
||||
is_allowed = allowed_routes_check(
|
||||
is_proxy_admin = result["is_proxy_admin"]
|
||||
team_id = result["team_id"]
|
||||
team_object = result["team_object"]
|
||||
user_id = result["user_id"]
|
||||
user_object = result["user_object"]
|
||||
end_user_id = result["end_user_id"]
|
||||
end_user_object = result["end_user_object"]
|
||||
org_id = result["org_id"]
|
||||
token = result["token"]
|
||||
|
||||
global_proxy_spend = await get_global_proxy_spend(
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
prisma_client=prisma_client,
|
||||
token=token,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if is_proxy_admin:
|
||||
return UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
)
|
||||
if is_allowed:
|
||||
return UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
else:
|
||||
allowed_routes: List[Any] = (
|
||||
jwt_handler.litellm_jwtauth.admin_allowed_routes
|
||||
)
|
||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||
raise Exception(
|
||||
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||
)
|
||||
|
||||
# get team id
|
||||
team_id = jwt_handler.get_team_id(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
|
||||
if team_id is None and jwt_handler.is_required_team_id() is True:
|
||||
raise Exception(
|
||||
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
||||
)
|
||||
|
||||
team_object: Optional[LiteLLM_TeamTable] = None
|
||||
if team_id is not None:
|
||||
# check allowed team routes
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role=LitellmUserRoles.TEAM,
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
)
|
||||
if is_allowed is False:
|
||||
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
|
||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||
raise Exception(
|
||||
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||
)
|
||||
|
||||
# check if team in db
|
||||
team_object = await get_team_object(
|
||||
team_id=team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
|
||||
org_id = jwt_handler.get_org_id(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
if org_id is not None:
|
||||
_ = await get_org_object(
|
||||
org_id=org_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# [OPTIONAL] allowed user email domains
|
||||
valid_user_email: Optional[bool] = None
|
||||
user_email: Optional[str] = None
|
||||
if jwt_handler.is_enforced_email_domain():
|
||||
"""
|
||||
if 'allowed_email_subdomains' is set,
|
||||
|
||||
- checks if token contains 'email' field
|
||||
- checks if 'email' is from an allowed domain
|
||||
"""
|
||||
user_email = jwt_handler.get_user_email(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
if user_email is None:
|
||||
valid_user_email = False
|
||||
else:
|
||||
valid_user_email = jwt_handler.is_allowed_domain(
|
||||
user_email=user_email
|
||||
)
|
||||
|
||||
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
|
||||
user_object = None
|
||||
user_id = jwt_handler.get_user_id(
|
||||
token=jwt_valid_token, default_value=user_email
|
||||
)
|
||||
if user_id is not None:
|
||||
# get the user object
|
||||
user_object = await get_user_object(
|
||||
user_id=user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
user_id_upsert=jwt_handler.is_upsert_user_id(
|
||||
valid_user_email=valid_user_email
|
||||
),
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
|
||||
end_user_object = None
|
||||
end_user_id = jwt_handler.get_end_user_id(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
if end_user_id is not None:
|
||||
# get the end-user object
|
||||
end_user_object = await get_end_user_object(
|
||||
end_user_id=end_user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
global_proxy_spend = None
|
||||
if litellm.max_budget > 0: # user set proxy max budget
|
||||
# check cache
|
||||
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name)
|
||||
)
|
||||
if global_proxy_spend is None and prisma_client is not None:
|
||||
# get from db
|
||||
sql_query = """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
|
||||
|
||||
response = await prisma_client.db.query_raw(query=sql_query)
|
||||
|
||||
global_proxy_spend = response[0]["total_spend"]
|
||||
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name),
|
||||
value=global_proxy_spend,
|
||||
)
|
||||
if global_proxy_spend is not None:
|
||||
user_info = CallInfo(
|
||||
user_id=litellm_proxy_admin_name,
|
||||
max_budget=litellm.max_budget,
|
||||
spend=global_proxy_spend,
|
||||
token=jwt_valid_token["token"],
|
||||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.budget_alerts(
|
||||
type="proxy_budget",
|
||||
user_info=user_info,
|
||||
)
|
||||
)
|
||||
# run through common checks
|
||||
_ = common_checks(
|
||||
request_body=request_data,
|
||||
|
@ -534,7 +626,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
# return UserAPIKeyAuth object
|
||||
return UserAPIKeyAuth(
|
||||
api_key=None,
|
||||
team_id=team_object.team_id if team_object is not None else None,
|
||||
team_id=team_id,
|
||||
team_tpm_limit=(
|
||||
team_object.tpm_limit if team_object is not None else None
|
||||
),
|
||||
|
@ -548,6 +640,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
parent_otel_span=parent_otel_span,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
|
||||
#### ELSE ####
|
||||
## CHECK PASS-THROUGH ENDPOINTS ##
|
||||
is_mapped_pass_through_route: bool = False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue