From 87a4b9cb28f8e9f94c40e79a2a8a8738e24aebe1 Mon Sep 17 00:00:00 2001 From: grs Date: Tue, 20 May 2025 13:00:44 -0400 Subject: [PATCH] fix: synchronize concurrent coroutines checking & updating key set (#2215) # What does this PR do? This PR adds a lock to coordinate concurrent coroutines passing through the jwt verification. As _refresh_jwks() was setting _jwks to an empty dict then repopulating it, having multiple coroutines doing this concurrently risks losing keys. The PR also builds the updated dict as a separate object and assigns it to _jwks once completed. This avoids impacting any coroutines using the key set as it is being updated. Signed-off-by: Gordon Sim --- .../distribution/server/auth_providers.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 4065a65f3..b73fded58 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -7,6 +7,7 @@ import json import time from abc import ABC, abstractmethod +from asyncio import Lock from enum import Enum from urllib.parse import parse_qs @@ -236,6 +237,7 @@ class OAuth2TokenAuthProvider(AuthProvider): self.config = config self._jwks_at: float = 0.0 self._jwks: dict[str, str] = {} + self._jwks_lock = Lock() async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: """Validate a token using the JWT token.""" @@ -271,17 +273,19 @@ class OAuth2TokenAuthProvider(AuthProvider): """Close the HTTP client.""" async def _refresh_jwks(self) -> None: - if time.time() - self._jwks_at > self.config.cache_ttl: - async with httpx.AsyncClient() as client: - res = await client.get(self.config.jwks_uri, timeout=5) - res.raise_for_status() - jwks_data = res.json()["keys"] - self._jwks = {} - for k in jwks_data: - kid = k["kid"] - # Store the entire key object as it may be needed for different algorithms - self._jwks[kid] = k - self._jwks_at = time.time() + async with self._jwks_lock: + if time.time() - self._jwks_at > self.config.cache_ttl: + async with httpx.AsyncClient() as client: + res = await client.get(self.config.jwks_uri, timeout=5) + res.raise_for_status() + jwks_data = res.json()["keys"] + updated = {} + for k in jwks_data: + kid = k["kid"] + # Store the entire key object as it may be needed for different algorithms + updated[kid] = k + self._jwks = updated + self._jwks_at = time.time() class CustomAuthProviderConfig(BaseModel):