JWT Auth - enforce_rbac support + UI team view, spend calc fix (#7863)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s

* fix(user_dashboard.tsx): fix spend calculation when team selected

sum all team keys, not user keys

* docs(admin_ui_sso.md): fix docs tabbing

* feat(user_api_key_auth.py): introduce new 'enforce_rbac' param on jwt auth

allows proxy admin to prevent any unmapped yet authenticated jwt tokens from calling proxy

Fixes https://github.com/BerriAI/litellm/issues/6793

* test: more unit testing + refactoring

* fix: fix returning id when obj not found in db

* fix(user_api_key_auth.py): add end user id tracking from jwt auth

* docs(token_auth.md): add doc on rbac with JWTs

* fix: fix unused params

* test: remove old test
This commit is contained in:
Krish Dholakia 2025-01-19 21:28:55 -08:00 committed by GitHub
parent c306c2e0fc
commit dca6904937
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 449 additions and 197 deletions

View file

@ -1,3 +1,7 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# ✨ SSO for Admin UI
:::info

View file

@ -114,7 +114,7 @@ general_settings:
admin_jwt_scope: "litellm-proxy-admin"
```
## Advanced - Spend Tracking (End-Users / Internal Users / Team / Org)
## Tracking End-Users / Internal Users / Team / Org
Set the field in the jwt token, which corresponds to a litellm user / team / org.
@ -156,6 +156,33 @@ scope: ["litellm-proxy-admin",...]
scope: "litellm-proxy-admin ..."
```
## Enforce Role-Based Access Control (RBAC)
Reject a JWT token if it's valid but doesn't have the required scopes / fields.
Only tokens which with valid Admin (`admin_jwt_scope`), User (`user_id_jwt_field`), Team (`team_id_jwt_field`) are allowed.
```yaml
general_settings:
master_key: sk-1234
enable_jwt_auth: True
litellm_jwtauth:
admin_jwt_scope: "litellm_proxy_endpoints_access"
admin_allowed_routes:
- openai_routes
- info_routes
public_key_ttl: 600
enforce_rbac: true # 👈 Enforce RBAC
```
Expected Scope in JWT:
```
{
"scope": "litellm_proxy_endpoints_access"
}
```
## Advanced - Allowed Routes
Configure which routes a JWT can access via the config.

View file

@ -6587,6 +6587,27 @@
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"stability.sd3-5-large-v1:0": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.08,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"stability.stable-image-core-v1:0": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.04,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"stability.stable-image-core-v1:1": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.04,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"stability.stable-image-ultra-v1:0": {
"max_tokens": 77,
"max_input_tokens": 77,
@ -6594,6 +6615,13 @@
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"stability.stable-image-ultra-v1:1": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.14,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"sagemaker/meta-textgeneration-llama-2-7b": {
"max_tokens": 4096,
"max_input_tokens": 4096,

View file

@ -416,6 +416,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
- user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy.
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
- public_allowed_routes: list of allowed routes for authenticated but unknown litellm role jwt tokens.
- enforce_rbac: If true, enforce RBAC for all routes.
See `auth_checks.py` for the specific routes
"""
@ -446,6 +448,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
)
end_user_id_jwt_field: Optional[str] = None
public_key_ttl: float = 600
public_allowed_routes: List[str] = ["public_routes"]
enforce_rbac: bool = False
def __init__(self, **kwargs: Any) -> None:
# get the attribute names for this Pydantic model
@ -2284,6 +2288,19 @@ class ProxyStateVariables(TypedDict):
UI_TEAM_ID = "litellm-dashboard"
class JWTAuthBuilderResult(TypedDict):
is_proxy_admin: bool
team_object: Optional[LiteLLM_TeamTable]
user_object: Optional[LiteLLM_UserTable]
end_user_object: Optional[LiteLLM_EndUserTable]
org_object: Optional[LiteLLM_OrganizationTable]
token: str
team_id: Optional[str]
user_id: Optional[str]
end_user_id: Optional[str]
org_id: Optional[str]
class ClientSideFallbackModel(TypedDict, total=False):
"""
Dictionary passed when client configuring input

View file

@ -9,7 +9,6 @@ Run checks for:
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
import inspect
import time
import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
@ -22,7 +21,6 @@ from litellm.caching.caching import DualCache
from litellm.caching.dual_cache import LimitedSizeOrderedDict
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
CommonProxyErrors,
LiteLLM_EndUserTable,
LiteLLM_JWTAuth,
LiteLLM_OrganizationTable,
@ -55,33 +53,6 @@ db_cache_expiry = 5 # refresh every 5s
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
def _allowed_import_check() -> bool:
from litellm.proxy.auth.user_api_key_auth import _user_api_key_auth_builder
# Get the calling frame
caller_frame = inspect.stack()[2]
caller_function = caller_frame.function
caller_function_callable = caller_frame.frame.f_globals.get(caller_function)
allowed_function = "_user_api_key_auth_builder"
allowed_signature = inspect.signature(_user_api_key_auth_builder)
if caller_function_callable is None or not callable(caller_function_callable):
raise Exception(f"Caller function {caller_function} is not callable")
caller_signature = inspect.signature(caller_function_callable)
if caller_signature != allowed_signature:
raise TypeError(
f"The function '{caller_function}' does not match the required signature of 'user_api_key_auth'. {CommonProxyErrors.not_premium_user.value}"
)
# Check if the caller module is allowed
if caller_function != allowed_function:
raise ImportError(
f"This function can only be imported by '{allowed_function}'. {CommonProxyErrors.not_premium_user.value}"
)
return True
def common_checks( # noqa: PLR0915
request_body: dict,
team_object: Optional[LiteLLM_TeamTable],
@ -106,7 +77,6 @@ def common_checks( # noqa: PLR0915
9. Check if request body is safe
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
"""
_allowed_import_check()
_model = request_body.get("model", None)
if team_object is not None and team_object.blocked is True:
raise Exception(
@ -844,7 +814,7 @@ async def get_org_object(
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
):
) -> Optional[LiteLLM_OrganizationTable]:
"""
- Check if org id in proxy Org Table
- if valid, return LiteLLM_OrganizationTable object
@ -859,7 +829,7 @@ async def get_org_object(
cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id))
if cached_org_obj is not None:
if isinstance(cached_org_obj, dict):
return cached_org_obj
return LiteLLM_OrganizationTable(**cached_org_obj)
elif isinstance(cached_org_obj, LiteLLM_OrganizationTable):
return cached_org_obj
# else, check db

View file

@ -17,7 +17,12 @@ from cryptography.hazmat.primitives import serialization
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import JWKKeyValue, JWTKeyItem, LiteLLM_JWTAuth
from litellm.proxy._types import (
JWKKeyValue,
JWTKeyItem,
LiteLLM_JWTAuth,
LitellmUserRoles,
)
from litellm.proxy.utils import PrismaClient
@ -54,6 +59,34 @@ class JWTHandler:
parts = token.split(".")
return len(parts) == 3
def get_rbac_role(self, token: dict) -> Optional[LitellmUserRoles]:
"""
Returns the RBAC role the token 'belongs' to.
RBAC roles allowed to make requests:
- PROXY_ADMIN: can make requests to all routes
- TEAM: can make requests to routes associated with a team
- INTERNAL_USER: can make requests to routes associated with a user
Resolves: https://github.com/BerriAI/litellm/issues/6793
Returns:
- PROXY_ADMIN: if token is admin
- TEAM: if token is associated with a team
- INTERNAL_USER: if token is associated with a user
- None: if token is not associated with a team or user
"""
scopes = self.get_scopes(token=token)
is_admin = self.is_admin(scopes=scopes)
if is_admin:
return LitellmUserRoles.PROXY_ADMIN
elif self.get_team_id(token=token, default_value=None) is not None:
return LitellmUserRoles.TEAM
elif self.get_user_id(token=token, default_value=None) is not None:
return LitellmUserRoles.INTERNAL_USER
return None
def is_admin(self, scopes: list) -> bool:
if self.litellm_jwtauth.admin_jwt_scope in scopes:
return True
@ -68,12 +101,14 @@ class JWTHandler:
self, token: dict, default_value: Optional[str]
) -> Optional[str]:
try:
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
else:
user_id = None
except KeyError:
user_id = default_value
return user_id
def is_required_team_id(self) -> bool:
@ -169,6 +204,7 @@ class JWTHandler:
return scopes
async def get_public_key(self, kid: Optional[str]) -> dict:
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
if keys_url is None:

View file

@ -19,6 +19,7 @@ from fastapi.security.api_key import APIKeyHeader
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm._service_logger import ServiceLogging
from litellm.caching import DualCache
from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import (
_cache_key_object,
@ -43,12 +44,13 @@ from litellm.proxy.auth.auth_utils import (
route_in_additonal_public_routes,
should_run_auth_on_pass_through_provider_route,
)
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.oauth2_check import check_oauth2_token
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.auth.service_account_checks import service_account_checks
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.utils import _to_ns
from litellm.proxy.utils import PrismaClient, ProxyLogging, _to_ns
from litellm.types.services import ServiceTypes
user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
@ -226,6 +228,221 @@ def update_valid_token_with_end_user_params(
return valid_token
async def get_global_proxy_spend(
litellm_proxy_admin_name: str,
user_api_key_cache: DualCache,
prisma_client: Optional[PrismaClient],
token: str,
proxy_logging_obj: ProxyLogging,
) -> Optional[float]:
global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget
# check cache
global_proxy_spend = await user_api_key_cache.async_get_cache(
key="{}:spend".format(litellm_proxy_admin_name)
)
if global_proxy_spend is None and prisma_client is not None:
# get from db
sql_query = (
"""SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
)
response = await prisma_client.db.query_raw(query=sql_query)
global_proxy_spend = response[0]["total_spend"]
await user_api_key_cache.async_set_cache(
key="{}:spend".format(litellm_proxy_admin_name),
value=global_proxy_spend,
)
if global_proxy_spend is not None:
user_info = CallInfo(
user_id=litellm_proxy_admin_name,
max_budget=litellm.max_budget,
spend=global_proxy_spend,
token=token,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="proxy_budget",
user_info=user_info,
)
)
return global_proxy_spend
def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
is_admin = jwt_handler.is_admin(scopes=scopes)
if is_admin:
return LitellmUserRoles.PROXY_ADMIN
else:
return LitellmUserRoles.TEAM
async def _jwt_auth_user_api_key_auth_builder(
api_key: str,
jwt_handler: JWTHandler,
route: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span],
proxy_logging_obj: ProxyLogging,
) -> JWTAuthBuilderResult:
# check if valid token
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
# check if unmatched token and enforce_rbac is true
if (
jwt_handler.litellm_jwtauth.enforce_rbac is True
and jwt_handler.get_rbac_role(token=jwt_valid_token) is None
):
raise HTTPException(
status_code=403,
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org",
)
# get scopes
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
# [OPTIONAL] allowed user email domains
valid_user_email: Optional[bool] = None
user_email: Optional[str] = None
if jwt_handler.is_enforced_email_domain():
"""
if 'allowed_email_subdomains' is set,
- checks if token contains 'email' field
- checks if 'email' is from an allowed domain
"""
user_email = jwt_handler.get_user_email(
token=jwt_valid_token, default_value=None
)
if user_email is None:
valid_user_email = False
else:
valid_user_email = jwt_handler.is_allowed_domain(user_email=user_email)
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
user_object = None
user_id = jwt_handler.get_user_id(token=jwt_valid_token, default_value=user_email)
# get org id
org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None)
# get team id
team_id = jwt_handler.get_team_id(token=jwt_valid_token, default_value=None)
# get end user id
end_user_id = jwt_handler.get_end_user_id(token=jwt_valid_token, default_value=None)
# check if admin
is_admin = jwt_handler.is_admin(scopes=scopes)
# if admin return
if is_admin:
# check allowed admin routes
is_allowed = allowed_routes_check(
user_role=LitellmUserRoles.PROXY_ADMIN,
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed:
return JWTAuthBuilderResult(
is_proxy_admin=True,
team_object=None,
user_object=None,
end_user_object=None,
org_object=None,
token=api_key,
team_id=team_id,
user_id=user_id,
end_user_id=end_user_id,
org_id=org_id,
)
else:
allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
if team_id is None and jwt_handler.is_required_team_id() is True:
raise Exception(
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
)
team_object: Optional[LiteLLM_TeamTable] = None
if team_id is not None:
# check allowed team routes
is_allowed = allowed_routes_check(
user_role=LitellmUserRoles.TEAM,
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed is False:
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
# check if team in db
team_object = await get_team_object(
team_id=team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
org_object: Optional[LiteLLM_OrganizationTable] = None
if org_id is not None:
org_object = await get_org_object(
org_id=org_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if user_id is not None:
# get the user object
user_object = await get_user_object(
user_id=user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
user_id_upsert=jwt_handler.is_upsert_user_id(
valid_user_email=valid_user_email
),
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
end_user_object = None
if end_user_id is not None:
# get the end-user object
end_user_object = await get_end_user_object(
end_user_id=end_user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
return {
"is_proxy_admin": False,
"team_id": team_id,
"team_object": team_object,
"user_id": user_id,
"user_object": user_object,
"org_id": org_id,
"org_object": org_object,
"end_user_id": end_user_id,
"end_user_object": end_user_object,
"token": api_key,
}
async def _user_api_key_auth_builder( # noqa: PLR0915
request: Request,
api_key: str,
@ -361,164 +578,39 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
is_jwt = jwt_handler.is_jwt(token=api_key)
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
if is_jwt:
# check if valid token
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
# get scopes
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
# check if admin
is_admin = jwt_handler.is_admin(scopes=scopes)
# if admin return
if is_admin:
# check allowed admin routes
is_allowed = allowed_routes_check(
user_role=LitellmUserRoles.PROXY_ADMIN,
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
result = await _jwt_auth_user_api_key_auth_builder(
api_key=api_key,
jwt_handler=jwt_handler,
route=route,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
parent_otel_span=parent_otel_span,
)
if is_allowed:
is_proxy_admin = result["is_proxy_admin"]
team_id = result["team_id"]
team_object = result["team_object"]
user_id = result["user_id"]
user_object = result["user_object"]
end_user_id = result["end_user_id"]
end_user_object = result["end_user_object"]
org_id = result["org_id"]
token = result["token"]
global_proxy_spend = await get_global_proxy_spend(
litellm_proxy_admin_name=litellm_proxy_admin_name,
user_api_key_cache=user_api_key_cache,
prisma_client=prisma_client,
token=token,
proxy_logging_obj=proxy_logging_obj,
)
if is_proxy_admin:
return UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
parent_otel_span=parent_otel_span,
)
else:
allowed_routes: List[Any] = (
jwt_handler.litellm_jwtauth.admin_allowed_routes
)
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
# get team id
team_id = jwt_handler.get_team_id(
token=jwt_valid_token, default_value=None
)
if team_id is None and jwt_handler.is_required_team_id() is True:
raise Exception(
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
)
team_object: Optional[LiteLLM_TeamTable] = None
if team_id is not None:
# check allowed team routes
is_allowed = allowed_routes_check(
user_role=LitellmUserRoles.TEAM,
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed is False:
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
# check if team in db
team_object = await get_team_object(
team_id=team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
org_id = jwt_handler.get_org_id(
token=jwt_valid_token, default_value=None
)
if org_id is not None:
_ = await get_org_object(
org_id=org_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# [OPTIONAL] allowed user email domains
valid_user_email: Optional[bool] = None
user_email: Optional[str] = None
if jwt_handler.is_enforced_email_domain():
"""
if 'allowed_email_subdomains' is set,
- checks if token contains 'email' field
- checks if 'email' is from an allowed domain
"""
user_email = jwt_handler.get_user_email(
token=jwt_valid_token, default_value=None
)
if user_email is None:
valid_user_email = False
else:
valid_user_email = jwt_handler.is_allowed_domain(
user_email=user_email
)
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
user_object = None
user_id = jwt_handler.get_user_id(
token=jwt_valid_token, default_value=user_email
)
if user_id is not None:
# get the user object
user_object = await get_user_object(
user_id=user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
user_id_upsert=jwt_handler.is_upsert_user_id(
valid_user_email=valid_user_email
),
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
end_user_object = None
end_user_id = jwt_handler.get_end_user_id(
token=jwt_valid_token, default_value=None
)
if end_user_id is not None:
# get the end-user object
end_user_object = await get_end_user_object(
end_user_id=end_user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget
# check cache
global_proxy_spend = await user_api_key_cache.async_get_cache(
key="{}:spend".format(litellm_proxy_admin_name)
)
if global_proxy_spend is None and prisma_client is not None:
# get from db
sql_query = """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
response = await prisma_client.db.query_raw(query=sql_query)
global_proxy_spend = response[0]["total_spend"]
await user_api_key_cache.async_set_cache(
key="{}:spend".format(litellm_proxy_admin_name),
value=global_proxy_spend,
)
if global_proxy_spend is not None:
user_info = CallInfo(
user_id=litellm_proxy_admin_name,
max_budget=litellm.max_budget,
spend=global_proxy_spend,
token=jwt_valid_token["token"],
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="proxy_budget",
user_info=user_info,
)
)
# run through common checks
_ = common_checks(
request_body=request_data,
@ -534,7 +626,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
# return UserAPIKeyAuth object
return UserAPIKeyAuth(
api_key=None,
team_id=team_object.team_id if team_object is not None else None,
team_id=team_id,
team_tpm_limit=(
team_object.tpm_limit if team_object is not None else None
),
@ -548,6 +640,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
parent_otel_span=parent_otel_span,
end_user_id=end_user_id,
)
#### ELSE ####
## CHECK PASS-THROUGH ENDPOINTS ##
is_mapped_pass_through_route: bool = False

View file

@ -1039,6 +1039,7 @@ async def test_end_user_jwt_auth(monkeypatch):
import json
monkeypatch.delenv("JWT_AUDIENCE", None)
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
jwt_handler = JWTHandler()
litellm_jwtauth = LiteLLM_JWTAuth(

View file

@ -794,3 +794,71 @@ async def test_user_api_key_auth_websocket():
assert (
mock_user_api_key_auth.call_args.kwargs["api_key"] == "Bearer some_api_key"
)
@pytest.mark.parametrize("enforce_rbac", [True, False])
@pytest.mark.asyncio
async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypatch):
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.user_api_key_auth import _jwt_auth_user_api_key_auth_builder
from unittest.mock import patch, Mock
from litellm.proxy._types import LiteLLM_JWTAuth
from litellm.caching import DualCache
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "my-fake-url")
monkeypatch.setenv("JWT_AUDIENCE", "api://LiteLLM_Proxy-dev")
local_cache = DualCache()
keys = [
{
"kty": "RSA",
"use": "sig",
"kid": "z1rsYHHJ9-8mggt4HsZu8BKkBPw",
"x5t": "z1rsYHHJ9-8mggt4HsZu8BKkBPw",
"n": "pOe4GbleFDT1u5ioOQjNMmhvkDVoVD9cBKvX7AlErtWA_D6wc1w1iwkd6arYVCPObZbAB4vLSXrlpBSOuP6VYnXw_cTgniv_c82ra-mfqCpM-SbqzZ3sVqlcE_bwxvci_4PrxAW4R85ok12NXyZ2371H3yGevabi35AlVm-bQ24azo1hLK_0DzB6TxsAIOTOcKfIugOfqP-B2R4vR4u6pYftS8MWcxegr9iJ5JNtubI1X2JHpxJhkRoMVwKFna2GXmtzdxLi3yS_GffVCKfTbFMhalbJS1lSmLqhmLZZL-lrQZ6fansTl1vcGcoxnzPTwBkZMks0iVV4yfym_gKBXQ",
"e": "AQAB",
"x5c": [
"MIIC/TCCAeWgAwIBAgIIQk8Qok6pfXkwDQYJKoZIhvcNAQELBQAwLTErMCkGA1UEAxMiYWNjb3VudHMuYWNjZXNzY29udHJvbC53aW5kb3dzLm5ldDAeFw0yNDExMjcwOTA0MzlaFw0yOTExMjcwOTA0MzlaMC0xKzApBgNVBAMTImFjY291bnRzLmFjY2Vzc2NvbnRyb2wud2luZG93cy5uZXQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCk57gZuV4UNPW7mKg5CM0yaG+QNWhUP1wEq9fsCUSu1YD8PrBzXDWLCR3pqthUI85tlsAHi8tJeuWkFI64/pVidfD9xOCeK/9zzatr6Z+oKkz5JurNnexWqVwT9vDG9yL/g+vEBbhHzmiTXY1fJnbfvUffIZ69puLfkCVWb5tDbhrOjWEsr/QPMHpPGwAg5M5wp8i6A5+o/4HZHi9Hi7qlh+1LwxZzF6Cv2Inkk225sjVfYkenEmGRGgxXAoWdrYZea3N3EuLfJL8Z99UIp9NsUyFqVslLWVKYuqGYtlkv6WtBnp9qexOXW9wZyjGfM9PAGRkySzSJVXjJ/Kb+AoFdAgMBAAGjITAfMB0GA1UdDgQWBBSTO5FmUwwGS+1CNqg2uNgjxUjFijANBgkqhkiG9w0BAQsFAAOCAQEAok04z0ICMEHGqDTzx6eD7vvJP8itJTCSz8JcZcGVJofJpViGF3bNnyeSPa7vNDYP1Ps9XBvw3/n2s+yynZ8EwFxMyxCZRCSbLv0N+cAbH3rmZqGcgMJszZVwcFUtXQPTe1ZRyHtEyOB+PVFH7K7obysRVO/cC6EGqIF3pYWzez/dtMaXRAkdTNlz0ko62WoA4eMPwUFCITjW/Jxfxl0BNUbo82PXXKhaeVJb+EgFG5b/pWWPswWmBoQhmD5G1UODvEACHRl/cHsPPqe4YE+6D1/wMno/xqqyGltnk8v0d4TpNcQMn9oM19V+OGgrzWOvvXhvnhqUIVGMsRlyBGNHAw=="
],
"cloud_instance_name": "microsoftonline.com",
"issuer": "https://login.microsoftonline.com/bdfd79b3-8401-47fb-a764-6e595c455b05/v2.0",
}
]
local_cache.set_cache(
key="litellm_jwt_auth_keys",
value=keys,
)
litellm_jwtauth = LiteLLM_JWTAuth(
**{
"admin_jwt_scope": "litellm_proxy_endpoints_access",
"admin_allowed_routes": ["openai_routes", "info_routes"],
"public_key_ttl": 600,
"enforce_rbac": enforce_rbac,
}
)
jwt_handler = JWTHandler()
jwt_handler.update_environment(
prisma_client=None,
user_api_key_cache=local_cache,
litellm_jwtauth=litellm_jwtauth,
leeway=10000000000000,
)
args = {
"api_key": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsIng1dCI6InoxcnNZSEhKOS04bWdndDRIc1p1OEJLa0JQdyIsImtpZCI6InoxcnNZSEhKOS04bWdndDRIc1p1OEJLa0JQdyJ9.eyJhdWQiOiJhcGk6Ly9MaXRlTExNX1Byb3h5LWRldiIsImlzcyI6Imh0dHBzOi8vc3RzLndpbmRvd3MubmV0L2JkZmQ3OWIzLTg0MDEtNDdmYi1hNzY0LTZlNTk1YzQ1NWIwNS8iLCJpYXQiOjE3MzcyNDE3ODEsIm5iZiI6MTczNzI0MTc4MSwiZXhwIjoxNzM3MjQ1NjgxLCJhaW8iOiJrMlJnWUpBNE5hZGg4MGJQdXlyRmxlV1o3dHZiQUE9PSIsImFwcGlkIjoiOGNjZjNkMDItMmNkNi00N2I5LTgxODUtMGVkYjI0YWJjZjY5IiwiYXBwaWRhY3IiOiIxIiwiaWRwIjoiaHR0cHM6Ly9zdHMud2luZG93cy5uZXQvYmRmZDc5YjMtODQwMS00N2ZiLWE3NjQtNmU1OTVjNDU1YjA1LyIsIm9pZCI6IjQ0YTg3YTYzLWFiNTUtNDc4NS1iMmFmLTMzNjllZWM4ZTEzOSIsInJoIjoiMS5BYjBBczNuOXZRR0UtMGVuWkc1WlhFVmJCY0VDbkl6NHJxaE9wZ2E0UGZSZjBsbTlBQUM5QUEuIiwic3ViIjoiNDRhODdhNjMtYWI1NS00Nzg1LWIyYWYtMzM2OWVlYzhlMTM5IiwidGlkIjoiYmRmZDc5YjMtODQwMS00N2ZiLWE3NjQtNmU1OTVjNDU1YjA1IiwidXRpIjoiY3ltNVhlcmhIMHVMSlNZU1JyQmhBQSIsInZlciI6IjEuMCJ9.UooJjM9pS-wgYsExqgHdrYyQhp7NbwAsr7au9dWJaLpsufXeyHJSg-Xd5VJ4RsDVJiDes3jkC7WeoAiaCfzEHpAum-p_aqqLYXf1QIYbi1hLC0m7y_klFcqMp11WbDa9TSTvg-o8q3x2Y5su8X23ymlFih4OP17b7JA6a4_2MybU5QkCEW1tQK6VspuuXzeDHvbfGeGYcIptHFyfttHMHHXRtX1o9bX7gOR_dwFITAXD18T4ZdAN_0y6f1OtVF9TMWQhMXhKU8ahn8TSg_CXmPl9T_1gV3ZWLvVtcdVrWs82fDz3-2lEw28z4bQEr1Z5xoAz7srhx1WEBu_ioAcQiA",
"jwt_handler": jwt_handler,
"route": "/v1/chat/completions",
"prisma_client": None,
"user_api_key_cache": Mock(),
"parent_otel_span": None,
"proxy_logging_obj": Mock(),
}
if enforce_rbac:
with pytest.raises(HTTPException):
await _jwt_auth_user_api_key_auth_builder(**args)
else:
await _jwt_auth_user_api_key_auth_builder(**args)

View file

@ -11,6 +11,7 @@ interface DashboardTeamProps {
setProxySettings: React.Dispatch<React.SetStateAction<ProxySettings | null>>;
userInfo: UserInfo | null;
accessToken: string | null;
setKeys: React.Dispatch<React.SetStateAction<any | null>>;
}
type TeamInterface = {
@ -27,7 +28,8 @@ const DashboardTeam: React.FC<DashboardTeamProps> = ({
proxySettings,
setProxySettings,
userInfo,
accessToken
accessToken,
setKeys
}) => {
console.log(`userInfo: ${JSON.stringify(userInfo)}`)
const defaultTeam: TeamInterface = {
@ -80,7 +82,10 @@ const DashboardTeam: React.FC<DashboardTeamProps> = ({
<SelectItem
key={index}
value={String(index)}
onClick={() => setSelectedTeam(team)}
onClick={() => {
setSelectedTeam(team);
setKeys(team["keys"]);
}}
>
{team["team_alias"]}
</SelectItem>

View file

@ -206,13 +206,15 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
setUserSpendData(response["user_info"]);
console.log(`userSpendData: ${JSON.stringify(userSpendData)}`)
setKeys(response["keys"]); // Assuming this is the correct path to your data
const teamsArray = [...response["teams"]];
if (teamsArray.length > 0) {
console.log(`response['teams']: ${teamsArray}`);
console.log(`response['teams']: ${JSON.stringify(teamsArray)}`);
setSelectedTeam(teamsArray[0]);
setKeys(teamsArray[0]["keys"]); // Assuming this is the correct path to your data
} else {
setSelectedTeam(defaultTeam);
setKeys(response["keys"]); // Assuming this is the correct path to your data
}
sessionStorage.setItem(
"userData" + userID,
@ -261,6 +263,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
selectedTeam.team_id !== null
) {
let sum = 0;
console.log(`keys: ${JSON.stringify(keys)}`)
for (const key of keys) {
if (
selectedTeam.hasOwnProperty("team_id") &&
@ -270,6 +273,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
sum += key.spend;
}
}
console.log(`sum: ${sum}`)
setTeamSpend(sum);
} else if (keys !== null) {
// sum the keys which don't have team-id set (default team)
@ -367,6 +371,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
setProxySettings={setProxySettings}
userInfo={userSpendData}
accessToken={accessToken}
setKeys={setKeys}
/>
</Col>
</Grid>

View file

@ -114,8 +114,6 @@ const ViewUserSpend: React.FC<ViewUserSpendProps> = ({ userID, userRole, accessT
fetchData();
}, [userRole, accessToken, userID]);
useEffect(() => {
if (userSpend !== null) {
setSpend(userSpend)