mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
JWT Auth - enforce_rbac
support + UI team view, spend calc fix (#7863)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s
* fix(user_dashboard.tsx): fix spend calculation when team selected sum all team keys, not user keys * docs(admin_ui_sso.md): fix docs tabbing * feat(user_api_key_auth.py): introduce new 'enforce_rbac' param on jwt auth allows proxy admin to prevent any unmapped yet authenticated jwt tokens from calling proxy Fixes https://github.com/BerriAI/litellm/issues/6793 * test: more unit testing + refactoring * fix: fix returning id when obj not found in db * fix(user_api_key_auth.py): add end user id tracking from jwt auth * docs(token_auth.md): add doc on rbac with JWTs * fix: fix unused params * test: remove old test
This commit is contained in:
parent
c306c2e0fc
commit
dca6904937
12 changed files with 449 additions and 197 deletions
|
@ -1,3 +1,7 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# ✨ SSO for Admin UI
|
||||
|
||||
:::info
|
||||
|
|
|
@ -114,7 +114,7 @@ general_settings:
|
|||
admin_jwt_scope: "litellm-proxy-admin"
|
||||
```
|
||||
|
||||
## Advanced - Spend Tracking (End-Users / Internal Users / Team / Org)
|
||||
## Tracking End-Users / Internal Users / Team / Org
|
||||
|
||||
Set the field in the jwt token, which corresponds to a litellm user / team / org.
|
||||
|
||||
|
@ -156,6 +156,33 @@ scope: ["litellm-proxy-admin",...]
|
|||
scope: "litellm-proxy-admin ..."
|
||||
```
|
||||
|
||||
## Enforce Role-Based Access Control (RBAC)
|
||||
|
||||
Reject a JWT token if it's valid but doesn't have the required scopes / fields.
|
||||
|
||||
Only tokens which with valid Admin (`admin_jwt_scope`), User (`user_id_jwt_field`), Team (`team_id_jwt_field`) are allowed.
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
enable_jwt_auth: True
|
||||
litellm_jwtauth:
|
||||
admin_jwt_scope: "litellm_proxy_endpoints_access"
|
||||
admin_allowed_routes:
|
||||
- openai_routes
|
||||
- info_routes
|
||||
public_key_ttl: 600
|
||||
enforce_rbac: true # 👈 Enforce RBAC
|
||||
```
|
||||
|
||||
Expected Scope in JWT:
|
||||
|
||||
```
|
||||
{
|
||||
"scope": "litellm_proxy_endpoints_access"
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced - Allowed Routes
|
||||
|
||||
Configure which routes a JWT can access via the config.
|
||||
|
|
|
@ -6587,6 +6587,27 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"stability.sd3-5-large-v1:0": {
|
||||
"max_tokens": 77,
|
||||
"max_input_tokens": 77,
|
||||
"output_cost_per_image": 0.08,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"stability.stable-image-core-v1:0": {
|
||||
"max_tokens": 77,
|
||||
"max_input_tokens": 77,
|
||||
"output_cost_per_image": 0.04,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"stability.stable-image-core-v1:1": {
|
||||
"max_tokens": 77,
|
||||
"max_input_tokens": 77,
|
||||
"output_cost_per_image": 0.04,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"stability.stable-image-ultra-v1:0": {
|
||||
"max_tokens": 77,
|
||||
"max_input_tokens": 77,
|
||||
|
@ -6594,6 +6615,13 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"stability.stable-image-ultra-v1:1": {
|
||||
"max_tokens": 77,
|
||||
"max_input_tokens": 77,
|
||||
"output_cost_per_image": 0.14,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-7b": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
|
|
|
@ -416,6 +416,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
|||
- user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy.
|
||||
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
|
||||
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
|
||||
- public_allowed_routes: list of allowed routes for authenticated but unknown litellm role jwt tokens.
|
||||
- enforce_rbac: If true, enforce RBAC for all routes.
|
||||
|
||||
See `auth_checks.py` for the specific routes
|
||||
"""
|
||||
|
@ -446,6 +448,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
|||
)
|
||||
end_user_id_jwt_field: Optional[str] = None
|
||||
public_key_ttl: float = 600
|
||||
public_allowed_routes: List[str] = ["public_routes"]
|
||||
enforce_rbac: bool = False
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
# get the attribute names for this Pydantic model
|
||||
|
@ -2284,6 +2288,19 @@ class ProxyStateVariables(TypedDict):
|
|||
UI_TEAM_ID = "litellm-dashboard"
|
||||
|
||||
|
||||
|
||||
class JWTAuthBuilderResult(TypedDict):
|
||||
is_proxy_admin: bool
|
||||
team_object: Optional[LiteLLM_TeamTable]
|
||||
user_object: Optional[LiteLLM_UserTable]
|
||||
end_user_object: Optional[LiteLLM_EndUserTable]
|
||||
org_object: Optional[LiteLLM_OrganizationTable]
|
||||
token: str
|
||||
team_id: Optional[str]
|
||||
user_id: Optional[str]
|
||||
end_user_id: Optional[str]
|
||||
org_id: Optional[str]
|
||||
|
||||
class ClientSideFallbackModel(TypedDict, total=False):
|
||||
"""
|
||||
Dictionary passed when client configuring input
|
||||
|
|
|
@ -9,7 +9,6 @@ Run checks for:
|
|||
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import time
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||
|
@ -22,7 +21,6 @@ from litellm.caching.caching import DualCache
|
|||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||
from litellm.proxy._types import (
|
||||
DB_CONNECTION_ERROR_TYPES,
|
||||
CommonProxyErrors,
|
||||
LiteLLM_EndUserTable,
|
||||
LiteLLM_JWTAuth,
|
||||
LiteLLM_OrganizationTable,
|
||||
|
@ -55,33 +53,6 @@ db_cache_expiry = 5 # refresh every 5s
|
|||
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
||||
|
||||
|
||||
def _allowed_import_check() -> bool:
|
||||
from litellm.proxy.auth.user_api_key_auth import _user_api_key_auth_builder
|
||||
|
||||
# Get the calling frame
|
||||
caller_frame = inspect.stack()[2]
|
||||
caller_function = caller_frame.function
|
||||
caller_function_callable = caller_frame.frame.f_globals.get(caller_function)
|
||||
|
||||
allowed_function = "_user_api_key_auth_builder"
|
||||
allowed_signature = inspect.signature(_user_api_key_auth_builder)
|
||||
if caller_function_callable is None or not callable(caller_function_callable):
|
||||
raise Exception(f"Caller function {caller_function} is not callable")
|
||||
caller_signature = inspect.signature(caller_function_callable)
|
||||
|
||||
if caller_signature != allowed_signature:
|
||||
raise TypeError(
|
||||
f"The function '{caller_function}' does not match the required signature of 'user_api_key_auth'. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
# Check if the caller module is allowed
|
||||
if caller_function != allowed_function:
|
||||
raise ImportError(
|
||||
f"This function can only be imported by '{allowed_function}'. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def common_checks( # noqa: PLR0915
|
||||
request_body: dict,
|
||||
team_object: Optional[LiteLLM_TeamTable],
|
||||
|
@ -106,7 +77,6 @@ def common_checks( # noqa: PLR0915
|
|||
9. Check if request body is safe
|
||||
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
|
||||
"""
|
||||
_allowed_import_check()
|
||||
_model = request_body.get("model", None)
|
||||
if team_object is not None and team_object.blocked is True:
|
||||
raise Exception(
|
||||
|
@ -844,7 +814,7 @@ async def get_org_object(
|
|||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||
):
|
||||
) -> Optional[LiteLLM_OrganizationTable]:
|
||||
"""
|
||||
- Check if org id in proxy Org Table
|
||||
- if valid, return LiteLLM_OrganizationTable object
|
||||
|
@ -859,7 +829,7 @@ async def get_org_object(
|
|||
cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id))
|
||||
if cached_org_obj is not None:
|
||||
if isinstance(cached_org_obj, dict):
|
||||
return cached_org_obj
|
||||
return LiteLLM_OrganizationTable(**cached_org_obj)
|
||||
elif isinstance(cached_org_obj, LiteLLM_OrganizationTable):
|
||||
return cached_org_obj
|
||||
# else, check db
|
||||
|
|
|
@ -17,7 +17,12 @@ from cryptography.hazmat.primitives import serialization
|
|||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
from litellm.proxy._types import JWKKeyValue, JWTKeyItem, LiteLLM_JWTAuth
|
||||
from litellm.proxy._types import (
|
||||
JWKKeyValue,
|
||||
JWTKeyItem,
|
||||
LiteLLM_JWTAuth,
|
||||
LitellmUserRoles,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
|
||||
|
@ -54,6 +59,34 @@ class JWTHandler:
|
|||
parts = token.split(".")
|
||||
return len(parts) == 3
|
||||
|
||||
def get_rbac_role(self, token: dict) -> Optional[LitellmUserRoles]:
|
||||
"""
|
||||
Returns the RBAC role the token 'belongs' to.
|
||||
|
||||
RBAC roles allowed to make requests:
|
||||
- PROXY_ADMIN: can make requests to all routes
|
||||
- TEAM: can make requests to routes associated with a team
|
||||
- INTERNAL_USER: can make requests to routes associated with a user
|
||||
|
||||
Resolves: https://github.com/BerriAI/litellm/issues/6793
|
||||
|
||||
Returns:
|
||||
- PROXY_ADMIN: if token is admin
|
||||
- TEAM: if token is associated with a team
|
||||
- INTERNAL_USER: if token is associated with a user
|
||||
- None: if token is not associated with a team or user
|
||||
"""
|
||||
scopes = self.get_scopes(token=token)
|
||||
is_admin = self.is_admin(scopes=scopes)
|
||||
if is_admin:
|
||||
return LitellmUserRoles.PROXY_ADMIN
|
||||
elif self.get_team_id(token=token, default_value=None) is not None:
|
||||
return LitellmUserRoles.TEAM
|
||||
elif self.get_user_id(token=token, default_value=None) is not None:
|
||||
return LitellmUserRoles.INTERNAL_USER
|
||||
|
||||
return None
|
||||
|
||||
def is_admin(self, scopes: list) -> bool:
|
||||
if self.litellm_jwtauth.admin_jwt_scope in scopes:
|
||||
return True
|
||||
|
@ -68,12 +101,14 @@ class JWTHandler:
|
|||
self, token: dict, default_value: Optional[str]
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
|
||||
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
|
||||
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
|
||||
else:
|
||||
user_id = None
|
||||
except KeyError:
|
||||
user_id = default_value
|
||||
|
||||
return user_id
|
||||
|
||||
def is_required_team_id(self) -> bool:
|
||||
|
@ -169,6 +204,7 @@ class JWTHandler:
|
|||
return scopes
|
||||
|
||||
async def get_public_key(self, kid: Optional[str]) -> dict:
|
||||
|
||||
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
|
||||
|
||||
if keys_url is None:
|
||||
|
|
|
@ -19,6 +19,7 @@ from fastapi.security.api_key import APIKeyHeader
|
|||
import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm._service_logger import ServiceLogging
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
_cache_key_object,
|
||||
|
@ -43,12 +44,13 @@ from litellm.proxy.auth.auth_utils import (
|
|||
route_in_additonal_public_routes,
|
||||
should_run_auth_on_pass_through_provider_route,
|
||||
)
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.auth.oauth2_check import check_oauth2_token
|
||||
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.auth.service_account_checks import service_account_checks
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.utils import _to_ns
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, _to_ns
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
|
||||
|
@ -226,6 +228,221 @@ def update_valid_token_with_end_user_params(
|
|||
return valid_token
|
||||
|
||||
|
||||
async def get_global_proxy_spend(
|
||||
litellm_proxy_admin_name: str,
|
||||
user_api_key_cache: DualCache,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
token: str,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
) -> Optional[float]:
|
||||
global_proxy_spend = None
|
||||
if litellm.max_budget > 0: # user set proxy max budget
|
||||
# check cache
|
||||
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name)
|
||||
)
|
||||
if global_proxy_spend is None and prisma_client is not None:
|
||||
# get from db
|
||||
sql_query = (
|
||||
"""SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
|
||||
)
|
||||
|
||||
response = await prisma_client.db.query_raw(query=sql_query)
|
||||
|
||||
global_proxy_spend = response[0]["total_spend"]
|
||||
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name),
|
||||
value=global_proxy_spend,
|
||||
)
|
||||
if global_proxy_spend is not None:
|
||||
user_info = CallInfo(
|
||||
user_id=litellm_proxy_admin_name,
|
||||
max_budget=litellm.max_budget,
|
||||
spend=global_proxy_spend,
|
||||
token=token,
|
||||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.budget_alerts(
|
||||
type="proxy_budget",
|
||||
user_info=user_info,
|
||||
)
|
||||
)
|
||||
return global_proxy_spend
|
||||
|
||||
|
||||
def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
|
||||
is_admin = jwt_handler.is_admin(scopes=scopes)
|
||||
if is_admin:
|
||||
return LitellmUserRoles.PROXY_ADMIN
|
||||
else:
|
||||
return LitellmUserRoles.TEAM
|
||||
|
||||
|
||||
async def _jwt_auth_user_api_key_auth_builder(
|
||||
api_key: str,
|
||||
jwt_handler: JWTHandler,
|
||||
route: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
) -> JWTAuthBuilderResult:
|
||||
|
||||
# check if valid token
|
||||
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
|
||||
|
||||
# check if unmatched token and enforce_rbac is true
|
||||
if (
|
||||
jwt_handler.litellm_jwtauth.enforce_rbac is True
|
||||
and jwt_handler.get_rbac_role(token=jwt_valid_token) is None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org",
|
||||
)
|
||||
# get scopes
|
||||
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||
|
||||
# [OPTIONAL] allowed user email domains
|
||||
valid_user_email: Optional[bool] = None
|
||||
user_email: Optional[str] = None
|
||||
if jwt_handler.is_enforced_email_domain():
|
||||
"""
|
||||
if 'allowed_email_subdomains' is set,
|
||||
|
||||
- checks if token contains 'email' field
|
||||
- checks if 'email' is from an allowed domain
|
||||
"""
|
||||
user_email = jwt_handler.get_user_email(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
if user_email is None:
|
||||
valid_user_email = False
|
||||
else:
|
||||
valid_user_email = jwt_handler.is_allowed_domain(user_email=user_email)
|
||||
|
||||
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
|
||||
user_object = None
|
||||
user_id = jwt_handler.get_user_id(token=jwt_valid_token, default_value=user_email)
|
||||
|
||||
# get org id
|
||||
org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None)
|
||||
# get team id
|
||||
team_id = jwt_handler.get_team_id(token=jwt_valid_token, default_value=None)
|
||||
# get end user id
|
||||
end_user_id = jwt_handler.get_end_user_id(token=jwt_valid_token, default_value=None)
|
||||
|
||||
# check if admin
|
||||
is_admin = jwt_handler.is_admin(scopes=scopes)
|
||||
# if admin return
|
||||
if is_admin:
|
||||
# check allowed admin routes
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
)
|
||||
if is_allowed:
|
||||
return JWTAuthBuilderResult(
|
||||
is_proxy_admin=True,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
org_object=None,
|
||||
token=api_key,
|
||||
team_id=team_id,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
org_id=org_id,
|
||||
)
|
||||
else:
|
||||
allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
|
||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||
raise Exception(
|
||||
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||
)
|
||||
|
||||
if team_id is None and jwt_handler.is_required_team_id() is True:
|
||||
raise Exception(
|
||||
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
||||
)
|
||||
|
||||
team_object: Optional[LiteLLM_TeamTable] = None
|
||||
if team_id is not None:
|
||||
# check allowed team routes
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role=LitellmUserRoles.TEAM,
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
)
|
||||
if is_allowed is False:
|
||||
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
|
||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||
raise Exception(
|
||||
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||
)
|
||||
|
||||
# check if team in db
|
||||
team_object = await get_team_object(
|
||||
team_id=team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
|
||||
|
||||
org_object: Optional[LiteLLM_OrganizationTable] = None
|
||||
if org_id is not None:
|
||||
org_object = await get_org_object(
|
||||
org_id=org_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if user_id is not None:
|
||||
# get the user object
|
||||
user_object = await get_user_object(
|
||||
user_id=user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
user_id_upsert=jwt_handler.is_upsert_user_id(
|
||||
valid_user_email=valid_user_email
|
||||
),
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
|
||||
end_user_object = None
|
||||
|
||||
if end_user_id is not None:
|
||||
# get the end-user object
|
||||
end_user_object = await get_end_user_object(
|
||||
end_user_id=end_user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
return {
|
||||
"is_proxy_admin": False,
|
||||
"team_id": team_id,
|
||||
"team_object": team_object,
|
||||
"user_id": user_id,
|
||||
"user_object": user_object,
|
||||
"org_id": org_id,
|
||||
"org_object": org_object,
|
||||
"end_user_id": end_user_id,
|
||||
"end_user_object": end_user_object,
|
||||
"token": api_key,
|
||||
}
|
||||
|
||||
|
||||
async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
request: Request,
|
||||
api_key: str,
|
||||
|
@ -361,164 +578,39 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
is_jwt = jwt_handler.is_jwt(token=api_key)
|
||||
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
|
||||
if is_jwt:
|
||||
# check if valid token
|
||||
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
|
||||
# get scopes
|
||||
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||
|
||||
# check if admin
|
||||
is_admin = jwt_handler.is_admin(scopes=scopes)
|
||||
# if admin return
|
||||
if is_admin:
|
||||
# check allowed admin routes
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
result = await _jwt_auth_user_api_key_auth_builder(
|
||||
api_key=api_key,
|
||||
jwt_handler=jwt_handler,
|
||||
route=route,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
if is_allowed:
|
||||
|
||||
is_proxy_admin = result["is_proxy_admin"]
|
||||
team_id = result["team_id"]
|
||||
team_object = result["team_object"]
|
||||
user_id = result["user_id"]
|
||||
user_object = result["user_object"]
|
||||
end_user_id = result["end_user_id"]
|
||||
end_user_object = result["end_user_object"]
|
||||
org_id = result["org_id"]
|
||||
token = result["token"]
|
||||
|
||||
global_proxy_spend = await get_global_proxy_spend(
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
prisma_client=prisma_client,
|
||||
token=token,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if is_proxy_admin:
|
||||
return UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
else:
|
||||
allowed_routes: List[Any] = (
|
||||
jwt_handler.litellm_jwtauth.admin_allowed_routes
|
||||
)
|
||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||
raise Exception(
|
||||
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||
)
|
||||
|
||||
# get team id
|
||||
team_id = jwt_handler.get_team_id(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
|
||||
if team_id is None and jwt_handler.is_required_team_id() is True:
|
||||
raise Exception(
|
||||
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
||||
)
|
||||
|
||||
team_object: Optional[LiteLLM_TeamTable] = None
|
||||
if team_id is not None:
|
||||
# check allowed team routes
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role=LitellmUserRoles.TEAM,
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
)
|
||||
if is_allowed is False:
|
||||
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
|
||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||
raise Exception(
|
||||
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||
)
|
||||
|
||||
# check if team in db
|
||||
team_object = await get_team_object(
|
||||
team_id=team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
|
||||
org_id = jwt_handler.get_org_id(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
if org_id is not None:
|
||||
_ = await get_org_object(
|
||||
org_id=org_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# [OPTIONAL] allowed user email domains
|
||||
valid_user_email: Optional[bool] = None
|
||||
user_email: Optional[str] = None
|
||||
if jwt_handler.is_enforced_email_domain():
|
||||
"""
|
||||
if 'allowed_email_subdomains' is set,
|
||||
|
||||
- checks if token contains 'email' field
|
||||
- checks if 'email' is from an allowed domain
|
||||
"""
|
||||
user_email = jwt_handler.get_user_email(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
if user_email is None:
|
||||
valid_user_email = False
|
||||
else:
|
||||
valid_user_email = jwt_handler.is_allowed_domain(
|
||||
user_email=user_email
|
||||
)
|
||||
|
||||
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
|
||||
user_object = None
|
||||
user_id = jwt_handler.get_user_id(
|
||||
token=jwt_valid_token, default_value=user_email
|
||||
)
|
||||
if user_id is not None:
|
||||
# get the user object
|
||||
user_object = await get_user_object(
|
||||
user_id=user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
user_id_upsert=jwt_handler.is_upsert_user_id(
|
||||
valid_user_email=valid_user_email
|
||||
),
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
|
||||
end_user_object = None
|
||||
end_user_id = jwt_handler.get_end_user_id(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
if end_user_id is not None:
|
||||
# get the end-user object
|
||||
end_user_object = await get_end_user_object(
|
||||
end_user_id=end_user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
global_proxy_spend = None
|
||||
if litellm.max_budget > 0: # user set proxy max budget
|
||||
# check cache
|
||||
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name)
|
||||
)
|
||||
if global_proxy_spend is None and prisma_client is not None:
|
||||
# get from db
|
||||
sql_query = """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
|
||||
|
||||
response = await prisma_client.db.query_raw(query=sql_query)
|
||||
|
||||
global_proxy_spend = response[0]["total_spend"]
|
||||
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name),
|
||||
value=global_proxy_spend,
|
||||
)
|
||||
if global_proxy_spend is not None:
|
||||
user_info = CallInfo(
|
||||
user_id=litellm_proxy_admin_name,
|
||||
max_budget=litellm.max_budget,
|
||||
spend=global_proxy_spend,
|
||||
token=jwt_valid_token["token"],
|
||||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.budget_alerts(
|
||||
type="proxy_budget",
|
||||
user_info=user_info,
|
||||
)
|
||||
)
|
||||
# run through common checks
|
||||
_ = common_checks(
|
||||
request_body=request_data,
|
||||
|
@ -534,7 +626,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
# return UserAPIKeyAuth object
|
||||
return UserAPIKeyAuth(
|
||||
api_key=None,
|
||||
team_id=team_object.team_id if team_object is not None else None,
|
||||
team_id=team_id,
|
||||
team_tpm_limit=(
|
||||
team_object.tpm_limit if team_object is not None else None
|
||||
),
|
||||
|
@ -548,6 +640,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
parent_otel_span=parent_otel_span,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
|
||||
#### ELSE ####
|
||||
## CHECK PASS-THROUGH ENDPOINTS ##
|
||||
is_mapped_pass_through_route: bool = False
|
||||
|
|
|
@ -1039,6 +1039,7 @@ async def test_end_user_jwt_auth(monkeypatch):
|
|||
import json
|
||||
|
||||
monkeypatch.delenv("JWT_AUDIENCE", None)
|
||||
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
|
|
|
@ -794,3 +794,71 @@ async def test_user_api_key_auth_websocket():
|
|||
assert (
|
||||
mock_user_api_key_auth.call_args.kwargs["api_key"] == "Bearer some_api_key"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enforce_rbac", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypatch):
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.auth.user_api_key_auth import _jwt_auth_user_api_key_auth_builder
|
||||
from unittest.mock import patch, Mock
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth
|
||||
from litellm.caching import DualCache
|
||||
|
||||
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "my-fake-url")
|
||||
monkeypatch.setenv("JWT_AUDIENCE", "api://LiteLLM_Proxy-dev")
|
||||
|
||||
local_cache = DualCache()
|
||||
|
||||
keys = [
|
||||
{
|
||||
"kty": "RSA",
|
||||
"use": "sig",
|
||||
"kid": "z1rsYHHJ9-8mggt4HsZu8BKkBPw",
|
||||
"x5t": "z1rsYHHJ9-8mggt4HsZu8BKkBPw",
|
||||
"n": "pOe4GbleFDT1u5ioOQjNMmhvkDVoVD9cBKvX7AlErtWA_D6wc1w1iwkd6arYVCPObZbAB4vLSXrlpBSOuP6VYnXw_cTgniv_c82ra-mfqCpM-SbqzZ3sVqlcE_bwxvci_4PrxAW4R85ok12NXyZ2371H3yGevabi35AlVm-bQ24azo1hLK_0DzB6TxsAIOTOcKfIugOfqP-B2R4vR4u6pYftS8MWcxegr9iJ5JNtubI1X2JHpxJhkRoMVwKFna2GXmtzdxLi3yS_GffVCKfTbFMhalbJS1lSmLqhmLZZL-lrQZ6fansTl1vcGcoxnzPTwBkZMks0iVV4yfym_gKBXQ",
|
||||
"e": "AQAB",
|
||||
"x5c": [
|
||||
"MIIC/TCCAeWgAwIBAgIIQk8Qok6pfXkwDQYJKoZIhvcNAQELBQAwLTErMCkGA1UEAxMiYWNjb3VudHMuYWNjZXNzY29udHJvbC53aW5kb3dzLm5ldDAeFw0yNDExMjcwOTA0MzlaFw0yOTExMjcwOTA0MzlaMC0xKzApBgNVBAMTImFjY291bnRzLmFjY2Vzc2NvbnRyb2wud2luZG93cy5uZXQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCk57gZuV4UNPW7mKg5CM0yaG+QNWhUP1wEq9fsCUSu1YD8PrBzXDWLCR3pqthUI85tlsAHi8tJeuWkFI64/pVidfD9xOCeK/9zzatr6Z+oKkz5JurNnexWqVwT9vDG9yL/g+vEBbhHzmiTXY1fJnbfvUffIZ69puLfkCVWb5tDbhrOjWEsr/QPMHpPGwAg5M5wp8i6A5+o/4HZHi9Hi7qlh+1LwxZzF6Cv2Inkk225sjVfYkenEmGRGgxXAoWdrYZea3N3EuLfJL8Z99UIp9NsUyFqVslLWVKYuqGYtlkv6WtBnp9qexOXW9wZyjGfM9PAGRkySzSJVXjJ/Kb+AoFdAgMBAAGjITAfMB0GA1UdDgQWBBSTO5FmUwwGS+1CNqg2uNgjxUjFijANBgkqhkiG9w0BAQsFAAOCAQEAok04z0ICMEHGqDTzx6eD7vvJP8itJTCSz8JcZcGVJofJpViGF3bNnyeSPa7vNDYP1Ps9XBvw3/n2s+yynZ8EwFxMyxCZRCSbLv0N+cAbH3rmZqGcgMJszZVwcFUtXQPTe1ZRyHtEyOB+PVFH7K7obysRVO/cC6EGqIF3pYWzez/dtMaXRAkdTNlz0ko62WoA4eMPwUFCITjW/Jxfxl0BNUbo82PXXKhaeVJb+EgFG5b/pWWPswWmBoQhmD5G1UODvEACHRl/cHsPPqe4YE+6D1/wMno/xqqyGltnk8v0d4TpNcQMn9oM19V+OGgrzWOvvXhvnhqUIVGMsRlyBGNHAw=="
|
||||
],
|
||||
"cloud_instance_name": "microsoftonline.com",
|
||||
"issuer": "https://login.microsoftonline.com/bdfd79b3-8401-47fb-a764-6e595c455b05/v2.0",
|
||||
}
|
||||
]
|
||||
|
||||
local_cache.set_cache(
|
||||
key="litellm_jwt_auth_keys",
|
||||
value=keys,
|
||||
)
|
||||
|
||||
litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
**{
|
||||
"admin_jwt_scope": "litellm_proxy_endpoints_access",
|
||||
"admin_allowed_routes": ["openai_routes", "info_routes"],
|
||||
"public_key_ttl": 600,
|
||||
"enforce_rbac": enforce_rbac,
|
||||
}
|
||||
)
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.update_environment(
|
||||
prisma_client=None,
|
||||
user_api_key_cache=local_cache,
|
||||
litellm_jwtauth=litellm_jwtauth,
|
||||
leeway=10000000000000,
|
||||
)
|
||||
args = {
|
||||
"api_key": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsIng1dCI6InoxcnNZSEhKOS04bWdndDRIc1p1OEJLa0JQdyIsImtpZCI6InoxcnNZSEhKOS04bWdndDRIc1p1OEJLa0JQdyJ9.eyJhdWQiOiJhcGk6Ly9MaXRlTExNX1Byb3h5LWRldiIsImlzcyI6Imh0dHBzOi8vc3RzLndpbmRvd3MubmV0L2JkZmQ3OWIzLTg0MDEtNDdmYi1hNzY0LTZlNTk1YzQ1NWIwNS8iLCJpYXQiOjE3MzcyNDE3ODEsIm5iZiI6MTczNzI0MTc4MSwiZXhwIjoxNzM3MjQ1NjgxLCJhaW8iOiJrMlJnWUpBNE5hZGg4MGJQdXlyRmxlV1o3dHZiQUE9PSIsImFwcGlkIjoiOGNjZjNkMDItMmNkNi00N2I5LTgxODUtMGVkYjI0YWJjZjY5IiwiYXBwaWRhY3IiOiIxIiwiaWRwIjoiaHR0cHM6Ly9zdHMud2luZG93cy5uZXQvYmRmZDc5YjMtODQwMS00N2ZiLWE3NjQtNmU1OTVjNDU1YjA1LyIsIm9pZCI6IjQ0YTg3YTYzLWFiNTUtNDc4NS1iMmFmLTMzNjllZWM4ZTEzOSIsInJoIjoiMS5BYjBBczNuOXZRR0UtMGVuWkc1WlhFVmJCY0VDbkl6NHJxaE9wZ2E0UGZSZjBsbTlBQUM5QUEuIiwic3ViIjoiNDRhODdhNjMtYWI1NS00Nzg1LWIyYWYtMzM2OWVlYzhlMTM5IiwidGlkIjoiYmRmZDc5YjMtODQwMS00N2ZiLWE3NjQtNmU1OTVjNDU1YjA1IiwidXRpIjoiY3ltNVhlcmhIMHVMSlNZU1JyQmhBQSIsInZlciI6IjEuMCJ9.UooJjM9pS-wgYsExqgHdrYyQhp7NbwAsr7au9dWJaLpsufXeyHJSg-Xd5VJ4RsDVJiDes3jkC7WeoAiaCfzEHpAum-p_aqqLYXf1QIYbi1hLC0m7y_klFcqMp11WbDa9TSTvg-o8q3x2Y5su8X23ymlFih4OP17b7JA6a4_2MybU5QkCEW1tQK6VspuuXzeDHvbfGeGYcIptHFyfttHMHHXRtX1o9bX7gOR_dwFITAXD18T4ZdAN_0y6f1OtVF9TMWQhMXhKU8ahn8TSg_CXmPl9T_1gV3ZWLvVtcdVrWs82fDz3-2lEw28z4bQEr1Z5xoAz7srhx1WEBu_ioAcQiA",
|
||||
"jwt_handler": jwt_handler,
|
||||
"route": "/v1/chat/completions",
|
||||
"prisma_client": None,
|
||||
"user_api_key_cache": Mock(),
|
||||
"parent_otel_span": None,
|
||||
"proxy_logging_obj": Mock(),
|
||||
}
|
||||
|
||||
if enforce_rbac:
|
||||
with pytest.raises(HTTPException):
|
||||
await _jwt_auth_user_api_key_auth_builder(**args)
|
||||
else:
|
||||
await _jwt_auth_user_api_key_auth_builder(**args)
|
||||
|
|
|
@ -11,6 +11,7 @@ interface DashboardTeamProps {
|
|||
setProxySettings: React.Dispatch<React.SetStateAction<ProxySettings | null>>;
|
||||
userInfo: UserInfo | null;
|
||||
accessToken: string | null;
|
||||
setKeys: React.Dispatch<React.SetStateAction<any | null>>;
|
||||
}
|
||||
|
||||
type TeamInterface = {
|
||||
|
@ -27,7 +28,8 @@ const DashboardTeam: React.FC<DashboardTeamProps> = ({
|
|||
proxySettings,
|
||||
setProxySettings,
|
||||
userInfo,
|
||||
accessToken
|
||||
accessToken,
|
||||
setKeys
|
||||
}) => {
|
||||
console.log(`userInfo: ${JSON.stringify(userInfo)}`)
|
||||
const defaultTeam: TeamInterface = {
|
||||
|
@ -80,7 +82,10 @@ const DashboardTeam: React.FC<DashboardTeamProps> = ({
|
|||
<SelectItem
|
||||
key={index}
|
||||
value={String(index)}
|
||||
onClick={() => setSelectedTeam(team)}
|
||||
onClick={() => {
|
||||
setSelectedTeam(team);
|
||||
setKeys(team["keys"]);
|
||||
}}
|
||||
>
|
||||
{team["team_alias"]}
|
||||
</SelectItem>
|
||||
|
|
|
@ -206,13 +206,15 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
|||
|
||||
setUserSpendData(response["user_info"]);
|
||||
console.log(`userSpendData: ${JSON.stringify(userSpendData)}`)
|
||||
setKeys(response["keys"]); // Assuming this is the correct path to your data
|
||||
|
||||
const teamsArray = [...response["teams"]];
|
||||
if (teamsArray.length > 0) {
|
||||
console.log(`response['teams']: ${teamsArray}`);
|
||||
console.log(`response['teams']: ${JSON.stringify(teamsArray)}`);
|
||||
setSelectedTeam(teamsArray[0]);
|
||||
setKeys(teamsArray[0]["keys"]); // Assuming this is the correct path to your data
|
||||
} else {
|
||||
setSelectedTeam(defaultTeam);
|
||||
setKeys(response["keys"]); // Assuming this is the correct path to your data
|
||||
}
|
||||
sessionStorage.setItem(
|
||||
"userData" + userID,
|
||||
|
@ -261,6 +263,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
|||
selectedTeam.team_id !== null
|
||||
) {
|
||||
let sum = 0;
|
||||
console.log(`keys: ${JSON.stringify(keys)}`)
|
||||
for (const key of keys) {
|
||||
if (
|
||||
selectedTeam.hasOwnProperty("team_id") &&
|
||||
|
@ -270,6 +273,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
|||
sum += key.spend;
|
||||
}
|
||||
}
|
||||
console.log(`sum: ${sum}`)
|
||||
setTeamSpend(sum);
|
||||
} else if (keys !== null) {
|
||||
// sum the keys which don't have team-id set (default team)
|
||||
|
@ -367,6 +371,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
|||
setProxySettings={setProxySettings}
|
||||
userInfo={userSpendData}
|
||||
accessToken={accessToken}
|
||||
setKeys={setKeys}
|
||||
/>
|
||||
</Col>
|
||||
</Grid>
|
||||
|
|
|
@ -114,8 +114,6 @@ const ViewUserSpend: React.FC<ViewUserSpendProps> = ({ userID, userRole, accessT
|
|||
fetchData();
|
||||
}, [userRole, accessToken, userID]);
|
||||
|
||||
|
||||
|
||||
useEffect(() => {
|
||||
if (userSpend !== null) {
|
||||
setSpend(userSpend)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue