From e2e15ebb6c271bec5cd03b1f6d7561b992514fc5 Mon Sep 17 00:00:00 2001 From: grs Date: Fri, 13 Jun 2025 04:13:41 -0400 Subject: [PATCH] feat(auth): allow token to be provided for use against jwks endpoint (#2394) Though the jwks endpoint does not usually require authentication, it does in a kubernetes cluster. While the cluster can be configured to allow anonymous access to that endpoint, this avoids the need to do so. --- .github/workflows/integration-auth-tests.yml | 26 +------ docs/source/distributions/configuration.md | 77 ++++++++++--------- .../distribution/server/auth_providers.py | 6 +- tests/unit/server/test_auth.py | 50 ++++++++++++ 4 files changed, 99 insertions(+), 60 deletions(-) diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index a3a746246..e0f3ff2e8 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -52,30 +52,7 @@ jobs: run: | kubectl create namespace llama-stack kubectl create serviceaccount llama-stack-auth -n llama-stack - kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token - cat <> $GITHUB_ENV echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV + echo "TOKEN=$(cat llama-stack-auth-token)" >> $GITHUB_ENV - name: Set Kube Auth Config and run server env: @@ -101,7 +79,7 @@ jobs: EOF yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml - yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml + yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml cat $run_dir/run.yaml nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 & diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index de99b6576..a48083055 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -56,10 +56,10 @@ shields: [] server: port: 8321 auth: - provider_type: "kubernetes" + provider_type: "oauth2_token" config: - api_server_url: "https://kubernetes.default.svc" - ca_cert_path: "/path/to/ca.crt" + jwks: + uri: "https://my-token-issuing-svc.com/jwks" ``` Let's break this down into the different sections. The first section specifies the set of APIs that the stack server will serve: @@ -132,16 +132,52 @@ The server supports multiple authentication providers: #### OAuth 2.0/OpenID Connect Provider with Kubernetes -The Kubernetes cluster must be configured to use a service account for authentication. +The server can be configured to use service account tokens for authorization, validating these against the Kubernetes API server, e.g.: +```yaml +server: + auth: + provider_type: "oauth2_token" + config: + jwks: + uri: "https://kubernetes.default.svc:8443/openid/v1/jwks" + token: "${env.TOKEN:}" + key_recheck_period: 3600 + tls_cafile: "/path/to/ca.crt" + issuer: "https://kubernetes.default.svc" + audience: "https://kubernetes.default.svc" +``` + +To find your cluster's jwks uri (from which the public key(s) to verify the token signature are obtained), run: +``` +kubectl get --raw /.well-known/openid-configuration| jq -r .jwks_uri +``` + +For the tls_cafile, you can use the CA certificate of the OIDC provider: +```bash +kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}' +``` + +For the issuer, you can use the OIDC provider's URL: +```bash +kubectl get --raw /.well-known/openid-configuration| jq .issuer +``` + +The audience can be obtained from a token, e.g. run: +```bash +kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud +``` + +The jwks token is used to authorize access to the jwks endpoint. You can obtain a token by running: ```bash kubectl create namespace llama-stack kubectl create serviceaccount llama-stack-auth -n llama-stack -kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token +export TOKEN=$(cat llama-stack-auth-token) ``` -Make sure the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests +Alternatively, you can configure the jwks endpoint to allow anonymous access. To do this, make sure +the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests and that the correct RoleBinding is created to allow the service account to access the necessary resources. If that is not the case, you can create a RoleBinding for the service account to access the necessary resources: @@ -175,35 +211,6 @@ And then apply the configuration: kubectl apply -f allow-anonymous-openid.yaml ``` -Validates tokens against the Kubernetes API server through the OIDC provider: -```yaml -server: - auth: - provider_type: "oauth2_token" - config: - jwks: - uri: "https://kubernetes.default.svc" - key_recheck_period: 3600 - tls_cafile: "/path/to/ca.crt" - issuer: "https://kubernetes.default.svc" - audience: "https://kubernetes.default.svc" -``` - -To find your cluster's audience, run: -```bash -kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud -``` - -For the issuer, you can use the OIDC provider's URL: -```bash -kubectl get --raw /.well-known/openid-configuration| jq .issuer -``` - -For the tls_cafile, you can use the CA certificate of the OIDC provider: -```bash -kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}' -``` - The provider extracts user information from the JWT token: - Username from the `sub` claim becomes a role - Kubernetes groups become teams diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 942ff8a18..98e51c25a 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -84,6 +84,7 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) class OAuth2JWKSConfig(BaseModel): # The JWKS URI for collecting public keys uri: str + token: str | None = Field(default=None, description="token to authorise access to jwks") key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates") @@ -246,9 +247,12 @@ class OAuth2TokenAuthProvider(AuthProvider): if self.config.jwks is None: raise ValueError("JWKS is not configured") if time.time() - self._jwks_at > self.config.jwks.key_recheck_period: + headers = {} + if self.config.jwks.token: + headers["Authorization"] = f"Bearer {self.config.jwks.token}" verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls async with httpx.AsyncClient(verify=verify) as client: - res = await client.get(self.config.jwks.uri, timeout=5) + res = await client.get(self.config.jwks.uri, timeout=5, headers=headers) res.raise_for_status() jwks_data = res.json()["keys"] updated = {} diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index e159aefd1..4410048c5 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -345,6 +345,56 @@ def test_invalid_oauth2_authentication(oauth2_client, invalid_token): assert "Invalid JWT token" in response.json()["error"]["message"] +async def mock_auth_jwks_response(*args, **kwargs): + if "headers" not in kwargs or "Authorization" not in kwargs["headers"]: + return MockResponse(401, {}) + authz = kwargs["headers"]["Authorization"] + if authz != "Bearer my-jwks-token": + return MockResponse(401, {}) + return await mock_jwks_response(args, kwargs) + + +@pytest.fixture +def oauth2_app_with_jwks_token(): + app = FastAPI() + auth_config = AuthenticationConfig( + provider_type=AuthProviderType.OAUTH2_TOKEN, + config={ + "jwks": { + "uri": "http://mock-authz-service/token/introspect", + "key_recheck_period": "3600", + "token": "my-jwks-token", + }, + "audience": "llama-stack", + }, + ) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + return app + + +@pytest.fixture +def oauth2_client_with_jwks_token(oauth2_app_with_jwks_token): + return TestClient(oauth2_app_with_jwks_token) + + +@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response) +def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid): + response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) + assert response.status_code == 401 + + +@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response) +def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid): + response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) + assert response.status_code == 200 + assert response.json() == {"message": "Authentication successful"} + + def test_get_attributes_from_claims(): claims = { "sub": "my-user",