mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
JWT Auth - enforce_rbac
support + UI team view, spend calc fix (#7863)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s
* fix(user_dashboard.tsx): fix spend calculation when team selected sum all team keys, not user keys * docs(admin_ui_sso.md): fix docs tabbing * feat(user_api_key_auth.py): introduce new 'enforce_rbac' param on jwt auth allows proxy admin to prevent any unmapped yet authenticated jwt tokens from calling proxy Fixes https://github.com/BerriAI/litellm/issues/6793 * test: more unit testing + refactoring * fix: fix returning id when obj not found in db * fix(user_api_key_auth.py): add end user id tracking from jwt auth * docs(token_auth.md): add doc on rbac with JWTs * fix: fix unused params * test: remove old test
This commit is contained in:
parent
c306c2e0fc
commit
dca6904937
12 changed files with 449 additions and 197 deletions
|
@ -1,3 +1,7 @@
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# ✨ SSO for Admin UI
|
# ✨ SSO for Admin UI
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
|
|
@ -114,7 +114,7 @@ general_settings:
|
||||||
admin_jwt_scope: "litellm-proxy-admin"
|
admin_jwt_scope: "litellm-proxy-admin"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Advanced - Spend Tracking (End-Users / Internal Users / Team / Org)
|
## Tracking End-Users / Internal Users / Team / Org
|
||||||
|
|
||||||
Set the field in the jwt token, which corresponds to a litellm user / team / org.
|
Set the field in the jwt token, which corresponds to a litellm user / team / org.
|
||||||
|
|
||||||
|
@ -156,6 +156,33 @@ scope: ["litellm-proxy-admin",...]
|
||||||
scope: "litellm-proxy-admin ..."
|
scope: "litellm-proxy-admin ..."
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Enforce Role-Based Access Control (RBAC)
|
||||||
|
|
||||||
|
Reject a JWT token if it's valid but doesn't have the required scopes / fields.
|
||||||
|
|
||||||
|
Only tokens which with valid Admin (`admin_jwt_scope`), User (`user_id_jwt_field`), Team (`team_id_jwt_field`) are allowed.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234
|
||||||
|
enable_jwt_auth: True
|
||||||
|
litellm_jwtauth:
|
||||||
|
admin_jwt_scope: "litellm_proxy_endpoints_access"
|
||||||
|
admin_allowed_routes:
|
||||||
|
- openai_routes
|
||||||
|
- info_routes
|
||||||
|
public_key_ttl: 600
|
||||||
|
enforce_rbac: true # 👈 Enforce RBAC
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected Scope in JWT:
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"scope": "litellm_proxy_endpoints_access"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Advanced - Allowed Routes
|
## Advanced - Allowed Routes
|
||||||
|
|
||||||
Configure which routes a JWT can access via the config.
|
Configure which routes a JWT can access via the config.
|
||||||
|
|
|
@ -6587,6 +6587,27 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "image_generation"
|
"mode": "image_generation"
|
||||||
},
|
},
|
||||||
|
"stability.sd3-5-large-v1:0": {
|
||||||
|
"max_tokens": 77,
|
||||||
|
"max_input_tokens": 77,
|
||||||
|
"output_cost_per_image": 0.08,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "image_generation"
|
||||||
|
},
|
||||||
|
"stability.stable-image-core-v1:0": {
|
||||||
|
"max_tokens": 77,
|
||||||
|
"max_input_tokens": 77,
|
||||||
|
"output_cost_per_image": 0.04,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "image_generation"
|
||||||
|
},
|
||||||
|
"stability.stable-image-core-v1:1": {
|
||||||
|
"max_tokens": 77,
|
||||||
|
"max_input_tokens": 77,
|
||||||
|
"output_cost_per_image": 0.04,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "image_generation"
|
||||||
|
},
|
||||||
"stability.stable-image-ultra-v1:0": {
|
"stability.stable-image-ultra-v1:0": {
|
||||||
"max_tokens": 77,
|
"max_tokens": 77,
|
||||||
"max_input_tokens": 77,
|
"max_input_tokens": 77,
|
||||||
|
@ -6594,6 +6615,13 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "image_generation"
|
"mode": "image_generation"
|
||||||
},
|
},
|
||||||
|
"stability.stable-image-ultra-v1:1": {
|
||||||
|
"max_tokens": 77,
|
||||||
|
"max_input_tokens": 77,
|
||||||
|
"output_cost_per_image": 0.14,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "image_generation"
|
||||||
|
},
|
||||||
"sagemaker/meta-textgeneration-llama-2-7b": {
|
"sagemaker/meta-textgeneration-llama-2-7b": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 4096,
|
"max_input_tokens": 4096,
|
||||||
|
|
|
@ -416,6 +416,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||||
- user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy.
|
- user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy.
|
||||||
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
|
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
|
||||||
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
|
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
|
||||||
|
- public_allowed_routes: list of allowed routes for authenticated but unknown litellm role jwt tokens.
|
||||||
|
- enforce_rbac: If true, enforce RBAC for all routes.
|
||||||
|
|
||||||
See `auth_checks.py` for the specific routes
|
See `auth_checks.py` for the specific routes
|
||||||
"""
|
"""
|
||||||
|
@ -446,6 +448,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||||
)
|
)
|
||||||
end_user_id_jwt_field: Optional[str] = None
|
end_user_id_jwt_field: Optional[str] = None
|
||||||
public_key_ttl: float = 600
|
public_key_ttl: float = 600
|
||||||
|
public_allowed_routes: List[str] = ["public_routes"]
|
||||||
|
enforce_rbac: bool = False
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
# get the attribute names for this Pydantic model
|
# get the attribute names for this Pydantic model
|
||||||
|
@ -2284,6 +2288,19 @@ class ProxyStateVariables(TypedDict):
|
||||||
UI_TEAM_ID = "litellm-dashboard"
|
UI_TEAM_ID = "litellm-dashboard"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class JWTAuthBuilderResult(TypedDict):
|
||||||
|
is_proxy_admin: bool
|
||||||
|
team_object: Optional[LiteLLM_TeamTable]
|
||||||
|
user_object: Optional[LiteLLM_UserTable]
|
||||||
|
end_user_object: Optional[LiteLLM_EndUserTable]
|
||||||
|
org_object: Optional[LiteLLM_OrganizationTable]
|
||||||
|
token: str
|
||||||
|
team_id: Optional[str]
|
||||||
|
user_id: Optional[str]
|
||||||
|
end_user_id: Optional[str]
|
||||||
|
org_id: Optional[str]
|
||||||
|
|
||||||
class ClientSideFallbackModel(TypedDict, total=False):
|
class ClientSideFallbackModel(TypedDict, total=False):
|
||||||
"""
|
"""
|
||||||
Dictionary passed when client configuring input
|
Dictionary passed when client configuring input
|
||||||
|
|
|
@ -9,7 +9,6 @@ Run checks for:
|
||||||
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import inspect
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||||
|
@ -22,7 +21,6 @@ from litellm.caching.caching import DualCache
|
||||||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
DB_CONNECTION_ERROR_TYPES,
|
DB_CONNECTION_ERROR_TYPES,
|
||||||
CommonProxyErrors,
|
|
||||||
LiteLLM_EndUserTable,
|
LiteLLM_EndUserTable,
|
||||||
LiteLLM_JWTAuth,
|
LiteLLM_JWTAuth,
|
||||||
LiteLLM_OrganizationTable,
|
LiteLLM_OrganizationTable,
|
||||||
|
@ -55,33 +53,6 @@ db_cache_expiry = 5 # refresh every 5s
|
||||||
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
||||||
|
|
||||||
|
|
||||||
def _allowed_import_check() -> bool:
|
|
||||||
from litellm.proxy.auth.user_api_key_auth import _user_api_key_auth_builder
|
|
||||||
|
|
||||||
# Get the calling frame
|
|
||||||
caller_frame = inspect.stack()[2]
|
|
||||||
caller_function = caller_frame.function
|
|
||||||
caller_function_callable = caller_frame.frame.f_globals.get(caller_function)
|
|
||||||
|
|
||||||
allowed_function = "_user_api_key_auth_builder"
|
|
||||||
allowed_signature = inspect.signature(_user_api_key_auth_builder)
|
|
||||||
if caller_function_callable is None or not callable(caller_function_callable):
|
|
||||||
raise Exception(f"Caller function {caller_function} is not callable")
|
|
||||||
caller_signature = inspect.signature(caller_function_callable)
|
|
||||||
|
|
||||||
if caller_signature != allowed_signature:
|
|
||||||
raise TypeError(
|
|
||||||
f"The function '{caller_function}' does not match the required signature of 'user_api_key_auth'. {CommonProxyErrors.not_premium_user.value}"
|
|
||||||
)
|
|
||||||
# Check if the caller module is allowed
|
|
||||||
if caller_function != allowed_function:
|
|
||||||
raise ImportError(
|
|
||||||
f"This function can only be imported by '{allowed_function}'. {CommonProxyErrors.not_premium_user.value}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def common_checks( # noqa: PLR0915
|
def common_checks( # noqa: PLR0915
|
||||||
request_body: dict,
|
request_body: dict,
|
||||||
team_object: Optional[LiteLLM_TeamTable],
|
team_object: Optional[LiteLLM_TeamTable],
|
||||||
|
@ -106,7 +77,6 @@ def common_checks( # noqa: PLR0915
|
||||||
9. Check if request body is safe
|
9. Check if request body is safe
|
||||||
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
|
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
|
||||||
"""
|
"""
|
||||||
_allowed_import_check()
|
|
||||||
_model = request_body.get("model", None)
|
_model = request_body.get("model", None)
|
||||||
if team_object is not None and team_object.blocked is True:
|
if team_object is not None and team_object.blocked is True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -844,7 +814,7 @@ async def get_org_object(
|
||||||
user_api_key_cache: DualCache,
|
user_api_key_cache: DualCache,
|
||||||
parent_otel_span: Optional[Span] = None,
|
parent_otel_span: Optional[Span] = None,
|
||||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||||
):
|
) -> Optional[LiteLLM_OrganizationTable]:
|
||||||
"""
|
"""
|
||||||
- Check if org id in proxy Org Table
|
- Check if org id in proxy Org Table
|
||||||
- if valid, return LiteLLM_OrganizationTable object
|
- if valid, return LiteLLM_OrganizationTable object
|
||||||
|
@ -859,7 +829,7 @@ async def get_org_object(
|
||||||
cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id))
|
cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id))
|
||||||
if cached_org_obj is not None:
|
if cached_org_obj is not None:
|
||||||
if isinstance(cached_org_obj, dict):
|
if isinstance(cached_org_obj, dict):
|
||||||
return cached_org_obj
|
return LiteLLM_OrganizationTable(**cached_org_obj)
|
||||||
elif isinstance(cached_org_obj, LiteLLM_OrganizationTable):
|
elif isinstance(cached_org_obj, LiteLLM_OrganizationTable):
|
||||||
return cached_org_obj
|
return cached_org_obj
|
||||||
# else, check db
|
# else, check db
|
||||||
|
|
|
@ -17,7 +17,12 @@ from cryptography.hazmat.primitives import serialization
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||||
from litellm.proxy._types import JWKKeyValue, JWTKeyItem, LiteLLM_JWTAuth
|
from litellm.proxy._types import (
|
||||||
|
JWKKeyValue,
|
||||||
|
JWTKeyItem,
|
||||||
|
LiteLLM_JWTAuth,
|
||||||
|
LitellmUserRoles,
|
||||||
|
)
|
||||||
from litellm.proxy.utils import PrismaClient
|
from litellm.proxy.utils import PrismaClient
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,6 +59,34 @@ class JWTHandler:
|
||||||
parts = token.split(".")
|
parts = token.split(".")
|
||||||
return len(parts) == 3
|
return len(parts) == 3
|
||||||
|
|
||||||
|
def get_rbac_role(self, token: dict) -> Optional[LitellmUserRoles]:
|
||||||
|
"""
|
||||||
|
Returns the RBAC role the token 'belongs' to.
|
||||||
|
|
||||||
|
RBAC roles allowed to make requests:
|
||||||
|
- PROXY_ADMIN: can make requests to all routes
|
||||||
|
- TEAM: can make requests to routes associated with a team
|
||||||
|
- INTERNAL_USER: can make requests to routes associated with a user
|
||||||
|
|
||||||
|
Resolves: https://github.com/BerriAI/litellm/issues/6793
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- PROXY_ADMIN: if token is admin
|
||||||
|
- TEAM: if token is associated with a team
|
||||||
|
- INTERNAL_USER: if token is associated with a user
|
||||||
|
- None: if token is not associated with a team or user
|
||||||
|
"""
|
||||||
|
scopes = self.get_scopes(token=token)
|
||||||
|
is_admin = self.is_admin(scopes=scopes)
|
||||||
|
if is_admin:
|
||||||
|
return LitellmUserRoles.PROXY_ADMIN
|
||||||
|
elif self.get_team_id(token=token, default_value=None) is not None:
|
||||||
|
return LitellmUserRoles.TEAM
|
||||||
|
elif self.get_user_id(token=token, default_value=None) is not None:
|
||||||
|
return LitellmUserRoles.INTERNAL_USER
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def is_admin(self, scopes: list) -> bool:
|
def is_admin(self, scopes: list) -> bool:
|
||||||
if self.litellm_jwtauth.admin_jwt_scope in scopes:
|
if self.litellm_jwtauth.admin_jwt_scope in scopes:
|
||||||
return True
|
return True
|
||||||
|
@ -68,12 +101,14 @@ class JWTHandler:
|
||||||
self, token: dict, default_value: Optional[str]
|
self, token: dict, default_value: Optional[str]
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
|
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
|
||||||
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
|
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
|
||||||
else:
|
else:
|
||||||
user_id = None
|
user_id = None
|
||||||
except KeyError:
|
except KeyError:
|
||||||
user_id = default_value
|
user_id = default_value
|
||||||
|
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
def is_required_team_id(self) -> bool:
|
def is_required_team_id(self) -> bool:
|
||||||
|
@ -169,6 +204,7 @@ class JWTHandler:
|
||||||
return scopes
|
return scopes
|
||||||
|
|
||||||
async def get_public_key(self, kid: Optional[str]) -> dict:
|
async def get_public_key(self, kid: Optional[str]) -> dict:
|
||||||
|
|
||||||
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
|
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
|
||||||
|
|
||||||
if keys_url is None:
|
if keys_url is None:
|
||||||
|
|
|
@ -19,6 +19,7 @@ from fastapi.security.api_key import APIKeyHeader
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||||
from litellm._service_logger import ServiceLogging
|
from litellm._service_logger import ServiceLogging
|
||||||
|
from litellm.caching import DualCache
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.proxy.auth.auth_checks import (
|
from litellm.proxy.auth.auth_checks import (
|
||||||
_cache_key_object,
|
_cache_key_object,
|
||||||
|
@ -43,12 +44,13 @@ from litellm.proxy.auth.auth_utils import (
|
||||||
route_in_additonal_public_routes,
|
route_in_additonal_public_routes,
|
||||||
should_run_auth_on_pass_through_provider_route,
|
should_run_auth_on_pass_through_provider_route,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
from litellm.proxy.auth.oauth2_check import check_oauth2_token
|
from litellm.proxy.auth.oauth2_check import check_oauth2_token
|
||||||
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
|
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
|
||||||
from litellm.proxy.auth.route_checks import RouteChecks
|
from litellm.proxy.auth.route_checks import RouteChecks
|
||||||
from litellm.proxy.auth.service_account_checks import service_account_checks
|
from litellm.proxy.auth.service_account_checks import service_account_checks
|
||||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||||
from litellm.proxy.utils import _to_ns
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, _to_ns
|
||||||
from litellm.types.services import ServiceTypes
|
from litellm.types.services import ServiceTypes
|
||||||
|
|
||||||
user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
|
user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
|
||||||
|
@ -226,6 +228,221 @@ def update_valid_token_with_end_user_params(
|
||||||
return valid_token
|
return valid_token
|
||||||
|
|
||||||
|
|
||||||
|
async def get_global_proxy_spend(
|
||||||
|
litellm_proxy_admin_name: str,
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
prisma_client: Optional[PrismaClient],
|
||||||
|
token: str,
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
) -> Optional[float]:
|
||||||
|
global_proxy_spend = None
|
||||||
|
if litellm.max_budget > 0: # user set proxy max budget
|
||||||
|
# check cache
|
||||||
|
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
||||||
|
key="{}:spend".format(litellm_proxy_admin_name)
|
||||||
|
)
|
||||||
|
if global_proxy_spend is None and prisma_client is not None:
|
||||||
|
# get from db
|
||||||
|
sql_query = (
|
||||||
|
"""SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await prisma_client.db.query_raw(query=sql_query)
|
||||||
|
|
||||||
|
global_proxy_spend = response[0]["total_spend"]
|
||||||
|
|
||||||
|
await user_api_key_cache.async_set_cache(
|
||||||
|
key="{}:spend".format(litellm_proxy_admin_name),
|
||||||
|
value=global_proxy_spend,
|
||||||
|
)
|
||||||
|
if global_proxy_spend is not None:
|
||||||
|
user_info = CallInfo(
|
||||||
|
user_id=litellm_proxy_admin_name,
|
||||||
|
max_budget=litellm.max_budget,
|
||||||
|
spend=global_proxy_spend,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.budget_alerts(
|
||||||
|
type="proxy_budget",
|
||||||
|
user_info=user_info,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return global_proxy_spend
|
||||||
|
|
||||||
|
|
||||||
|
def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
|
||||||
|
is_admin = jwt_handler.is_admin(scopes=scopes)
|
||||||
|
if is_admin:
|
||||||
|
return LitellmUserRoles.PROXY_ADMIN
|
||||||
|
else:
|
||||||
|
return LitellmUserRoles.TEAM
|
||||||
|
|
||||||
|
|
||||||
|
async def _jwt_auth_user_api_key_auth_builder(
|
||||||
|
api_key: str,
|
||||||
|
jwt_handler: JWTHandler,
|
||||||
|
route: str,
|
||||||
|
prisma_client: Optional[PrismaClient],
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
parent_otel_span: Optional[Span],
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
) -> JWTAuthBuilderResult:
|
||||||
|
|
||||||
|
# check if valid token
|
||||||
|
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
|
||||||
|
|
||||||
|
# check if unmatched token and enforce_rbac is true
|
||||||
|
if (
|
||||||
|
jwt_handler.litellm_jwtauth.enforce_rbac is True
|
||||||
|
and jwt_handler.get_rbac_role(token=jwt_valid_token) is None
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org",
|
||||||
|
)
|
||||||
|
# get scopes
|
||||||
|
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||||
|
|
||||||
|
# [OPTIONAL] allowed user email domains
|
||||||
|
valid_user_email: Optional[bool] = None
|
||||||
|
user_email: Optional[str] = None
|
||||||
|
if jwt_handler.is_enforced_email_domain():
|
||||||
|
"""
|
||||||
|
if 'allowed_email_subdomains' is set,
|
||||||
|
|
||||||
|
- checks if token contains 'email' field
|
||||||
|
- checks if 'email' is from an allowed domain
|
||||||
|
"""
|
||||||
|
user_email = jwt_handler.get_user_email(
|
||||||
|
token=jwt_valid_token, default_value=None
|
||||||
|
)
|
||||||
|
if user_email is None:
|
||||||
|
valid_user_email = False
|
||||||
|
else:
|
||||||
|
valid_user_email = jwt_handler.is_allowed_domain(user_email=user_email)
|
||||||
|
|
||||||
|
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
|
||||||
|
user_object = None
|
||||||
|
user_id = jwt_handler.get_user_id(token=jwt_valid_token, default_value=user_email)
|
||||||
|
|
||||||
|
# get org id
|
||||||
|
org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None)
|
||||||
|
# get team id
|
||||||
|
team_id = jwt_handler.get_team_id(token=jwt_valid_token, default_value=None)
|
||||||
|
# get end user id
|
||||||
|
end_user_id = jwt_handler.get_end_user_id(token=jwt_valid_token, default_value=None)
|
||||||
|
|
||||||
|
# check if admin
|
||||||
|
is_admin = jwt_handler.is_admin(scopes=scopes)
|
||||||
|
# if admin return
|
||||||
|
if is_admin:
|
||||||
|
# check allowed admin routes
|
||||||
|
is_allowed = allowed_routes_check(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
user_route=route,
|
||||||
|
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||||
|
)
|
||||||
|
if is_allowed:
|
||||||
|
return JWTAuthBuilderResult(
|
||||||
|
is_proxy_admin=True,
|
||||||
|
team_object=None,
|
||||||
|
user_object=None,
|
||||||
|
end_user_object=None,
|
||||||
|
org_object=None,
|
||||||
|
token=api_key,
|
||||||
|
team_id=team_id,
|
||||||
|
user_id=user_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
org_id=org_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
|
||||||
|
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||||
|
raise Exception(
|
||||||
|
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if team_id is None and jwt_handler.is_required_team_id() is True:
|
||||||
|
raise Exception(
|
||||||
|
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
team_object: Optional[LiteLLM_TeamTable] = None
|
||||||
|
if team_id is not None:
|
||||||
|
# check allowed team routes
|
||||||
|
is_allowed = allowed_routes_check(
|
||||||
|
user_role=LitellmUserRoles.TEAM,
|
||||||
|
user_route=route,
|
||||||
|
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||||
|
)
|
||||||
|
if is_allowed is False:
|
||||||
|
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
|
||||||
|
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||||
|
raise Exception(
|
||||||
|
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if team in db
|
||||||
|
team_object = await get_team_object(
|
||||||
|
team_id=team_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
|
||||||
|
|
||||||
|
org_object: Optional[LiteLLM_OrganizationTable] = None
|
||||||
|
if org_id is not None:
|
||||||
|
org_object = await get_org_object(
|
||||||
|
org_id=org_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_id is not None:
|
||||||
|
# get the user object
|
||||||
|
user_object = await get_user_object(
|
||||||
|
user_id=user_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
user_id_upsert=jwt_handler.is_upsert_user_id(
|
||||||
|
valid_user_email=valid_user_email
|
||||||
|
),
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
|
||||||
|
end_user_object = None
|
||||||
|
|
||||||
|
if end_user_id is not None:
|
||||||
|
# get the end-user object
|
||||||
|
end_user_object = await get_end_user_object(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"is_proxy_admin": False,
|
||||||
|
"team_id": team_id,
|
||||||
|
"team_object": team_object,
|
||||||
|
"user_id": user_id,
|
||||||
|
"user_object": user_object,
|
||||||
|
"org_id": org_id,
|
||||||
|
"org_object": org_object,
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"end_user_object": end_user_object,
|
||||||
|
"token": api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def _user_api_key_auth_builder( # noqa: PLR0915
|
async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
|
@ -361,164 +578,39 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
is_jwt = jwt_handler.is_jwt(token=api_key)
|
is_jwt = jwt_handler.is_jwt(token=api_key)
|
||||||
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
|
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
|
||||||
if is_jwt:
|
if is_jwt:
|
||||||
# check if valid token
|
result = await _jwt_auth_user_api_key_auth_builder(
|
||||||
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
|
api_key=api_key,
|
||||||
# get scopes
|
jwt_handler=jwt_handler,
|
||||||
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
route=route,
|
||||||
|
prisma_client=prisma_client,
|
||||||
# check if admin
|
user_api_key_cache=user_api_key_cache,
|
||||||
is_admin = jwt_handler.is_admin(scopes=scopes)
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
# if admin return
|
parent_otel_span=parent_otel_span,
|
||||||
if is_admin:
|
|
||||||
# check allowed admin routes
|
|
||||||
is_allowed = allowed_routes_check(
|
|
||||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
||||||
user_route=route,
|
|
||||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
|
||||||
)
|
)
|
||||||
if is_allowed:
|
|
||||||
|
is_proxy_admin = result["is_proxy_admin"]
|
||||||
|
team_id = result["team_id"]
|
||||||
|
team_object = result["team_object"]
|
||||||
|
user_id = result["user_id"]
|
||||||
|
user_object = result["user_object"]
|
||||||
|
end_user_id = result["end_user_id"]
|
||||||
|
end_user_object = result["end_user_object"]
|
||||||
|
org_id = result["org_id"]
|
||||||
|
token = result["token"]
|
||||||
|
|
||||||
|
global_proxy_spend = await get_global_proxy_spend(
|
||||||
|
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
token=token,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_proxy_admin:
|
||||||
return UserAPIKeyAuth(
|
return UserAPIKeyAuth(
|
||||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
allowed_routes: List[Any] = (
|
|
||||||
jwt_handler.litellm_jwtauth.admin_allowed_routes
|
|
||||||
)
|
|
||||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
|
||||||
raise Exception(
|
|
||||||
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# get team id
|
|
||||||
team_id = jwt_handler.get_team_id(
|
|
||||||
token=jwt_valid_token, default_value=None
|
|
||||||
)
|
|
||||||
|
|
||||||
if team_id is None and jwt_handler.is_required_team_id() is True:
|
|
||||||
raise Exception(
|
|
||||||
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
team_object: Optional[LiteLLM_TeamTable] = None
|
|
||||||
if team_id is not None:
|
|
||||||
# check allowed team routes
|
|
||||||
is_allowed = allowed_routes_check(
|
|
||||||
user_role=LitellmUserRoles.TEAM,
|
|
||||||
user_route=route,
|
|
||||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
|
||||||
)
|
|
||||||
if is_allowed is False:
|
|
||||||
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
|
|
||||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
|
||||||
raise Exception(
|
|
||||||
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# check if team in db
|
|
||||||
team_object = await get_team_object(
|
|
||||||
team_id=team_id,
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
user_api_key_cache=user_api_key_cache,
|
|
||||||
parent_otel_span=parent_otel_span,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
|
|
||||||
org_id = jwt_handler.get_org_id(
|
|
||||||
token=jwt_valid_token, default_value=None
|
|
||||||
)
|
|
||||||
if org_id is not None:
|
|
||||||
_ = await get_org_object(
|
|
||||||
org_id=org_id,
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
user_api_key_cache=user_api_key_cache,
|
|
||||||
parent_otel_span=parent_otel_span,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
)
|
|
||||||
# [OPTIONAL] allowed user email domains
|
|
||||||
valid_user_email: Optional[bool] = None
|
|
||||||
user_email: Optional[str] = None
|
|
||||||
if jwt_handler.is_enforced_email_domain():
|
|
||||||
"""
|
|
||||||
if 'allowed_email_subdomains' is set,
|
|
||||||
|
|
||||||
- checks if token contains 'email' field
|
|
||||||
- checks if 'email' is from an allowed domain
|
|
||||||
"""
|
|
||||||
user_email = jwt_handler.get_user_email(
|
|
||||||
token=jwt_valid_token, default_value=None
|
|
||||||
)
|
|
||||||
if user_email is None:
|
|
||||||
valid_user_email = False
|
|
||||||
else:
|
|
||||||
valid_user_email = jwt_handler.is_allowed_domain(
|
|
||||||
user_email=user_email
|
|
||||||
)
|
|
||||||
|
|
||||||
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
|
|
||||||
user_object = None
|
|
||||||
user_id = jwt_handler.get_user_id(
|
|
||||||
token=jwt_valid_token, default_value=user_email
|
|
||||||
)
|
|
||||||
if user_id is not None:
|
|
||||||
# get the user object
|
|
||||||
user_object = await get_user_object(
|
|
||||||
user_id=user_id,
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
user_api_key_cache=user_api_key_cache,
|
|
||||||
user_id_upsert=jwt_handler.is_upsert_user_id(
|
|
||||||
valid_user_email=valid_user_email
|
|
||||||
),
|
|
||||||
parent_otel_span=parent_otel_span,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
)
|
|
||||||
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
|
|
||||||
end_user_object = None
|
|
||||||
end_user_id = jwt_handler.get_end_user_id(
|
|
||||||
token=jwt_valid_token, default_value=None
|
|
||||||
)
|
|
||||||
if end_user_id is not None:
|
|
||||||
# get the end-user object
|
|
||||||
end_user_object = await get_end_user_object(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
user_api_key_cache=user_api_key_cache,
|
|
||||||
parent_otel_span=parent_otel_span,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
global_proxy_spend = None
|
|
||||||
if litellm.max_budget > 0: # user set proxy max budget
|
|
||||||
# check cache
|
|
||||||
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
|
||||||
key="{}:spend".format(litellm_proxy_admin_name)
|
|
||||||
)
|
|
||||||
if global_proxy_spend is None and prisma_client is not None:
|
|
||||||
# get from db
|
|
||||||
sql_query = """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";"""
|
|
||||||
|
|
||||||
response = await prisma_client.db.query_raw(query=sql_query)
|
|
||||||
|
|
||||||
global_proxy_spend = response[0]["total_spend"]
|
|
||||||
|
|
||||||
await user_api_key_cache.async_set_cache(
|
|
||||||
key="{}:spend".format(litellm_proxy_admin_name),
|
|
||||||
value=global_proxy_spend,
|
|
||||||
)
|
|
||||||
if global_proxy_spend is not None:
|
|
||||||
user_info = CallInfo(
|
|
||||||
user_id=litellm_proxy_admin_name,
|
|
||||||
max_budget=litellm.max_budget,
|
|
||||||
spend=global_proxy_spend,
|
|
||||||
token=jwt_valid_token["token"],
|
|
||||||
)
|
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.budget_alerts(
|
|
||||||
type="proxy_budget",
|
|
||||||
user_info=user_info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# run through common checks
|
# run through common checks
|
||||||
_ = common_checks(
|
_ = common_checks(
|
||||||
request_body=request_data,
|
request_body=request_data,
|
||||||
|
@ -534,7 +626,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
# return UserAPIKeyAuth object
|
# return UserAPIKeyAuth object
|
||||||
return UserAPIKeyAuth(
|
return UserAPIKeyAuth(
|
||||||
api_key=None,
|
api_key=None,
|
||||||
team_id=team_object.team_id if team_object is not None else None,
|
team_id=team_id,
|
||||||
team_tpm_limit=(
|
team_tpm_limit=(
|
||||||
team_object.tpm_limit if team_object is not None else None
|
team_object.tpm_limit if team_object is not None else None
|
||||||
),
|
),
|
||||||
|
@ -548,6 +640,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
#### ELSE ####
|
#### ELSE ####
|
||||||
## CHECK PASS-THROUGH ENDPOINTS ##
|
## CHECK PASS-THROUGH ENDPOINTS ##
|
||||||
is_mapped_pass_through_route: bool = False
|
is_mapped_pass_through_route: bool = False
|
||||||
|
|
|
@ -1039,6 +1039,7 @@ async def test_end_user_jwt_auth(monkeypatch):
|
||||||
import json
|
import json
|
||||||
|
|
||||||
monkeypatch.delenv("JWT_AUDIENCE", None)
|
monkeypatch.delenv("JWT_AUDIENCE", None)
|
||||||
|
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
|
||||||
jwt_handler = JWTHandler()
|
jwt_handler = JWTHandler()
|
||||||
|
|
||||||
litellm_jwtauth = LiteLLM_JWTAuth(
|
litellm_jwtauth = LiteLLM_JWTAuth(
|
||||||
|
|
|
@ -794,3 +794,71 @@ async def test_user_api_key_auth_websocket():
|
||||||
assert (
|
assert (
|
||||||
mock_user_api_key_auth.call_args.kwargs["api_key"] == "Bearer some_api_key"
|
mock_user_api_key_auth.call_args.kwargs["api_key"] == "Bearer some_api_key"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("enforce_rbac", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypatch):
|
||||||
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import _jwt_auth_user_api_key_auth_builder
|
||||||
|
from unittest.mock import patch, Mock
|
||||||
|
from litellm.proxy._types import LiteLLM_JWTAuth
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "my-fake-url")
|
||||||
|
monkeypatch.setenv("JWT_AUDIENCE", "api://LiteLLM_Proxy-dev")
|
||||||
|
|
||||||
|
local_cache = DualCache()
|
||||||
|
|
||||||
|
keys = [
|
||||||
|
{
|
||||||
|
"kty": "RSA",
|
||||||
|
"use": "sig",
|
||||||
|
"kid": "z1rsYHHJ9-8mggt4HsZu8BKkBPw",
|
||||||
|
"x5t": "z1rsYHHJ9-8mggt4HsZu8BKkBPw",
|
||||||
|
"n": "pOe4GbleFDT1u5ioOQjNMmhvkDVoVD9cBKvX7AlErtWA_D6wc1w1iwkd6arYVCPObZbAB4vLSXrlpBSOuP6VYnXw_cTgniv_c82ra-mfqCpM-SbqzZ3sVqlcE_bwxvci_4PrxAW4R85ok12NXyZ2371H3yGevabi35AlVm-bQ24azo1hLK_0DzB6TxsAIOTOcKfIugOfqP-B2R4vR4u6pYftS8MWcxegr9iJ5JNtubI1X2JHpxJhkRoMVwKFna2GXmtzdxLi3yS_GffVCKfTbFMhalbJS1lSmLqhmLZZL-lrQZ6fansTl1vcGcoxnzPTwBkZMks0iVV4yfym_gKBXQ",
|
||||||
|
"e": "AQAB",
|
||||||
|
"x5c": [
|
||||||
|
"MIIC/TCCAeWgAwIBAgIIQk8Qok6pfXkwDQYJKoZIhvcNAQELBQAwLTErMCkGA1UEAxMiYWNjb3VudHMuYWNjZXNzY29udHJvbC53aW5kb3dzLm5ldDAeFw0yNDExMjcwOTA0MzlaFw0yOTExMjcwOTA0MzlaMC0xKzApBgNVBAMTImFjY291bnRzLmFjY2Vzc2NvbnRyb2wud2luZG93cy5uZXQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCk57gZuV4UNPW7mKg5CM0yaG+QNWhUP1wEq9fsCUSu1YD8PrBzXDWLCR3pqthUI85tlsAHi8tJeuWkFI64/pVidfD9xOCeK/9zzatr6Z+oKkz5JurNnexWqVwT9vDG9yL/g+vEBbhHzmiTXY1fJnbfvUffIZ69puLfkCVWb5tDbhrOjWEsr/QPMHpPGwAg5M5wp8i6A5+o/4HZHi9Hi7qlh+1LwxZzF6Cv2Inkk225sjVfYkenEmGRGgxXAoWdrYZea3N3EuLfJL8Z99UIp9NsUyFqVslLWVKYuqGYtlkv6WtBnp9qexOXW9wZyjGfM9PAGRkySzSJVXjJ/Kb+AoFdAgMBAAGjITAfMB0GA1UdDgQWBBSTO5FmUwwGS+1CNqg2uNgjxUjFijANBgkqhkiG9w0BAQsFAAOCAQEAok04z0ICMEHGqDTzx6eD7vvJP8itJTCSz8JcZcGVJofJpViGF3bNnyeSPa7vNDYP1Ps9XBvw3/n2s+yynZ8EwFxMyxCZRCSbLv0N+cAbH3rmZqGcgMJszZVwcFUtXQPTe1ZRyHtEyOB+PVFH7K7obysRVO/cC6EGqIF3pYWzez/dtMaXRAkdTNlz0ko62WoA4eMPwUFCITjW/Jxfxl0BNUbo82PXXKhaeVJb+EgFG5b/pWWPswWmBoQhmD5G1UODvEACHRl/cHsPPqe4YE+6D1/wMno/xqqyGltnk8v0d4TpNcQMn9oM19V+OGgrzWOvvXhvnhqUIVGMsRlyBGNHAw=="
|
||||||
|
],
|
||||||
|
"cloud_instance_name": "microsoftonline.com",
|
||||||
|
"issuer": "https://login.microsoftonline.com/bdfd79b3-8401-47fb-a764-6e595c455b05/v2.0",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
local_cache.set_cache(
|
||||||
|
key="litellm_jwt_auth_keys",
|
||||||
|
value=keys,
|
||||||
|
)
|
||||||
|
|
||||||
|
litellm_jwtauth = LiteLLM_JWTAuth(
|
||||||
|
**{
|
||||||
|
"admin_jwt_scope": "litellm_proxy_endpoints_access",
|
||||||
|
"admin_allowed_routes": ["openai_routes", "info_routes"],
|
||||||
|
"public_key_ttl": 600,
|
||||||
|
"enforce_rbac": enforce_rbac,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
jwt_handler = JWTHandler()
|
||||||
|
jwt_handler.update_environment(
|
||||||
|
prisma_client=None,
|
||||||
|
user_api_key_cache=local_cache,
|
||||||
|
litellm_jwtauth=litellm_jwtauth,
|
||||||
|
leeway=10000000000000,
|
||||||
|
)
|
||||||
|
args = {
|
||||||
|
"api_key": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsIng1dCI6InoxcnNZSEhKOS04bWdndDRIc1p1OEJLa0JQdyIsImtpZCI6InoxcnNZSEhKOS04bWdndDRIc1p1OEJLa0JQdyJ9.eyJhdWQiOiJhcGk6Ly9MaXRlTExNX1Byb3h5LWRldiIsImlzcyI6Imh0dHBzOi8vc3RzLndpbmRvd3MubmV0L2JkZmQ3OWIzLTg0MDEtNDdmYi1hNzY0LTZlNTk1YzQ1NWIwNS8iLCJpYXQiOjE3MzcyNDE3ODEsIm5iZiI6MTczNzI0MTc4MSwiZXhwIjoxNzM3MjQ1NjgxLCJhaW8iOiJrMlJnWUpBNE5hZGg4MGJQdXlyRmxlV1o3dHZiQUE9PSIsImFwcGlkIjoiOGNjZjNkMDItMmNkNi00N2I5LTgxODUtMGVkYjI0YWJjZjY5IiwiYXBwaWRhY3IiOiIxIiwiaWRwIjoiaHR0cHM6Ly9zdHMud2luZG93cy5uZXQvYmRmZDc5YjMtODQwMS00N2ZiLWE3NjQtNmU1OTVjNDU1YjA1LyIsIm9pZCI6IjQ0YTg3YTYzLWFiNTUtNDc4NS1iMmFmLTMzNjllZWM4ZTEzOSIsInJoIjoiMS5BYjBBczNuOXZRR0UtMGVuWkc1WlhFVmJCY0VDbkl6NHJxaE9wZ2E0UGZSZjBsbTlBQUM5QUEuIiwic3ViIjoiNDRhODdhNjMtYWI1NS00Nzg1LWIyYWYtMzM2OWVlYzhlMTM5IiwidGlkIjoiYmRmZDc5YjMtODQwMS00N2ZiLWE3NjQtNmU1OTVjNDU1YjA1IiwidXRpIjoiY3ltNVhlcmhIMHVMSlNZU1JyQmhBQSIsInZlciI6IjEuMCJ9.UooJjM9pS-wgYsExqgHdrYyQhp7NbwAsr7au9dWJaLpsufXeyHJSg-Xd5VJ4RsDVJiDes3jkC7WeoAiaCfzEHpAum-p_aqqLYXf1QIYbi1hLC0m7y_klFcqMp11WbDa9TSTvg-o8q3x2Y5su8X23ymlFih4OP17b7JA6a4_2MybU5QkCEW1tQK6VspuuXzeDHvbfGeGYcIptHFyfttHMHHXRtX1o9bX7gOR_dwFITAXD18T4ZdAN_0y6f1OtVF9TMWQhMXhKU8ahn8TSg_CXmPl9T_1gV3ZWLvVtcdVrWs82fDz3-2lEw28z4bQEr1Z5xoAz7srhx1WEBu_ioAcQiA",
|
||||||
|
"jwt_handler": jwt_handler,
|
||||||
|
"route": "/v1/chat/completions",
|
||||||
|
"prisma_client": None,
|
||||||
|
"user_api_key_cache": Mock(),
|
||||||
|
"parent_otel_span": None,
|
||||||
|
"proxy_logging_obj": Mock(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if enforce_rbac:
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
await _jwt_auth_user_api_key_auth_builder(**args)
|
||||||
|
else:
|
||||||
|
await _jwt_auth_user_api_key_auth_builder(**args)
|
||||||
|
|
|
@ -11,6 +11,7 @@ interface DashboardTeamProps {
|
||||||
setProxySettings: React.Dispatch<React.SetStateAction<ProxySettings | null>>;
|
setProxySettings: React.Dispatch<React.SetStateAction<ProxySettings | null>>;
|
||||||
userInfo: UserInfo | null;
|
userInfo: UserInfo | null;
|
||||||
accessToken: string | null;
|
accessToken: string | null;
|
||||||
|
setKeys: React.Dispatch<React.SetStateAction<any | null>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
type TeamInterface = {
|
type TeamInterface = {
|
||||||
|
@ -27,7 +28,8 @@ const DashboardTeam: React.FC<DashboardTeamProps> = ({
|
||||||
proxySettings,
|
proxySettings,
|
||||||
setProxySettings,
|
setProxySettings,
|
||||||
userInfo,
|
userInfo,
|
||||||
accessToken
|
accessToken,
|
||||||
|
setKeys
|
||||||
}) => {
|
}) => {
|
||||||
console.log(`userInfo: ${JSON.stringify(userInfo)}`)
|
console.log(`userInfo: ${JSON.stringify(userInfo)}`)
|
||||||
const defaultTeam: TeamInterface = {
|
const defaultTeam: TeamInterface = {
|
||||||
|
@ -80,7 +82,10 @@ const DashboardTeam: React.FC<DashboardTeamProps> = ({
|
||||||
<SelectItem
|
<SelectItem
|
||||||
key={index}
|
key={index}
|
||||||
value={String(index)}
|
value={String(index)}
|
||||||
onClick={() => setSelectedTeam(team)}
|
onClick={() => {
|
||||||
|
setSelectedTeam(team);
|
||||||
|
setKeys(team["keys"]);
|
||||||
|
}}
|
||||||
>
|
>
|
||||||
{team["team_alias"]}
|
{team["team_alias"]}
|
||||||
</SelectItem>
|
</SelectItem>
|
||||||
|
|
|
@ -206,13 +206,15 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
|
|
||||||
setUserSpendData(response["user_info"]);
|
setUserSpendData(response["user_info"]);
|
||||||
console.log(`userSpendData: ${JSON.stringify(userSpendData)}`)
|
console.log(`userSpendData: ${JSON.stringify(userSpendData)}`)
|
||||||
setKeys(response["keys"]); // Assuming this is the correct path to your data
|
|
||||||
const teamsArray = [...response["teams"]];
|
const teamsArray = [...response["teams"]];
|
||||||
if (teamsArray.length > 0) {
|
if (teamsArray.length > 0) {
|
||||||
console.log(`response['teams']: ${teamsArray}`);
|
console.log(`response['teams']: ${JSON.stringify(teamsArray)}`);
|
||||||
setSelectedTeam(teamsArray[0]);
|
setSelectedTeam(teamsArray[0]);
|
||||||
|
setKeys(teamsArray[0]["keys"]); // Assuming this is the correct path to your data
|
||||||
} else {
|
} else {
|
||||||
setSelectedTeam(defaultTeam);
|
setSelectedTeam(defaultTeam);
|
||||||
|
setKeys(response["keys"]); // Assuming this is the correct path to your data
|
||||||
}
|
}
|
||||||
sessionStorage.setItem(
|
sessionStorage.setItem(
|
||||||
"userData" + userID,
|
"userData" + userID,
|
||||||
|
@ -261,6 +263,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
selectedTeam.team_id !== null
|
selectedTeam.team_id !== null
|
||||||
) {
|
) {
|
||||||
let sum = 0;
|
let sum = 0;
|
||||||
|
console.log(`keys: ${JSON.stringify(keys)}`)
|
||||||
for (const key of keys) {
|
for (const key of keys) {
|
||||||
if (
|
if (
|
||||||
selectedTeam.hasOwnProperty("team_id") &&
|
selectedTeam.hasOwnProperty("team_id") &&
|
||||||
|
@ -270,6 +273,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
sum += key.spend;
|
sum += key.spend;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
console.log(`sum: ${sum}`)
|
||||||
setTeamSpend(sum);
|
setTeamSpend(sum);
|
||||||
} else if (keys !== null) {
|
} else if (keys !== null) {
|
||||||
// sum the keys which don't have team-id set (default team)
|
// sum the keys which don't have team-id set (default team)
|
||||||
|
@ -367,6 +371,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
setProxySettings={setProxySettings}
|
setProxySettings={setProxySettings}
|
||||||
userInfo={userSpendData}
|
userInfo={userSpendData}
|
||||||
accessToken={accessToken}
|
accessToken={accessToken}
|
||||||
|
setKeys={setKeys}
|
||||||
/>
|
/>
|
||||||
</Col>
|
</Col>
|
||||||
</Grid>
|
</Grid>
|
||||||
|
|
|
@ -114,8 +114,6 @@ const ViewUserSpend: React.FC<ViewUserSpendProps> = ({ userID, userRole, accessT
|
||||||
fetchData();
|
fetchData();
|
||||||
}, [userRole, accessToken, userID]);
|
}, [userRole, accessToken, userID]);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (userSpend !== null) {
|
if (userSpend !== null) {
|
||||||
setSpend(userSpend)
|
setSpend(userSpend)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue