diff --git a/llama_stack/core/server/auth_providers.py b/llama_stack/core/server/auth_providers.py index b4ea9cdaa..05a21c8d4 100644 --- a/llama_stack/core/server/auth_providers.py +++ b/llama_stack/core/server/auth_providers.py @@ -105,46 +105,51 @@ class OAuth2TokenAuthProvider(AuthProvider): return await self.introspect_token(token, scope) raise ValueError("One of jwks or introspection must be configured") - async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User: - """Validate a token using the JWT token.""" - if self.config.jwks is None: - raise ValueError("JWKS is not configured") - try: - if self._jwks_client is None: - ssl_context = None - if not self.config.verify_tls: - # Disable SSL verification if verify_tls is False - ssl_context = ssl.create_default_context() - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE - elif self.config.tls_cafile: - # Use custom CA file if provided - ssl_context = ssl.create_default_context( - cafile=self.config.tls_cafile.as_posix(), - ) + def _get_jwks_client(self) -> jwt.PyJWKClient: + if self._jwks_client is None: + ssl_context = None + if not self.config.verify_tls: + # Disable SSL verification if verify_tls is False + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + elif self.config.tls_cafile: + # Use custom CA file if provided + ssl_context = ssl.create_default_context( + cafile=self.config.tls_cafile.as_posix(), + ) # If verify_tls is True and no tls_cafile, ssl_context remains None (use system defaults) - # Prepare headers for JWKS request - this is needed for Kubernetes to authenticate - # to the JWK endpoint - headers = {} - if self.config.jwks.token: - headers["Authorization"] = f"Bearer {self.config.jwks.token}" + # Prepare headers for JWKS request - this is needed for Kubernetes to authenticate + # to the JWK endpoint, we must use the token in the config to authenticate + headers = {} + if self.config.jwks and self.config.jwks.token: + headers["Authorization"] = f"Bearer {self.config.jwks.token}" - # Create PyJWKClient with SSL context if supported - self._jwks_client = jwt.PyJWKClient( - self.config.jwks.uri, - cache_keys=True, - max_cached_keys=10, - lifespan=self.config.jwks.key_recheck_period, # Use configurable period - headers=headers, - ssl_context=ssl_context, - ) + self._jwks_client = jwt.PyJWKClient( + self.config.jwks.uri if self.config.jwks else None, + cache_keys=True, + max_cached_keys=10, + lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None, + headers=headers, + ssl_context=ssl_context, + ) + return self._jwks_client - # Get the signing key from the JWT token - signing_key = self._jwks_client.get_signing_key_from_jwt(token) - - # Get the algorithm from the JWT token + async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User: + """Validate a token using the JWT token.""" + try: + jwks_client: jwt.PyJWKClient = self._get_jwks_client() + signing_key = jwks_client.get_signing_key_from_jwt(token) algorithm = jwt.get_unverified_header(token)["alg"] + claims = jwt.decode( + token, + signing_key.key, + algorithms=[algorithm], + audience=self.config.audience, + issuer=self.config.issuer, + options={"verify_exp": True, "verify_aud": True, "verify_iss": True}, + ) # Decode and verify the JWT claims = jwt.decode(