forked from phoenix-oss/llama-stack-mirror
fix: synchronize concurrent coroutines checking & updating key set (#2215)
# What does this PR do? This PR adds a lock to coordinate concurrent coroutines passing through the jwt verification. As _refresh_jwks() was setting _jwks to an empty dict then repopulating it, having multiple coroutines doing this concurrently risks losing keys. The PR also builds the updated dict as a separate object and assigns it to _jwks once completed. This avoids impacting any coroutines using the key set as it is being updated. Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
parent
3339844fda
commit
87a4b9cb28
1 changed files with 15 additions and 11 deletions
|
@ -7,6 +7,7 @@
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from asyncio import Lock
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs
|
||||||
|
|
||||||
|
@ -236,6 +237,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._jwks_at: float = 0.0
|
self._jwks_at: float = 0.0
|
||||||
self._jwks: dict[str, str] = {}
|
self._jwks: dict[str, str] = {}
|
||||||
|
self._jwks_lock = Lock()
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||||
"""Validate a token using the JWT token."""
|
"""Validate a token using the JWT token."""
|
||||||
|
@ -271,17 +273,19 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
"""Close the HTTP client."""
|
"""Close the HTTP client."""
|
||||||
|
|
||||||
async def _refresh_jwks(self) -> None:
|
async def _refresh_jwks(self) -> None:
|
||||||
if time.time() - self._jwks_at > self.config.cache_ttl:
|
async with self._jwks_lock:
|
||||||
async with httpx.AsyncClient() as client:
|
if time.time() - self._jwks_at > self.config.cache_ttl:
|
||||||
res = await client.get(self.config.jwks_uri, timeout=5)
|
async with httpx.AsyncClient() as client:
|
||||||
res.raise_for_status()
|
res = await client.get(self.config.jwks_uri, timeout=5)
|
||||||
jwks_data = res.json()["keys"]
|
res.raise_for_status()
|
||||||
self._jwks = {}
|
jwks_data = res.json()["keys"]
|
||||||
for k in jwks_data:
|
updated = {}
|
||||||
kid = k["kid"]
|
for k in jwks_data:
|
||||||
# Store the entire key object as it may be needed for different algorithms
|
kid = k["kid"]
|
||||||
self._jwks[kid] = k
|
# Store the entire key object as it may be needed for different algorithms
|
||||||
self._jwks_at = time.time()
|
updated[kid] = k
|
||||||
|
self._jwks = updated
|
||||||
|
self._jwks_at = time.time()
|
||||||
|
|
||||||
|
|
||||||
class CustomAuthProviderConfig(BaseModel):
|
class CustomAuthProviderConfig(BaseModel):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue