mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
test(test_jwt.py): add unit tests for jwt auth integration
This commit is contained in:
parent
03b8444d3c
commit
2e4e97a48f
2 changed files with 169 additions and 12 deletions
|
@ -147,9 +147,7 @@ class JWTHandler:
|
|||
scopes = []
|
||||
return scopes
|
||||
|
||||
async def auth_jwt(self, token: str) -> dict:
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
|
||||
async def get_public_key(self, kid: Optional[str]) -> dict:
|
||||
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
|
||||
|
||||
if keys_url is None:
|
||||
|
@ -169,19 +167,33 @@ class JWTHandler:
|
|||
else:
|
||||
keys = cached_keys
|
||||
|
||||
public_key: Optional[dict] = None
|
||||
|
||||
if len(keys) == 1:
|
||||
public_key = keys[0]
|
||||
elif len(keys) > 1:
|
||||
for key in keys:
|
||||
if kid is not None and key["kid"] == kid:
|
||||
public_key = key
|
||||
|
||||
if public_key is None:
|
||||
raise Exception(
|
||||
f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}"
|
||||
)
|
||||
|
||||
return public_key
|
||||
|
||||
async def auth_jwt(self, token: str) -> dict:
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
|
||||
header = jwt.get_unverified_header(token)
|
||||
|
||||
verbose_proxy_logger.debug(f"header: {header}")
|
||||
|
||||
kid = header.get("kid", None)
|
||||
|
||||
public_key = None
|
||||
if len(keys) == 1:
|
||||
public_key = public_key
|
||||
elif len(keys) > 1:
|
||||
for key in keys:
|
||||
if key["kid"] == kid:
|
||||
public_key = key
|
||||
public_key = await self.get_public_key(kid=kid)
|
||||
|
||||
if public_key is not None and isinstance(public_key, dict):
|
||||
jwk = {}
|
||||
if "kty" in public_key:
|
||||
|
@ -193,13 +205,13 @@ class JWTHandler:
|
|||
if "e" in public_key:
|
||||
jwk["e"] = public_key["e"]
|
||||
|
||||
public_key = RSAAlgorithm.from_jwk(json.dumps(jwk))
|
||||
public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk))
|
||||
|
||||
try:
|
||||
# decode the token using the public key
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
public_key, # type: ignore
|
||||
public_key_rsa, # type: ignore
|
||||
algorithms=["RS256"],
|
||||
options={"verify_aud": False},
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue