fix: synchronize concurrent coroutines checking key set

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
Gordon Sim 2025-05-20 13:02:31 +01:00
parent 6d20b720b8
commit b1ab9dce81

View file

@ -7,6 +7,7 @@
import json import json
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio import Lock
from enum import Enum from enum import Enum
from urllib.parse import parse_qs from urllib.parse import parse_qs
@ -236,6 +237,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
self.config = config self.config = config
self._jwks_at: float = 0.0 self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {} self._jwks: dict[str, str] = {}
self._jwks_lock = Lock()
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the JWT token.""" """Validate a token using the JWT token."""
@ -271,17 +273,19 @@ class OAuth2TokenAuthProvider(AuthProvider):
"""Close the HTTP client.""" """Close the HTTP client."""
async def _refresh_jwks(self) -> None: async def _refresh_jwks(self) -> None:
if time.time() - self._jwks_at > self.config.cache_ttl: async with self._jwks_lock:
async with httpx.AsyncClient() as client: if time.time() - self._jwks_at > self.config.cache_ttl:
res = await client.get(self.config.jwks_uri, timeout=5) async with httpx.AsyncClient() as client:
res.raise_for_status() res = await client.get(self.config.jwks_uri, timeout=5)
jwks_data = res.json()["keys"] res.raise_for_status()
self._jwks = {} jwks_data = res.json()["keys"]
for k in jwks_data: updated = {}
kid = k["kid"] for k in jwks_data:
# Store the entire key object as it may be needed for different algorithms kid = k["kid"]
self._jwks[kid] = k # Store the entire key object as it may be needed for different algorithms
self._jwks_at = time.time() updated[kid] = k
self._jwks = updated
self._jwks_at = time.time()
class CustomAuthProviderConfig(BaseModel): class CustomAuthProviderConfig(BaseModel):