fix: synchronize concurrent coroutines checking & updating key set (#2215)

# What does this PR do?

This PR adds a lock to coordinate concurrent coroutines passing through
the jwt verification. As _refresh_jwks() was setting _jwks to an empty
dict then repopulating it, having multiple coroutines doing this
concurrently risks losing keys. The PR also builds the updated dict as a
separate object and assigns it to _jwks once completed. This avoids
impacting any coroutines using the key set as it is being updated.

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
grs 2025-05-20 13:00:44 -04:00 committed by GitHub
parent 3339844fda
commit 87a4b9cb28
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -7,6 +7,7 @@
import json
import time
from abc import ABC, abstractmethod
from asyncio import Lock
from enum import Enum
from urllib.parse import parse_qs
@ -236,6 +237,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
self.config = config
self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {}
self._jwks_lock = Lock()
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the JWT token."""
@ -271,17 +273,19 @@ class OAuth2TokenAuthProvider(AuthProvider):
"""Close the HTTP client."""
async def _refresh_jwks(self) -> None:
if time.time() - self._jwks_at > self.config.cache_ttl:
async with httpx.AsyncClient() as client:
res = await client.get(self.config.jwks_uri, timeout=5)
res.raise_for_status()
jwks_data = res.json()["keys"]
self._jwks = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
self._jwks[kid] = k
self._jwks_at = time.time()
async with self._jwks_lock:
if time.time() - self._jwks_at > self.config.cache_ttl:
async with httpx.AsyncClient() as client:
res = await client.get(self.config.jwks_uri, timeout=5)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
updated[kid] = k
self._jwks = updated
self._jwks_at = time.time()
class CustomAuthProviderConfig(BaseModel):