test(test_jwt.py): add unit tests for jwt auth integration

This commit is contained in:
Krrish Dholakia 2024-03-25 13:24:39 -07:00
parent 03b8444d3c
commit 2e4e97a48f
2 changed files with 169 additions and 12 deletions

View file

@ -147,9 +147,7 @@ class JWTHandler:
scopes = []
return scopes
async def auth_jwt(self, token: str) -> dict:
from jwt.algorithms import RSAAlgorithm
async def get_public_key(self, kid: Optional[str]) -> dict:
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
if keys_url is None:
@ -169,19 +167,33 @@ class JWTHandler:
else:
keys = cached_keys
public_key: Optional[dict] = None
if len(keys) == 1:
public_key = keys[0]
elif len(keys) > 1:
for key in keys:
if kid is not None and key["kid"] == kid:
public_key = key
if public_key is None:
raise Exception(
f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}"
)
return public_key
async def auth_jwt(self, token: str) -> dict:
from jwt.algorithms import RSAAlgorithm
header = jwt.get_unverified_header(token)
verbose_proxy_logger.debug(f"header: {header}")
kid = header.get("kid", None)
public_key = None
if len(keys) == 1:
public_key = public_key
elif len(keys) > 1:
for key in keys:
if key["kid"] == kid:
public_key = key
public_key = await self.get_public_key(kid=kid)
if public_key is not None and isinstance(public_key, dict):
jwk = {}
if "kty" in public_key:
@ -193,13 +205,13 @@ class JWTHandler:
if "e" in public_key:
jwk["e"] = public_key["e"]
public_key = RSAAlgorithm.from_jwk(json.dumps(jwk))
public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk))
try:
# decode the token using the public key
payload = jwt.decode(
token,
public_key, # type: ignore
public_key_rsa, # type: ignore
algorithms=["RS256"],
options={"verify_aud": False},
)