mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(handle_jwt.py): support public key caching ttl param
This commit is contained in:
parent
d90f44fe8e
commit
752516df1b
6 changed files with 26 additions and 25 deletions
|
@ -12,7 +12,7 @@ import json
|
|||
import os
|
||||
from litellm.caching import DualCache
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import LiteLLMProxyRoles, LiteLLM_UserTable
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from typing import Optional
|
||||
|
||||
|
@ -70,30 +70,30 @@ class JWTHandler:
|
|||
self,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
litellm_proxy_roles: LiteLLMProxyRoles,
|
||||
litellm_jwtauth: LiteLLM_JWTAuth,
|
||||
) -> None:
|
||||
self.prisma_client = prisma_client
|
||||
self.user_api_key_cache = user_api_key_cache
|
||||
self.litellm_proxy_roles = litellm_proxy_roles
|
||||
self.litellm_jwtauth = litellm_jwtauth
|
||||
|
||||
def is_jwt(self, token: str):
|
||||
parts = token.split(".")
|
||||
return len(parts) == 3
|
||||
|
||||
def is_admin(self, scopes: list) -> bool:
|
||||
if self.litellm_proxy_roles.admin_jwt_scope in scopes:
|
||||
if self.litellm_jwtauth.admin_jwt_scope in scopes:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_team(self, scopes: list) -> bool:
|
||||
if self.litellm_proxy_roles.team_jwt_scope in scopes:
|
||||
if self.litellm_jwtauth.team_jwt_scope in scopes:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
|
||||
try:
|
||||
if self.litellm_proxy_roles.team_id_jwt_field is not None:
|
||||
user_id = token[self.litellm_proxy_roles.team_id_jwt_field]
|
||||
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
|
||||
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
|
||||
else:
|
||||
user_id = None
|
||||
except KeyError:
|
||||
|
@ -102,7 +102,7 @@ class JWTHandler:
|
|||
|
||||
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
team_id = token[self.litellm_proxy_roles.team_id_jwt_field]
|
||||
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
|
||||
except KeyError:
|
||||
team_id = default_value
|
||||
return team_id
|
||||
|
@ -137,7 +137,9 @@ class JWTHandler:
|
|||
keys = response.json()["keys"]
|
||||
|
||||
await self.user_api_key_cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys", value=keys, ttl=600 # cache for 10 mins
|
||||
key="litellm_jwt_auth_keys",
|
||||
value=keys,
|
||||
ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins
|
||||
)
|
||||
else:
|
||||
keys = cached_keys
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue