fix: synchronize concurrent coroutines checking & updating key set (#2215)

# What does this PR do?

This PR adds a lock to coordinate concurrent coroutines passing through
the jwt verification. As _refresh_jwks() was setting _jwks to an empty
dict then repopulating it, having multiple coroutines doing this
concurrently risks losing keys. The PR also builds the updated dict as a
separate object and assigns it to _jwks once completed. This avoids
impacting any coroutines using the key set as it is being updated.

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
grs 2025-05-20 13:00:44 -04:00 committed by GitHub
parent 3339844fda
commit 87a4b9cb28
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -7,6 +7,7 @@
import json import json
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio import Lock
from enum import Enum from enum import Enum
from urllib.parse import parse_qs from urllib.parse import parse_qs
@ -236,6 +237,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
self.config = config self.config = config
self._jwks_at: float = 0.0 self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {} self._jwks: dict[str, str] = {}
self._jwks_lock = Lock()
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the JWT token.""" """Validate a token using the JWT token."""
@ -271,17 +273,19 @@ class OAuth2TokenAuthProvider(AuthProvider):
"""Close the HTTP client.""" """Close the HTTP client."""
async def _refresh_jwks(self) -> None: async def _refresh_jwks(self) -> None:
if time.time() - self._jwks_at > self.config.cache_ttl: async with self._jwks_lock:
async with httpx.AsyncClient() as client: if time.time() - self._jwks_at > self.config.cache_ttl:
res = await client.get(self.config.jwks_uri, timeout=5) async with httpx.AsyncClient() as client:
res.raise_for_status() res = await client.get(self.config.jwks_uri, timeout=5)
jwks_data = res.json()["keys"] res.raise_for_status()
self._jwks = {} jwks_data = res.json()["keys"]
for k in jwks_data: updated = {}
kid = k["kid"] for k in jwks_data:
# Store the entire key object as it may be needed for different algorithms kid = k["kid"]
self._jwks[kid] = k # Store the entire key object as it may be needed for different algorithms
self._jwks_at = time.time() updated[kid] = k
self._jwks = updated
self._jwks_at = time.time()
class CustomAuthProviderConfig(BaseModel): class CustomAuthProviderConfig(BaseModel):