From 752516df1b0663ef4336280c0b594110fb85c6e5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 26 Mar 2024 14:32:55 -0700 Subject: [PATCH] fix(handle_jwt.py): support public key caching ttl param --- docs/my-website/docs/proxy/token_auth.md | 2 +- litellm/proxy/_types.py | 3 ++- litellm/proxy/auth/auth_checks.py | 4 ++-- litellm/proxy/auth/handle_jwt.py | 20 +++++++++++--------- litellm/proxy/proxy_server.py | 18 ++++++++---------- litellm/tests/test_jwt.py | 4 ++-- 6 files changed, 26 insertions(+), 25 deletions(-) diff --git a/docs/my-website/docs/proxy/token_auth.md b/docs/my-website/docs/proxy/token_auth.md index 5f1812757..0fbce04db 100644 --- a/docs/my-website/docs/proxy/token_auth.md +++ b/docs/my-website/docs/proxy/token_auth.md @@ -124,7 +124,7 @@ general_settings: ### Allowed LiteLLM scopes ```python -class LiteLLMProxyRoles(LiteLLMBase): +class LiteLLM_JWTAuth(LiteLLMBase): proxy_admin: str = "litellm_proxy_admin" proxy_user: str = "litellm_user" # 👈 Not implemented yet, for JWT-Auth. ``` diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 86d842ed9..4fd1bf3b0 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -90,7 +90,7 @@ class LiteLLMRoutes(enum.Enum): ] -class LiteLLMProxyRoles(LiteLLMBase): +class LiteLLM_JWTAuth(LiteLLMBase): """ A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth. @@ -115,6 +115,7 @@ class LiteLLMProxyRoles(LiteLLMBase): Literal["openai_routes", "info_routes", "management_routes"] ] = ["openai_routes", "info_routes"] end_user_id_jwt_field: Optional[str] = "sub" + public_key_ttl: float = 600 def __init__(self, **kwargs: Any) -> None: # get the attribute names for this Pydantic model diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index f1ef5ca00..b8f7c6e3f 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -11,7 +11,7 @@ Run checks for: from litellm.proxy._types import ( LiteLLM_UserTable, LiteLLM_EndUserTable, - LiteLLMProxyRoles, + LiteLLM_JWTAuth, LiteLLM_TeamTable, LiteLLMRoutes, ) @@ -88,7 +88,7 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool: def allowed_routes_check( user_role: Literal["proxy_admin", "team"], user_route: str, - litellm_proxy_roles: LiteLLMProxyRoles, + litellm_proxy_roles: LiteLLM_JWTAuth, ) -> bool: """ Check if user -> not admin - allowed to access these routes diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 6cb67b171..08ffc0955 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -12,7 +12,7 @@ import json import os from litellm.caching import DualCache from litellm._logging import verbose_proxy_logger -from litellm.proxy._types import LiteLLMProxyRoles, LiteLLM_UserTable +from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable from litellm.proxy.utils import PrismaClient from typing import Optional @@ -70,30 +70,30 @@ class JWTHandler: self, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, - litellm_proxy_roles: LiteLLMProxyRoles, + litellm_jwtauth: LiteLLM_JWTAuth, ) -> None: self.prisma_client = prisma_client self.user_api_key_cache = user_api_key_cache - self.litellm_proxy_roles = litellm_proxy_roles + self.litellm_jwtauth = litellm_jwtauth def is_jwt(self, token: str): parts = token.split(".") return len(parts) == 3 def is_admin(self, scopes: list) -> bool: - if self.litellm_proxy_roles.admin_jwt_scope in scopes: + if self.litellm_jwtauth.admin_jwt_scope in scopes: return True return False def is_team(self, scopes: list) -> bool: - if self.litellm_proxy_roles.team_jwt_scope in scopes: + if self.litellm_jwtauth.team_jwt_scope in scopes: return True return False def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str: try: - if self.litellm_proxy_roles.team_id_jwt_field is not None: - user_id = token[self.litellm_proxy_roles.team_id_jwt_field] + if self.litellm_jwtauth.end_user_id_jwt_field is not None: + user_id = token[self.litellm_jwtauth.end_user_id_jwt_field] else: user_id = None except KeyError: @@ -102,7 +102,7 @@ class JWTHandler: def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: try: - team_id = token[self.litellm_proxy_roles.team_id_jwt_field] + team_id = token[self.litellm_jwtauth.team_id_jwt_field] except KeyError: team_id = default_value return team_id @@ -137,7 +137,9 @@ class JWTHandler: keys = response.json()["keys"] await self.user_api_key_cache.async_set_cache( - key="litellm_jwt_auth_keys", value=keys, ttl=600 # cache for 10 mins + key="litellm_jwt_auth_keys", + value=keys, + ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins ) else: keys = cached_keys diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d16b297fe..c2a5c5372 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -378,13 +378,13 @@ async def user_api_key_auth( is_allowed = allowed_routes_check( user_role="proxy_admin", user_route=route, - litellm_proxy_roles=jwt_handler.litellm_proxy_roles, + litellm_proxy_roles=jwt_handler.litellm_jwtauth, ) if is_allowed: return UserAPIKeyAuth() else: allowed_routes = ( - jwt_handler.litellm_proxy_roles.admin_allowed_routes + jwt_handler.litellm_jwtauth.admin_allowed_routes ) actual_routes = get_actual_routes(allowed_routes=allowed_routes) raise Exception( @@ -394,23 +394,23 @@ async def user_api_key_auth( is_team = jwt_handler.is_team(scopes=scopes) if is_team == False: raise Exception( - f"Missing both Admin and Team scopes from token. Either is required. Admin Scope={jwt_handler.litellm_proxy_roles.admin_jwt_scope}, Team Scope={jwt_handler.litellm_proxy_roles.team_jwt_scope}" + f"Missing both Admin and Team scopes from token. Either is required. Admin Scope={jwt_handler.litellm_jwtauth.admin_jwt_scope}, Team Scope={jwt_handler.litellm_jwtauth.team_jwt_scope}" ) # get team id team_id = jwt_handler.get_team_id(token=valid_token, default_value=None) if team_id is None: raise Exception( - f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_proxy_roles.team_id_jwt_field}'" + f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'" ) # check allowed team routes is_allowed = allowed_routes_check( user_role="team", user_route=route, - litellm_proxy_roles=jwt_handler.litellm_proxy_roles, + litellm_proxy_roles=jwt_handler.litellm_jwtauth, ) if is_allowed == False: - allowed_routes = jwt_handler.litellm_proxy_roles.team_allowed_routes + allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes actual_routes = get_actual_routes(allowed_routes=allowed_routes) raise Exception( f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}" @@ -2741,11 +2741,9 @@ async def startup_event(): ## JWT AUTH ## if general_settings.get("litellm_proxy_roles", None) is not None: - litellm_proxy_roles = LiteLLMProxyRoles( - **general_settings["litellm_proxy_roles"] - ) + litellm_proxy_roles = LiteLLM_JWTAuth(**general_settings["litellm_proxy_roles"]) else: - litellm_proxy_roles = LiteLLMProxyRoles() + litellm_proxy_roles = LiteLLM_JWTAuth() jwt_handler.update_environment( prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index 57c7e5c62..ee1c67fdd 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -12,7 +12,7 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest -from litellm.proxy._types import LiteLLMProxyRoles +from litellm.proxy._types import LiteLLM_JWTAuth from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.caching import DualCache from datetime import datetime, timedelta @@ -32,7 +32,7 @@ def test_load_config_with_custom_role_names(): } } - proxy_roles = LiteLLMProxyRoles( + proxy_roles = LiteLLM_JWTAuth( **config.get("general_settings", {}).get("litellm_proxy_roles", {}) )