diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index d23049056b..bbcd10ada3 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -14,11 +14,6 @@ def hash_token(token: str): return hashed_token -class LiteLLMProxyRoles(enum.Enum): - PROXY_ADMIN = "litellm_proxy_admin" - USER = "litellm_user" - - class LiteLLMBase(BaseModel): """ Implements default functions, all pydantic objects should have. @@ -42,6 +37,11 @@ class LiteLLMBase(BaseModel): protected_namespaces = () +class LiteLLMProxyRoles(LiteLLMBase): + PROXY_ADMIN: str = "litellm_proxy_admin" + PROXY_USER: str = "litellm_user" + + class LiteLLMPromptInjectionParams(LiteLLMBase): heuristics_check: bool = False vector_db_check: bool = False diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 83effab7c3..2d7aa3d4b7 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -67,17 +67,21 @@ class JWTHandler: self.http_handler = HTTPHandler() def update_environment( - self, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache + self, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + litellm_proxy_roles: LiteLLMProxyRoles, ) -> None: self.prisma_client = prisma_client self.user_api_key_cache = user_api_key_cache + self.litellm_proxy_roles = litellm_proxy_roles def is_jwt(self, token: str): parts = token.split(".") return len(parts) == 3 def is_admin(self, scopes: list) -> bool: - if LiteLLMProxyRoles.PROXY_ADMIN.value in scopes: + if self.litellm_proxy_roles.PROXY_ADMIN in scopes: return True return False