fix(handle_jwt.py): support public key caching ttl param

This commit is contained in:
Krrish Dholakia 2024-03-26 14:32:55 -07:00
parent d90f44fe8e
commit 752516df1b
6 changed files with 26 additions and 25 deletions

View file

@ -12,7 +12,7 @@ import json
import os
from litellm.caching import DualCache
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import LiteLLMProxyRoles, LiteLLM_UserTable
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable
from litellm.proxy.utils import PrismaClient
from typing import Optional
@ -70,30 +70,30 @@ class JWTHandler:
self,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
litellm_proxy_roles: LiteLLMProxyRoles,
litellm_jwtauth: LiteLLM_JWTAuth,
) -> None:
self.prisma_client = prisma_client
self.user_api_key_cache = user_api_key_cache
self.litellm_proxy_roles = litellm_proxy_roles
self.litellm_jwtauth = litellm_jwtauth
def is_jwt(self, token: str):
parts = token.split(".")
return len(parts) == 3
def is_admin(self, scopes: list) -> bool:
if self.litellm_proxy_roles.admin_jwt_scope in scopes:
if self.litellm_jwtauth.admin_jwt_scope in scopes:
return True
return False
def is_team(self, scopes: list) -> bool:
if self.litellm_proxy_roles.team_jwt_scope in scopes:
if self.litellm_jwtauth.team_jwt_scope in scopes:
return True
return False
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
try:
if self.litellm_proxy_roles.team_id_jwt_field is not None:
user_id = token[self.litellm_proxy_roles.team_id_jwt_field]
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
else:
user_id = None
except KeyError:
@ -102,7 +102,7 @@ class JWTHandler:
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
team_id = token[self.litellm_proxy_roles.team_id_jwt_field]
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
except KeyError:
team_id = default_value
return team_id
@ -137,7 +137,9 @@ class JWTHandler:
keys = response.json()["keys"]
await self.user_api_key_cache.async_set_cache(
key="litellm_jwt_auth_keys", value=keys, ttl=600 # cache for 10 mins
key="litellm_jwt_auth_keys",
value=keys,
ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins
)
else:
keys = cached_keys