diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 4065a65f3..b73fded58 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -7,6 +7,7 @@ import json import time from abc import ABC, abstractmethod +from asyncio import Lock from enum import Enum from urllib.parse import parse_qs @@ -236,6 +237,7 @@ class OAuth2TokenAuthProvider(AuthProvider): self.config = config self._jwks_at: float = 0.0 self._jwks: dict[str, str] = {} + self._jwks_lock = Lock() async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: """Validate a token using the JWT token.""" @@ -271,17 +273,19 @@ class OAuth2TokenAuthProvider(AuthProvider): """Close the HTTP client.""" async def _refresh_jwks(self) -> None: - if time.time() - self._jwks_at > self.config.cache_ttl: - async with httpx.AsyncClient() as client: - res = await client.get(self.config.jwks_uri, timeout=5) - res.raise_for_status() - jwks_data = res.json()["keys"] - self._jwks = {} - for k in jwks_data: - kid = k["kid"] - # Store the entire key object as it may be needed for different algorithms - self._jwks[kid] = k - self._jwks_at = time.time() + async with self._jwks_lock: + if time.time() - self._jwks_at > self.config.cache_ttl: + async with httpx.AsyncClient() as client: + res = await client.get(self.config.jwks_uri, timeout=5) + res.raise_for_status() + jwks_data = res.json()["keys"] + updated = {} + for k in jwks_data: + kid = k["kid"] + # Store the entire key object as it may be needed for different algorithms + updated[kid] = k + self._jwks = updated + self._jwks_at = time.time() class CustomAuthProviderConfig(BaseModel):