fix(handle_jwt.py): support public key caching ttl param

This commit is contained in:
Krrish Dholakia 2024-03-26 14:32:55 -07:00
parent d90f44fe8e
commit 752516df1b
6 changed files with 26 additions and 25 deletions

View file

@ -124,7 +124,7 @@ general_settings:
### Allowed LiteLLM scopes ### Allowed LiteLLM scopes
```python ```python
class LiteLLMProxyRoles(LiteLLMBase): class LiteLLM_JWTAuth(LiteLLMBase):
proxy_admin: str = "litellm_proxy_admin" proxy_admin: str = "litellm_proxy_admin"
proxy_user: str = "litellm_user" # 👈 Not implemented yet, for JWT-Auth. proxy_user: str = "litellm_user" # 👈 Not implemented yet, for JWT-Auth.
``` ```

View file

@ -90,7 +90,7 @@ class LiteLLMRoutes(enum.Enum):
] ]
class LiteLLMProxyRoles(LiteLLMBase): class LiteLLM_JWTAuth(LiteLLMBase):
""" """
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth. A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
@ -115,6 +115,7 @@ class LiteLLMProxyRoles(LiteLLMBase):
Literal["openai_routes", "info_routes", "management_routes"] Literal["openai_routes", "info_routes", "management_routes"]
] = ["openai_routes", "info_routes"] ] = ["openai_routes", "info_routes"]
end_user_id_jwt_field: Optional[str] = "sub" end_user_id_jwt_field: Optional[str] = "sub"
public_key_ttl: float = 600
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
# get the attribute names for this Pydantic model # get the attribute names for this Pydantic model

View file

@ -11,7 +11,7 @@ Run checks for:
from litellm.proxy._types import ( from litellm.proxy._types import (
LiteLLM_UserTable, LiteLLM_UserTable,
LiteLLM_EndUserTable, LiteLLM_EndUserTable,
LiteLLMProxyRoles, LiteLLM_JWTAuth,
LiteLLM_TeamTable, LiteLLM_TeamTable,
LiteLLMRoutes, LiteLLMRoutes,
) )
@ -88,7 +88,7 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
def allowed_routes_check( def allowed_routes_check(
user_role: Literal["proxy_admin", "team"], user_role: Literal["proxy_admin", "team"],
user_route: str, user_route: str,
litellm_proxy_roles: LiteLLMProxyRoles, litellm_proxy_roles: LiteLLM_JWTAuth,
) -> bool: ) -> bool:
""" """
Check if user -> not admin - allowed to access these routes Check if user -> not admin - allowed to access these routes

View file

@ -12,7 +12,7 @@ import json
import os import os
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import LiteLLMProxyRoles, LiteLLM_UserTable from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable
from litellm.proxy.utils import PrismaClient from litellm.proxy.utils import PrismaClient
from typing import Optional from typing import Optional
@ -70,30 +70,30 @@ class JWTHandler:
self, self,
prisma_client: Optional[PrismaClient], prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache, user_api_key_cache: DualCache,
litellm_proxy_roles: LiteLLMProxyRoles, litellm_jwtauth: LiteLLM_JWTAuth,
) -> None: ) -> None:
self.prisma_client = prisma_client self.prisma_client = prisma_client
self.user_api_key_cache = user_api_key_cache self.user_api_key_cache = user_api_key_cache
self.litellm_proxy_roles = litellm_proxy_roles self.litellm_jwtauth = litellm_jwtauth
def is_jwt(self, token: str): def is_jwt(self, token: str):
parts = token.split(".") parts = token.split(".")
return len(parts) == 3 return len(parts) == 3
def is_admin(self, scopes: list) -> bool: def is_admin(self, scopes: list) -> bool:
if self.litellm_proxy_roles.admin_jwt_scope in scopes: if self.litellm_jwtauth.admin_jwt_scope in scopes:
return True return True
return False return False
def is_team(self, scopes: list) -> bool: def is_team(self, scopes: list) -> bool:
if self.litellm_proxy_roles.team_jwt_scope in scopes: if self.litellm_jwtauth.team_jwt_scope in scopes:
return True return True
return False return False
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str: def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
try: try:
if self.litellm_proxy_roles.team_id_jwt_field is not None: if self.litellm_jwtauth.end_user_id_jwt_field is not None:
user_id = token[self.litellm_proxy_roles.team_id_jwt_field] user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
else: else:
user_id = None user_id = None
except KeyError: except KeyError:
@ -102,7 +102,7 @@ class JWTHandler:
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try: try:
team_id = token[self.litellm_proxy_roles.team_id_jwt_field] team_id = token[self.litellm_jwtauth.team_id_jwt_field]
except KeyError: except KeyError:
team_id = default_value team_id = default_value
return team_id return team_id
@ -137,7 +137,9 @@ class JWTHandler:
keys = response.json()["keys"] keys = response.json()["keys"]
await self.user_api_key_cache.async_set_cache( await self.user_api_key_cache.async_set_cache(
key="litellm_jwt_auth_keys", value=keys, ttl=600 # cache for 10 mins key="litellm_jwt_auth_keys",
value=keys,
ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins
) )
else: else:
keys = cached_keys keys = cached_keys

View file

@ -378,13 +378,13 @@ async def user_api_key_auth(
is_allowed = allowed_routes_check( is_allowed = allowed_routes_check(
user_role="proxy_admin", user_role="proxy_admin",
user_route=route, user_route=route,
litellm_proxy_roles=jwt_handler.litellm_proxy_roles, litellm_proxy_roles=jwt_handler.litellm_jwtauth,
) )
if is_allowed: if is_allowed:
return UserAPIKeyAuth() return UserAPIKeyAuth()
else: else:
allowed_routes = ( allowed_routes = (
jwt_handler.litellm_proxy_roles.admin_allowed_routes jwt_handler.litellm_jwtauth.admin_allowed_routes
) )
actual_routes = get_actual_routes(allowed_routes=allowed_routes) actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception( raise Exception(
@ -394,23 +394,23 @@ async def user_api_key_auth(
is_team = jwt_handler.is_team(scopes=scopes) is_team = jwt_handler.is_team(scopes=scopes)
if is_team == False: if is_team == False:
raise Exception( raise Exception(
f"Missing both Admin and Team scopes from token. Either is required. Admin Scope={jwt_handler.litellm_proxy_roles.admin_jwt_scope}, Team Scope={jwt_handler.litellm_proxy_roles.team_jwt_scope}" f"Missing both Admin and Team scopes from token. Either is required. Admin Scope={jwt_handler.litellm_jwtauth.admin_jwt_scope}, Team Scope={jwt_handler.litellm_jwtauth.team_jwt_scope}"
) )
# get team id # get team id
team_id = jwt_handler.get_team_id(token=valid_token, default_value=None) team_id = jwt_handler.get_team_id(token=valid_token, default_value=None)
if team_id is None: if team_id is None:
raise Exception( raise Exception(
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_proxy_roles.team_id_jwt_field}'" f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
) )
# check allowed team routes # check allowed team routes
is_allowed = allowed_routes_check( is_allowed = allowed_routes_check(
user_role="team", user_role="team",
user_route=route, user_route=route,
litellm_proxy_roles=jwt_handler.litellm_proxy_roles, litellm_proxy_roles=jwt_handler.litellm_jwtauth,
) )
if is_allowed == False: if is_allowed == False:
allowed_routes = jwt_handler.litellm_proxy_roles.team_allowed_routes allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes
actual_routes = get_actual_routes(allowed_routes=allowed_routes) actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception( raise Exception(
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}" f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
@ -2741,11 +2741,9 @@ async def startup_event():
## JWT AUTH ## ## JWT AUTH ##
if general_settings.get("litellm_proxy_roles", None) is not None: if general_settings.get("litellm_proxy_roles", None) is not None:
litellm_proxy_roles = LiteLLMProxyRoles( litellm_proxy_roles = LiteLLM_JWTAuth(**general_settings["litellm_proxy_roles"])
**general_settings["litellm_proxy_roles"]
)
else: else:
litellm_proxy_roles = LiteLLMProxyRoles() litellm_proxy_roles = LiteLLM_JWTAuth()
jwt_handler.update_environment( jwt_handler.update_environment(
prisma_client=prisma_client, prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,

View file

@ -12,7 +12,7 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
from litellm.proxy._types import LiteLLMProxyRoles from litellm.proxy._types import LiteLLM_JWTAuth
from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.caching import DualCache from litellm.caching import DualCache
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -32,7 +32,7 @@ def test_load_config_with_custom_role_names():
} }
} }
proxy_roles = LiteLLMProxyRoles( proxy_roles = LiteLLM_JWTAuth(
**config.get("general_settings", {}).get("litellm_proxy_roles", {}) **config.get("general_settings", {}).get("litellm_proxy_roles", {})
) )