diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index bbcd10ada..d4e5834f2 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -38,8 +38,8 @@ class LiteLLMBase(BaseModel): class LiteLLMProxyRoles(LiteLLMBase): - PROXY_ADMIN: str = "litellm_proxy_admin" - PROXY_USER: str = "litellm_user" + proxy_admin: str = "litellm_proxy_admin" + proxy_user: str = "litellm_user" class LiteLLMPromptInjectionParams(LiteLLMBase): diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 2d7aa3d4b..c8eb7e838 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -81,7 +81,7 @@ class JWTHandler: return len(parts) == 3 def is_admin(self, scopes: list) -> bool: - if self.litellm_proxy_roles.PROXY_ADMIN in scopes: + if self.litellm_proxy_roles.proxy_admin in scopes: return True return False @@ -94,7 +94,7 @@ class JWTHandler: def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: try: - team_id = token["azp"] + team_id = token["client_id"] except KeyError: team_id = default_value return team_id diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 413d7aa4f..7d92d413e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2711,7 +2711,11 @@ async def startup_event(): ## JWT AUTH ## jwt_handler.update_environment( - prisma_client=prisma_client, user_api_key_cache=user_api_key_cache + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + litellm_proxy_roles=LiteLLMProxyRoles( + **general_settings.get("litellm_proxy_roles", {}) + ), ) if use_background_health_checks: diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py new file mode 100644 index 000000000..a2c9e4e4a --- /dev/null +++ b/litellm/tests/test_jwt.py @@ -0,0 +1,34 @@ +#### What this tests #### +# Unit tests for JWT-Auth + +import sys, os, asyncio, time, random +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +from litellm.proxy._types import LiteLLMProxyRoles + + +def test_load_config_with_custom_role_names(): + config = { + "general_settings": { + "litellm_proxy_roles": {"proxy_admin": "litellm-proxy-admin"} + } + } + + proxy_roles = LiteLLMProxyRoles( + **config.get("general_settings", {}).get("litellm_proxy_roles", {}) + ) + + print(f"proxy_roles: {proxy_roles}") + + assert proxy_roles.proxy_admin == "litellm-proxy-admin" + + +# test_load_config_with_custom_role_names()