mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
Merge pull request #2687 from BerriAI/litellm_jwt_auth_fixes_2
Litellm jwt auth fixes
This commit is contained in:
commit
f15ba10170
5 changed files with 300 additions and 43 deletions
|
@ -67,17 +67,21 @@ class JWTHandler:
|
|||
self.http_handler = HTTPHandler()
|
||||
|
||||
def update_environment(
|
||||
self, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache
|
||||
self,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
litellm_proxy_roles: LiteLLMProxyRoles,
|
||||
) -> None:
|
||||
self.prisma_client = prisma_client
|
||||
self.user_api_key_cache = user_api_key_cache
|
||||
self.litellm_proxy_roles = litellm_proxy_roles
|
||||
|
||||
def is_jwt(self, token: str):
|
||||
parts = token.split(".")
|
||||
return len(parts) == 3
|
||||
|
||||
def is_admin(self, scopes: list) -> bool:
|
||||
if LiteLLMProxyRoles.PROXY_ADMIN.value in scopes:
|
||||
if self.litellm_proxy_roles.proxy_admin in scopes:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -90,7 +94,7 @@ class JWTHandler:
|
|||
|
||||
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
team_id = token["azp"]
|
||||
team_id = token["client_id"]
|
||||
except KeyError:
|
||||
team_id = default_value
|
||||
return team_id
|
||||
|
@ -130,58 +134,94 @@ class JWTHandler:
|
|||
|
||||
def get_scopes(self, token: dict) -> list:
|
||||
try:
|
||||
# Assuming the scopes are stored in 'scope' claim and are space-separated
|
||||
scopes = token["scope"].split()
|
||||
if isinstance(token["scope"], str):
|
||||
# Assuming the scopes are stored in 'scope' claim and are space-separated
|
||||
scopes = token["scope"].split()
|
||||
elif isinstance(token["scope"], list):
|
||||
scopes = token["scope"]
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unmapped scope type - {type(token['scope'])}. Supported types - list, str."
|
||||
)
|
||||
except KeyError:
|
||||
scopes = []
|
||||
return scopes
|
||||
|
||||
async def auth_jwt(self, token: str) -> dict:
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
|
||||
async def get_public_key(self, kid: Optional[str]) -> dict:
|
||||
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
|
||||
|
||||
if keys_url is None:
|
||||
raise Exception("Missing JWT Public Key URL from environment.")
|
||||
|
||||
response = await self.http_handler.get(keys_url)
|
||||
cached_keys = await self.user_api_key_cache.async_get_cache(
|
||||
"litellm_jwt_auth_keys"
|
||||
)
|
||||
if cached_keys is None:
|
||||
response = await self.http_handler.get(keys_url)
|
||||
|
||||
keys = response.json()["keys"]
|
||||
keys = response.json()["keys"]
|
||||
|
||||
await self.user_api_key_cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys", value=keys, ttl=600 # cache for 10 mins
|
||||
)
|
||||
else:
|
||||
keys = cached_keys
|
||||
|
||||
public_key: Optional[dict] = None
|
||||
|
||||
if len(keys) == 1:
|
||||
public_key = keys[0]
|
||||
elif len(keys) > 1:
|
||||
for key in keys:
|
||||
if kid is not None and key["kid"] == kid:
|
||||
public_key = key
|
||||
|
||||
if public_key is None:
|
||||
raise Exception(
|
||||
f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}"
|
||||
)
|
||||
|
||||
return public_key
|
||||
|
||||
async def auth_jwt(self, token: str) -> dict:
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
|
||||
header = jwt.get_unverified_header(token)
|
||||
|
||||
verbose_proxy_logger.debug("header: %s", header)
|
||||
|
||||
if "kid" in header:
|
||||
kid = header["kid"]
|
||||
else:
|
||||
raise Exception(f"Expected 'kid' in header. header={header}.")
|
||||
kid = header.get("kid", None)
|
||||
|
||||
for key in keys:
|
||||
if key["kid"] == kid:
|
||||
jwk = {
|
||||
"kty": key["kty"],
|
||||
"kid": key["kid"],
|
||||
"n": key["n"],
|
||||
"e": key["e"],
|
||||
}
|
||||
public_key = RSAAlgorithm.from_jwk(json.dumps(jwk))
|
||||
public_key = await self.get_public_key(kid=kid)
|
||||
|
||||
try:
|
||||
# decode the token using the public key
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
public_key, # type: ignore
|
||||
algorithms=["RS256"],
|
||||
audience="account",
|
||||
)
|
||||
return payload
|
||||
if public_key is not None and isinstance(public_key, dict):
|
||||
jwk = {}
|
||||
if "kty" in public_key:
|
||||
jwk["kty"] = public_key["kty"]
|
||||
if "kid" in public_key:
|
||||
jwk["kid"] = public_key["kid"]
|
||||
if "n" in public_key:
|
||||
jwk["n"] = public_key["n"]
|
||||
if "e" in public_key:
|
||||
jwk["e"] = public_key["e"]
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
# the token is expired, do something to refresh it
|
||||
raise Exception("Token Expired")
|
||||
except Exception as e:
|
||||
raise Exception(f"Validation fails: {str(e)}")
|
||||
public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk))
|
||||
|
||||
try:
|
||||
# decode the token using the public key
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
public_key_rsa, # type: ignore
|
||||
algorithms=["RS256"],
|
||||
options={"verify_aud": False},
|
||||
)
|
||||
return payload
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
# the token is expired, do something to refresh it
|
||||
raise Exception("Token Expired")
|
||||
except Exception as e:
|
||||
raise Exception(f"Validation fails: {str(e)}")
|
||||
|
||||
raise Exception("Invalid JWT Submitted")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue