diff --git a/llama_stack/core/server/auth_providers.py b/llama_stack/core/server/auth_providers.py index 8e4547adf..8267daf09 100644 --- a/llama_stack/core/server/auth_providers.py +++ b/llama_stack/core/server/auth_providers.py @@ -109,25 +109,31 @@ class OAuth2TokenAuthProvider(AuthProvider): """Validate a token using the JWT token.""" if self.config.jwks is None: raise ValueError("JWKS is not configured") - try: - # Initialize PyJWKClient if not already done if self._jwks_client is None: + ssl_context = None + if self.config.tls_cafile: + ssl_context = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix()) + self._jwks_client = jwt.PyJWKClient( self.config.jwks.uri, cache_keys=True, max_cached_keys=10, - lifespan=3600, # 1 hour cache + lifespan=self.config.jwks.key_recheck_period, # Use configurable period + ssl_context=ssl_context, ) # Get the signing key from the JWT token signing_key = self._jwks_client.get_signing_key_from_jwt(token) + # Get the algorithm from the JWT token + algorithm = jwt.get_unverified_header(token)["alg"] + # Decode and verify the JWT claims = jwt.decode( token, signing_key.key, - algorithms=["RS256", "HS256", "ES256"], # Common algorithms + algorithms=[algorithm], audience=self.config.audience, issuer=self.config.issuer, options={"verify_exp": True, "verify_aud": True, "verify_iss": True}, diff --git a/pyproject.toml b/pyproject.toml index ecc367cfd..82ee8af43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "openai>=1.107", # for expires_after support "prompt-toolkit", "python-dotenv", - "PyJWT>=2.8.0", + "pyjwt[crypto]>=2.8.0", # Pull crypto to support RS256 for jwt. "pydantic>=2.11.9", "rich", "starlette", diff --git a/uv.lock b/uv.lock index 0c97f3204..724bda075 100644 --- a/uv.lock +++ b/uv.lock @@ -1775,7 +1775,7 @@ dependencies = [ { name = "pillow" }, { name = "prompt-toolkit" }, { name = "pydantic" }, - { name = "pyjwt" }, + { name = "pyjwt", extra = ["crypto"] }, { name = "python-dotenv" }, { name = "python-multipart" }, { name = "rich" }, @@ -1898,7 +1898,7 @@ requires-dist = [ { name = "pillow" }, { name = "prompt-toolkit" }, { name = "pydantic", specifier = ">=2.11.9" }, - { name = "pyjwt", specifier = ">=2.8.0" }, + { name = "pyjwt", extras = ["crypto"], specifier = ">=2.8.0" }, { name = "python-dotenv" }, { name = "python-multipart", specifier = ">=0.0.20" }, { name = "rich" }, @@ -3555,6 +3555,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, ] +[package.optional-dependencies] +crypto = [ + { name = "cryptography" }, +] + [[package]] name = "pymilvus" version = "2.6.1"