forked from phoenix-oss/llama-stack-mirror
fix: synchronize concurrent coroutines checking & updating key set (#2215)
# What does this PR do? This PR adds a lock to coordinate concurrent coroutines passing through the jwt verification. As _refresh_jwks() was setting _jwks to an empty dict then repopulating it, having multiple coroutines doing this concurrently risks losing keys. The PR also builds the updated dict as a separate object and assigns it to _jwks once completed. This avoids impacting any coroutines using the key set as it is being updated. Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
parent
3339844fda
commit
87a4b9cb28
1 changed files with 15 additions and 11 deletions
|
@ -7,6 +7,7 @@
|
|||
import json
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Lock
|
||||
from enum import Enum
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
|
@ -236,6 +237,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
self.config = config
|
||||
self._jwks_at: float = 0.0
|
||||
self._jwks: dict[str, str] = {}
|
||||
self._jwks_lock = Lock()
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
"""Validate a token using the JWT token."""
|
||||
|
@ -271,17 +273,19 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
"""Close the HTTP client."""
|
||||
|
||||
async def _refresh_jwks(self) -> None:
|
||||
if time.time() - self._jwks_at > self.config.cache_ttl:
|
||||
async with httpx.AsyncClient() as client:
|
||||
res = await client.get(self.config.jwks_uri, timeout=5)
|
||||
res.raise_for_status()
|
||||
jwks_data = res.json()["keys"]
|
||||
self._jwks = {}
|
||||
for k in jwks_data:
|
||||
kid = k["kid"]
|
||||
# Store the entire key object as it may be needed for different algorithms
|
||||
self._jwks[kid] = k
|
||||
self._jwks_at = time.time()
|
||||
async with self._jwks_lock:
|
||||
if time.time() - self._jwks_at > self.config.cache_ttl:
|
||||
async with httpx.AsyncClient() as client:
|
||||
res = await client.get(self.config.jwks_uri, timeout=5)
|
||||
res.raise_for_status()
|
||||
jwks_data = res.json()["keys"]
|
||||
updated = {}
|
||||
for k in jwks_data:
|
||||
kid = k["kid"]
|
||||
# Store the entire key object as it may be needed for different algorithms
|
||||
updated[kid] = k
|
||||
self._jwks = updated
|
||||
self._jwks_at = time.time()
|
||||
|
||||
|
||||
class CustomAuthProviderConfig(BaseModel):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue