mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
fix: put the client initialization into its own method
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
dbd6e2be06
commit
c3142758be
1 changed files with 40 additions and 35 deletions
|
|
@ -105,46 +105,51 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
return await self.introspect_token(token, scope)
|
return await self.introspect_token(token, scope)
|
||||||
raise ValueError("One of jwks or introspection must be configured")
|
raise ValueError("One of jwks or introspection must be configured")
|
||||||
|
|
||||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
|
def _get_jwks_client(self) -> jwt.PyJWKClient:
|
||||||
"""Validate a token using the JWT token."""
|
if self._jwks_client is None:
|
||||||
if self.config.jwks is None:
|
ssl_context = None
|
||||||
raise ValueError("JWKS is not configured")
|
if not self.config.verify_tls:
|
||||||
try:
|
# Disable SSL verification if verify_tls is False
|
||||||
if self._jwks_client is None:
|
ssl_context = ssl.create_default_context()
|
||||||
ssl_context = None
|
ssl_context.check_hostname = False
|
||||||
if not self.config.verify_tls:
|
ssl_context.verify_mode = ssl.CERT_NONE
|
||||||
# Disable SSL verification if verify_tls is False
|
elif self.config.tls_cafile:
|
||||||
ssl_context = ssl.create_default_context()
|
# Use custom CA file if provided
|
||||||
ssl_context.check_hostname = False
|
ssl_context = ssl.create_default_context(
|
||||||
ssl_context.verify_mode = ssl.CERT_NONE
|
cafile=self.config.tls_cafile.as_posix(),
|
||||||
elif self.config.tls_cafile:
|
)
|
||||||
# Use custom CA file if provided
|
|
||||||
ssl_context = ssl.create_default_context(
|
|
||||||
cafile=self.config.tls_cafile.as_posix(),
|
|
||||||
)
|
|
||||||
# If verify_tls is True and no tls_cafile, ssl_context remains None (use system defaults)
|
# If verify_tls is True and no tls_cafile, ssl_context remains None (use system defaults)
|
||||||
|
|
||||||
# Prepare headers for JWKS request - this is needed for Kubernetes to authenticate
|
# Prepare headers for JWKS request - this is needed for Kubernetes to authenticate
|
||||||
# to the JWK endpoint
|
# to the JWK endpoint, we must use the token in the config to authenticate
|
||||||
headers = {}
|
headers = {}
|
||||||
if self.config.jwks.token:
|
if self.config.jwks and self.config.jwks.token:
|
||||||
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
|
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
|
||||||
|
|
||||||
# Create PyJWKClient with SSL context if supported
|
self._jwks_client = jwt.PyJWKClient(
|
||||||
self._jwks_client = jwt.PyJWKClient(
|
self.config.jwks.uri if self.config.jwks else None,
|
||||||
self.config.jwks.uri,
|
cache_keys=True,
|
||||||
cache_keys=True,
|
max_cached_keys=10,
|
||||||
max_cached_keys=10,
|
lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None,
|
||||||
lifespan=self.config.jwks.key_recheck_period, # Use configurable period
|
headers=headers,
|
||||||
headers=headers,
|
ssl_context=ssl_context,
|
||||||
ssl_context=ssl_context,
|
)
|
||||||
)
|
return self._jwks_client
|
||||||
|
|
||||||
# Get the signing key from the JWT token
|
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
signing_key = self._jwks_client.get_signing_key_from_jwt(token)
|
"""Validate a token using the JWT token."""
|
||||||
|
try:
|
||||||
# Get the algorithm from the JWT token
|
jwks_client: jwt.PyJWKClient = self._get_jwks_client()
|
||||||
|
signing_key = jwks_client.get_signing_key_from_jwt(token)
|
||||||
algorithm = jwt.get_unverified_header(token)["alg"]
|
algorithm = jwt.get_unverified_header(token)["alg"]
|
||||||
|
claims = jwt.decode(
|
||||||
|
token,
|
||||||
|
signing_key.key,
|
||||||
|
algorithms=[algorithm],
|
||||||
|
audience=self.config.audience,
|
||||||
|
issuer=self.config.issuer,
|
||||||
|
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
|
||||||
|
)
|
||||||
|
|
||||||
# Decode and verify the JWT
|
# Decode and verify the JWT
|
||||||
claims = jwt.decode(
|
claims = jwt.decode(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue