diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index 82a76ad32..994bd1dec 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -23,7 +23,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - auth-provider: [kubernetes] + auth-provider: [oauth2_token] fail-fast: false # we want to run all tests regardless of failure steps: @@ -47,29 +47,53 @@ jobs: uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19 - name: Start minikube - if: ${{ matrix.auth-provider == 'kubernetes' }} + if: ${{ matrix.auth-provider == 'oauth2_token' }} run: | minikube start kubectl get pods -A - name: Configure Kube Auth - if: ${{ matrix.auth-provider == 'kubernetes' }} + if: ${{ matrix.auth-provider == 'oauth2_token' }} run: | kubectl create namespace llama-stack kubectl create serviceaccount llama-stack-auth -n llama-stack kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token + cat <> $GITHUB_ENV + echo "KUBERNETES_API_SERVER_URL=$(kubectl get --raw /.well-known/openid-configuration| jq -r .jwks_uri)" >> $GITHUB_ENV echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV + echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV + echo "KUBERNETES_AUDIENCE=$(kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV - name: Set Kube Auth Config and run server env: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" - if: ${{ matrix.auth-provider == 'kubernetes' }} + if: ${{ matrix.auth-provider == 'oauth2_token' }} run: | run_dir=$(mktemp -d) cat <<'EOF' > $run_dir/run.yaml @@ -81,7 +105,8 @@ jobs: port: 8321 EOF yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml - yq eval '.server.auth.config = {"api_server_url": "${{ env.KUBERNETES_API_SERVER_URL }}", "ca_cert_path": "${{ env.KUBERNETES_CA_CERT_PATH }}"}' -i $run_dir/run.yaml + yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml + yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml cat $run_dir/run.yaml source .venv/bin/activate diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 7a42f503a..77b52a621 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -118,11 +118,6 @@ server: port: 8321 # Port to listen on (default: 8321) tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS - auth: # Optional: Authentication configuration - provider_type: "kubernetes" # Type of auth provider - config: # Provider-specific configuration - api_server_url: "https://kubernetes.default.svc" - ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate ``` ### Authentication Configuration @@ -135,7 +130,7 @@ Authorization: Bearer The server supports multiple authentication providers: -#### Kubernetes Provider +#### OAuth 2.0/OpenID Connect Provider with Kubernetes The Kubernetes cluster must be configured to use a service account for authentication. @@ -146,14 +141,67 @@ kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --se kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token ``` -Validates tokens against the Kubernetes API server: +Make sure the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests +and that the correct RoleBinding is created to allow the service account to access the necessary +resources. If that is not the case, you can create a RoleBinding for the service account to access +the necessary resources: + +```yaml +# allow-anonymous-openid.yaml +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: allow-anonymous-openid +rules: +- nonResourceURLs: ["/openid/v1/jwks"] + verbs: ["get"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: allow-anonymous-openid +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: allow-anonymous-openid +subjects: +- kind: User + name: system:anonymous + apiGroup: rbac.authorization.k8s.io +``` + +And then apply the configuration: +```bash +kubectl apply -f allow-anonymous-openid.yaml +``` + +Validates tokens against the Kubernetes API server through the OIDC provider: ```yaml server: auth: - provider_type: "kubernetes" + provider_type: "oauth2_token" config: - api_server_url: "https://kubernetes.default.svc" # URL of the Kubernetes API server - ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate + jwks: + uri: "https://kubernetes.default.svc" + cache_ttl: 3600 + tls_cafile: "/path/to/ca.crt" + issuer: "https://kubernetes.default.svc" + audience: "https://kubernetes.default.svc" +``` + +To find your cluster's audience, run: +```bash +kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud +``` + +For the issuer, you can use the OIDC provider's URL: +```bash +kubectl get --raw /.well-known/openid-configuration| jq .issuer +``` + +For the tls_cafile, you can use the CA certificate of the OIDC provider: +```bash +kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}' ``` The provider extracts user information from the JWT token: diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index ca3664828..eb790ad93 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -220,14 +220,14 @@ class LoggingConfig(BaseModel): class AuthProviderType(str, Enum): """Supported authentication provider types.""" - KUBERNETES = "kubernetes" + OAUTH2_TOKEN = "oauth2_token" CUSTOM = "custom" class AuthenticationConfig(BaseModel): provider_type: AuthProviderType = Field( ..., - description="Type of authentication provider (e.g., 'kubernetes', 'custom')", + description="Type of authentication provider", ) config: dict[str, Any] = Field( ..., diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index 67acffe3e..fb26b49a7 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -8,7 +8,8 @@ import json import httpx -from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider +from llama_stack.distribution.datatypes import AuthenticationConfig +from llama_stack.distribution.server.auth_providers import create_auth_provider from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") @@ -77,7 +78,7 @@ class AuthenticationMiddleware: access resources that don't have access_attributes defined. """ - def __init__(self, app, auth_config: AuthProviderConfig): + def __init__(self, app, auth_config: AuthenticationConfig): self.app = app self.auth_provider = create_auth_provider(auth_config) diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index baab75eca..39f258c3b 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -4,13 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json import ssl import time from abc import ABC, abstractmethod from asyncio import Lock -from enum import Enum -from typing import Any +from pathlib import Path from urllib.parse import parse_qs import httpx @@ -18,7 +16,7 @@ from jose import jwt from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self -from llama_stack.distribution.datatypes import AccessAttributes +from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") @@ -76,21 +74,6 @@ class AuthRequest(BaseModel): request: AuthRequestContext = Field(description="Context information about the request being authenticated") -class AuthProviderType(str, Enum): - """Supported authentication provider types.""" - - KUBERNETES = "kubernetes" - CUSTOM = "custom" - OAUTH2_TOKEN = "oauth2_token" - - -class AuthProviderConfig(BaseModel): - """Base configuration for authentication providers.""" - - provider_type: AuthProviderType = Field(..., description="Type of authentication provider") - config: dict[str, Any] = Field(..., description="Provider-specific configuration") - - class AuthProvider(ABC): """Abstract base class for authentication providers.""" @@ -105,83 +88,6 @@ class AuthProvider(ABC): pass -class KubernetesAuthProviderConfig(BaseModel): - api_server_url: str - ca_cert_path: str | None = None - - -class KubernetesAuthProvider(AuthProvider): - """Kubernetes authentication provider that validates tokens against the Kubernetes API server.""" - - def __init__(self, config: KubernetesAuthProviderConfig): - self.config = config - self._client = None - - async def _get_client(self): - """Get or create a Kubernetes client.""" - if self._client is None: - # kubernetes-client has not async support, see: - # https://github.com/kubernetes-client/python/issues/323 - from kubernetes import client - from kubernetes.client import ApiClient - - # Configure the client - configuration = client.Configuration() - configuration.host = self.config.api_server_url - if self.config.ca_cert_path: - configuration.ssl_ca_cert = self.config.ca_cert_path - configuration.verify_ssl = bool(self.config.ca_cert_path) - - # Create API client - self._client = ApiClient(configuration) - return self._client - - async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: - """Validate a Kubernetes token and return access attributes.""" - try: - client = await self._get_client() - - # Set the token in the client - client.set_default_header("Authorization", f"Bearer {token}") - - # Make a request to validate the token - # We use the /api endpoint which requires authentication - from kubernetes.client import CoreV1Api - - api = CoreV1Api(client) - api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request - - # If we get here, the token is valid - # Extract user info from the token claims - import base64 - - # Decode the token (without verification since we've already validated it) - token_parts = token.split(".") - payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4))) - - # Extract user information from the token - username = payload.get("sub", "") - groups = payload.get("groups", []) - - return TokenValidationResult( - principal=username, - access_attributes=AccessAttributes( - roles=[username], # Use username as a role - teams=groups, # Use Kubernetes groups as teams - ), - ) - - except Exception as e: - logger.exception("Failed to validate Kubernetes token") - raise ValueError("Invalid or expired token") from e - - async def close(self): - """Close the HTTP client.""" - if self._client: - self._client.close() - self._client = None - - def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes: attributes = AccessAttributes() for claim_key, attribute_key in mapping.items(): @@ -212,11 +118,13 @@ class OAuth2IntrospectionConfig(BaseModel): client_id: str client_secret: str send_secret_in_body: bool = False - tls_cafile: str | None = None class OAuth2TokenAuthProviderConfig(BaseModel): audience: str = "llama-stack" + verify_tls: bool = True + tls_cafile: Path | None = None + issuer: str | None = Field(default=None, description="The OIDC issuer URL.") claims_mapping: dict[str, str] = Field( default_factory=lambda: { "sub": "roles", @@ -265,16 +173,14 @@ class OAuth2TokenAuthProvider(AuthProvider): async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: if self.config.jwks: - return await self.validate_jwt_token(token, self.config.jwks, scope) + return await self.validate_jwt_token(token, scope) if self.config.introspection: - return await self.introspect_token(token, self.config.introspection, scope) + return await self.introspect_token(token, scope) raise ValueError("One of jwks or introspection must be configured") - async def validate_jwt_token( - self, token: str, config: OAuth2JWKSConfig, scope: dict | None = None - ) -> TokenValidationResult: + async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: """Validate a token using the JWT token.""" - await self._refresh_jwks(config) + await self._refresh_jwks() try: header = jwt.get_unverified_header(token) @@ -288,7 +194,7 @@ class OAuth2TokenAuthProvider(AuthProvider): key_data, algorithms=[algorithm], audience=self.config.audience, - options={"verify_exp": True}, + issuer=self.config.issuer, ) except Exception as exc: raise ValueError(f"Invalid JWT token: {token}") from exc @@ -302,26 +208,27 @@ class OAuth2TokenAuthProvider(AuthProvider): access_attributes=access_attributes, ) - async def introspect_token( - self, token: str, config: OAuth2IntrospectionConfig, scope: dict | None = None - ) -> TokenValidationResult: + async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: """Validate a token using token introspection as defined by RFC 7662.""" form = { "token": token, } - if config.send_secret_in_body: - form["client_id"] = config.client_id - form["client_secret"] = config.client_secret + if self.config.introspection is None: + raise ValueError("Introspection is not configured") + + if self.config.introspection.send_secret_in_body: + form["client_id"] = self.config.introspection.client_id + form["client_secret"] = self.config.introspection.client_secret auth = None else: - auth = (config.client_id, config.client_secret) + auth = (self.config.introspection.client_id, self.config.introspection.client_secret) ssl_ctxt = None - if config.tls_cafile: - ssl_ctxt = ssl.create_default_context(cafile=config.tls_cafile) + if self.config.tls_cafile: + ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix()) try: async with httpx.AsyncClient(verify=ssl_ctxt) as client: response = await client.post( - config.url, + self.config.introspection.url, data=form, auth=auth, timeout=10.0, # Add a reasonable timeout @@ -352,11 +259,24 @@ class OAuth2TokenAuthProvider(AuthProvider): async def close(self): pass - async def _refresh_jwks(self, config: OAuth2JWKSConfig) -> None: + async def _refresh_jwks(self) -> None: + """ + Refresh the JWKS cache. + + This is a simple cache that expires after a certain amount of time (defined by `cache_ttl`). + If the cache is expired, we refresh the JWKS from the JWKS URI. + + Notes: for Kubernetes which doesn't fully implement the OIDC protocol: + * It doesn't have user authentication flows + * It doesn't have refresh tokens + """ async with self._jwks_lock: - if time.time() - self._jwks_at > config.cache_ttl: - async with httpx.AsyncClient() as client: - res = await client.get(config.uri, timeout=5) + if self.config.jwks is None: + raise ValueError("JWKS is not configured") + if time.time() - self._jwks_at > self.config.jwks.cache_ttl: + verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls + async with httpx.AsyncClient(verify=verify) as client: + res = await client.get(self.config.jwks.uri, timeout=5) res.raise_for_status() jwks_data = res.json()["keys"] updated = {} @@ -443,13 +363,11 @@ class CustomAuthProvider(AuthProvider): self._client = None -def create_auth_provider(config: AuthProviderConfig) -> AuthProvider: +def create_auth_provider(config: AuthenticationConfig) -> AuthProvider: """Factory function to create the appropriate auth provider.""" provider_type = config.provider_type.lower() - if provider_type == "kubernetes": - return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config)) - elif provider_type == "custom": + if provider_type == "custom": return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config)) elif provider_type == "oauth2_token": return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config)) diff --git a/pyproject.toml b/pyproject.toml index a41830e64..8b922bafb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,6 @@ dependencies = [ "tiktoken", "pillow", "h11>=0.16.0", - "kubernetes", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 6dfcc1024..2fe72c803 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,19 +4,16 @@ annotated-types==0.7.0 anyio==4.8.0 attrs==25.1.0 blobfile==3.0.0 -cachetools==5.5.2 certifi==2025.1.31 charset-normalizer==3.4.1 click==8.1.8 colorama==0.4.6 ; sys_platform == 'win32' distro==1.9.0 -durationpy==0.9 ecdsa==0.19.1 exceptiongroup==1.2.2 ; python_full_version < '3.11' filelock==3.17.0 fire==0.7.0 fsspec==2024.12.0 -google-auth==2.38.0 h11==0.16.0 httpcore==1.0.9 httpx==0.28.1 @@ -26,14 +23,12 @@ jinja2==3.1.6 jiter==0.8.2 jsonschema==4.23.0 jsonschema-specifications==2024.10.1 -kubernetes==32.0.1 llama-stack-client==0.2.7 lxml==5.3.1 markdown-it-py==3.0.0 markupsafe==3.0.2 mdurl==0.1.2 numpy==2.2.3 -oauthlib==3.2.2 openai==1.71.0 packaging==24.2 pandas==2.2.3 @@ -41,7 +36,6 @@ pillow==11.1.0 prompt-toolkit==3.0.50 pyaml==25.1.0 pyasn1==0.4.8 -pyasn1-modules==0.4.1 pycryptodomex==3.21.0 pydantic==2.10.6 pydantic-core==2.27.2 @@ -54,7 +48,6 @@ pyyaml==6.0.2 referencing==0.36.2 regex==2024.11.6 requests==2.32.3 -requests-oauthlib==2.0.0 rich==13.9.4 rpds-py==0.22.3 rsa==4.9 @@ -68,4 +61,3 @@ typing-extensions==4.12.2 tzdata==2025.1 urllib3==2.3.0 wcwidth==0.2.13 -websocket-client==1.8.0 diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 56458c0e7..94c486f18 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -11,12 +11,10 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient -from llama_stack.distribution.datatypes import AccessAttributes +from llama_stack.distribution.datatypes import AuthenticationConfig from llama_stack.distribution.server.auth import AuthenticationMiddleware from llama_stack.distribution.server.auth_providers import ( - AuthProviderConfig, AuthProviderType, - TokenValidationResult, get_attributes_from_claims, ) @@ -62,7 +60,7 @@ def invalid_token(): @pytest.fixture def http_app(mock_auth_endpoint): app = FastAPI() - auth_config = AuthProviderConfig( + auth_config = AuthenticationConfig( provider_type=AuthProviderType.CUSTOM, config={"endpoint": mock_auth_endpoint}, ) @@ -78,7 +76,7 @@ def http_app(mock_auth_endpoint): @pytest.fixture def k8s_app(): app = FastAPI() - auth_config = AuthProviderConfig( + auth_config = AuthenticationConfig( provider_type=AuthProviderType.KUBERNETES, config={"api_server_url": "https://kubernetes.default.svc"}, ) @@ -118,7 +116,7 @@ def mock_scope(): @pytest.fixture def mock_http_middleware(mock_auth_endpoint): mock_app = AsyncMock() - auth_config = AuthProviderConfig( + auth_config = AuthenticationConfig( provider_type=AuthProviderType.CUSTOM, config={"endpoint": mock_auth_endpoint}, ) @@ -128,7 +126,7 @@ def mock_http_middleware(mock_auth_endpoint): @pytest.fixture def mock_k8s_middleware(): mock_app = AsyncMock() - auth_config = AuthProviderConfig( + auth_config = AuthenticationConfig( provider_type=AuthProviderType.KUBERNETES, config={"api_server_url": "https://kubernetes.default.svc"}, ) @@ -284,116 +282,13 @@ async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope): assert attributes["roles"] == ["test.jwt.token"] -# Kubernetes Tests -def test_missing_auth_header_k8s(k8s_client): - response = k8s_client.get("/test") - assert response.status_code == 401 - assert "Missing or invalid Authorization header" in response.json()["error"]["message"] - - -def test_invalid_auth_header_format_k8s(k8s_client): - response = k8s_client.get("/test", headers={"Authorization": "InvalidFormat token123"}) - assert response.status_code == 401 - assert "Missing or invalid Authorization header" in response.json()["error"]["message"] - - -@patch("kubernetes.client.ApiClient") -def test_valid_k8s_authentication(mock_api_client, k8s_client, valid_token): - # Mock the Kubernetes client - mock_client = AsyncMock() - mock_api_client.return_value = mock_client - - # Mock successful token validation - mock_client.set_default_header = AsyncMock() - - # Mock the token validation to return valid access attributes - with patch("llama_stack.distribution.server.auth_providers.KubernetesAuthProvider.validate_token") as mock_validate: - mock_validate.return_value = TokenValidationResult( - principal="test-principal", - access_attributes=AccessAttributes( - roles=["admin"], teams=["ml-team"], projects=["llama-3"], namespaces=["research"] - ), - ) - response = k8s_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"}) - assert response.status_code == 200 - assert response.json() == {"message": "Authentication successful"} - - -@patch("kubernetes.client.ApiClient") -def test_invalid_k8s_authentication(mock_api_client, k8s_client, invalid_token): - # Mock the Kubernetes client - mock_client = AsyncMock() - mock_api_client.return_value = mock_client - - # Mock failed token validation by raising an exception - with patch("llama_stack.distribution.server.auth_providers.KubernetesAuthProvider.validate_token") as mock_validate: - mock_validate.side_effect = ValueError("Invalid or expired token") - response = k8s_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"}) - assert response.status_code == 401 - assert "Invalid or expired token" in response.json()["error"]["message"] - - -@pytest.mark.asyncio -async def test_k8s_middleware_with_access_attributes(mock_k8s_middleware, mock_scope): - middleware, mock_app = mock_k8s_middleware - mock_receive = AsyncMock() - mock_send = AsyncMock() - - with patch("kubernetes.client.ApiClient") as mock_api_client: - mock_client = AsyncMock() - mock_api_client.return_value = mock_client - - # Mock successful token validation - mock_client.set_default_header = AsyncMock() - - # Mock token payload with access attributes - mock_token_parts = ["header", "eyJzdWIiOiJhZG1pbiIsImdyb3VwcyI6WyJtbC10ZWFtIl19", "signature"] - mock_scope["headers"][1] = (b"authorization", f"Bearer {'.'.join(mock_token_parts)}".encode()) - - await middleware(mock_scope, mock_receive, mock_send) - - assert "user_attributes" in mock_scope - assert mock_scope["user_attributes"]["roles"] == ["admin"] - assert mock_scope["user_attributes"]["teams"] == ["ml-team"] - - mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send) - - -@pytest.mark.asyncio -async def test_k8s_middleware_no_attributes(mock_k8s_middleware, mock_scope): - """Test middleware behavior with no access attributes""" - middleware, mock_app = mock_k8s_middleware - mock_receive = AsyncMock() - mock_send = AsyncMock() - - with patch("kubernetes.client.ApiClient") as mock_api_client: - mock_client = AsyncMock() - mock_api_client.return_value = mock_client - - # Mock successful token validation - mock_client.set_default_header = AsyncMock() - - # Mock token payload without access attributes - mock_token_parts = ["header", "eyJzdWIiOiJhZG1pbiJ9", "signature"] - mock_scope["headers"][1] = (b"authorization", f"Bearer {'.'.join(mock_token_parts)}".encode()) - - await middleware(mock_scope, mock_receive, mock_send) - - assert "user_attributes" in mock_scope - attributes = mock_scope["user_attributes"] - assert "roles" in attributes - assert attributes["roles"] == ["admin"] - - mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send) - - # oauth2 token provider tests @pytest.fixture def oauth2_app(): app = FastAPI() - auth_config = AuthProviderConfig( + auth_config = AuthenticationConfig( provider_type=AuthProviderType.OAUTH2_TOKEN, config={ "jwks": { @@ -530,7 +425,7 @@ def mock_introspection_endpoint(): @pytest.fixture def introspection_app(mock_introspection_endpoint): app = FastAPI() - auth_config = AuthProviderConfig( + auth_config = AuthenticationConfig( provider_type=AuthProviderType.OAUTH2_TOKEN, config={ "jwks": None, @@ -549,7 +444,7 @@ def introspection_app(mock_introspection_endpoint): @pytest.fixture def introspection_app_with_custom_mapping(mock_introspection_endpoint): app = FastAPI() - auth_config = AuthProviderConfig( + auth_config = AuthenticationConfig( provider_type=AuthProviderType.OAUTH2_TOKEN, config={ "jwks": None, diff --git a/uv.lock b/uv.lock index c30e2c4c1..a987ddc9e 100644 --- a/uv.lock +++ b/uv.lock @@ -676,15 +676,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/d7/9322c609343d929e75e7e5e6255e614fcc67572cfd083959cdef3b7aad79/docutils-0.21.2-py3-none-any.whl", hash = "sha256:dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2", size = 587408 }, ] -[[package]] -name = "durationpy" -version = "0.9" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/31/e9/f49c4e7fccb77fa5c43c2480e09a857a78b41e7331a75e128ed5df45c56b/durationpy-0.9.tar.gz", hash = "sha256:fd3feb0a69a0057d582ef643c355c40d2fa1c942191f914d12203b1a01ac722a", size = 3186 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4c/a3/ac312faeceffd2d8f86bc6dcb5c401188ba5a01bc88e69bed97578a0dfcd/durationpy-0.9-py3-none-any.whl", hash = "sha256:e65359a7af5cedad07fb77a2dd3f390f8eb0b74cb845589fa6c057086834dd38", size = 3461 }, -] - [[package]] name = "ecdsa" version = "0.19.1" @@ -863,20 +854,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1d/9a/4114a9057db2f1462d5c8f8390ab7383925fe1ac012eaa42402ad65c2963/GitPython-3.1.44-py3-none-any.whl", hash = "sha256:9e0e10cda9bed1ee64bc9a6de50e7e38a9c9943241cd7f585f6df3ed28011110", size = 207599 }, ] -[[package]] -name = "google-auth" -version = "2.38.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cachetools" }, - { name = "pyasn1-modules" }, - { name = "rsa" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c6/eb/d504ba1daf190af6b204a9d4714d457462b486043744901a6eeea711f913/google_auth-2.38.0.tar.gz", hash = "sha256:8285113607d3b80a3f1543b75962447ba8a09fe85783432a784fdeef6ac094c4", size = 270866 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9d/47/603554949a37bca5b7f894d51896a9c534b9eab808e2520a748e081669d0/google_auth-2.38.0-py2.py3-none-any.whl", hash = "sha256:e7dae6694313f434a2727bf2906f27ad259bae090d7aa896590d86feec3d9d4a", size = 210770 }, -] - [[package]] name = "googleapis-common-protos" version = "1.67.0" @@ -1324,28 +1301,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c9/fb/108ecd1fe961941959ad0ee4e12ee7b8b1477247f30b1fdfd83ceaf017f0/jupyter_core-5.7.2-py3-none-any.whl", hash = "sha256:4f7315d2f6b4bcf2e3e7cb6e46772eba760ae459cd1f59d29eb57b0a01bd7409", size = 28965 }, ] -[[package]] -name = "kubernetes" -version = "32.0.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, - { name = "durationpy" }, - { name = "google-auth" }, - { name = "oauthlib" }, - { name = "python-dateutil" }, - { name = "pyyaml" }, - { name = "requests" }, - { name = "requests-oauthlib" }, - { name = "six" }, - { name = "urllib3" }, - { name = "websocket-client" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b7/e8/0598f0e8b4af37cd9b10d8b87386cf3173cb8045d834ab5f6ec347a758b3/kubernetes-32.0.1.tar.gz", hash = "sha256:42f43d49abd437ada79a79a16bd48a604d3471a117a8347e87db693f2ba0ba28", size = 946691 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/08/10/9f8af3e6f569685ce3af7faab51c8dd9d93b9c38eba339ca31c746119447/kubernetes-32.0.1-py2.py3-none-any.whl", hash = "sha256:35282ab8493b938b08ab5526c7ce66588232df00ef5e1dbe88a419107dc10998", size = 1988070 }, -] - [[package]] name = "levenshtein" version = "0.27.1" @@ -1441,7 +1396,6 @@ dependencies = [ { name = "huggingface-hub" }, { name = "jinja2" }, { name = "jsonschema" }, - { name = "kubernetes" }, { name = "llama-stack-client" }, { name = "openai" }, { name = "pillow" }, @@ -1546,7 +1500,6 @@ requires-dist = [ { name = "jinja2", specifier = ">=3.1.6" }, { name = "jinja2", marker = "extra == 'codegen'", specifier = ">=3.1.6" }, { name = "jsonschema" }, - { name = "kubernetes" }, { name = "llama-stack-client", specifier = ">=0.2.7" }, { name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.7" }, { name = "mcp", marker = "extra == 'test'" }, @@ -1624,9 +1577,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cd/6b/31c07396c5b3010668e4eb38061a96ffacb47ec4b14d8aeb64c13856c485/llama_stack_client-0.2.7.tar.gz", hash = "sha256:11aee11fdd5e0e8caad07c0cce9c4d88640938844372e7e3453a91ea0757fcb3", size = 259273, upload-time = "2025-05-16T20:31:39.221Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/6b/31c07396c5b3010668e4eb38061a96ffacb47ec4b14d8aeb64c13856c485/llama_stack_client-0.2.7.tar.gz", hash = "sha256:11aee11fdd5e0e8caad07c0cce9c4d88640938844372e7e3453a91ea0757fcb3", size = 259273 } wheels = [ - { url = "https://files.pythonhosted.org/packages/ac/69/6a5f4683afe355500df4376fdcbfb2fc1e6a0c3bcea5ff8f6114773a9acf/llama_stack_client-0.2.7-py3-none-any.whl", hash = "sha256:78b3f2abdb1770c7b1270a9c0ef58402a988401c564d2e6c83588779ac6fc38d", size = 292727, upload-time = "2025-05-16T20:31:37.587Z" }, + { url = "https://files.pythonhosted.org/packages/ac/69/6a5f4683afe355500df4376fdcbfb2fc1e6a0c3bcea5ff8f6114773a9acf/llama_stack_client-0.2.7-py3-none-any.whl", hash = "sha256:78b3f2abdb1770c7b1270a9c0ef58402a988401c564d2e6c83588779ac6fc38d", size = 292727 }, ] [[package]] @@ -2087,15 +2040,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/17/7f/d322a4125405920401450118dbdc52e0384026bd669939484670ce8b2ab9/numpy-2.2.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:783145835458e60fa97afac25d511d00a1eca94d4a8f3ace9fe2043003c678e4", size = 12839607 }, ] -[[package]] -name = "oauthlib" -version = "3.2.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6d/fa/fbf4001037904031639e6bfbfc02badfc7e12f137a8afa254df6c4c8a670/oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918", size = 177352 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/80/cab10959dc1faead58dc8384a781dfbf93cb4d33d50988f7a69f1b7c9bbe/oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca", size = 151688 }, -] - [[package]] name = "openai" version = "1.71.0" @@ -2608,18 +2552,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/1e/a94a8d635fa3ce4cfc7f506003548d0a2447ae76fd5ca53932970fe3053f/pyasn1-0.4.8-py2.py3-none-any.whl", hash = "sha256:39c7e2ec30515947ff4e87fb6f456dfc6e84857d34be479c9d4a4ba4bf46aa5d", size = 77145 }, ] -[[package]] -name = "pyasn1-modules" -version = "0.4.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyasn1" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/1d/67/6afbf0d507f73c32d21084a79946bfcfca5fbc62a72057e9c23797a737c9/pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c", size = 310028 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd", size = 181537 }, -] - [[package]] name = "pycparser" version = "2.22" @@ -2875,9 +2807,9 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973 } wheels = [ - { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" }, + { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382 }, ] [[package]] @@ -3256,19 +3188,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 }, ] -[[package]] -name = "requests-oauthlib" -version = "2.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "oauthlib" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/42/f2/05f29bc3913aea15eb670be136045bf5c5bbf4b99ecb839da9b422bb2c85/requests-oauthlib-2.0.0.tar.gz", hash = "sha256:b3dffaebd884d8cd778494369603a9e7b58d29111bf6b41bdc2dcd87203af4e9", size = 55650 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/5d/63d4ae3b9daea098d5d6f5da83984853c1bbacd5dc826764b249fe119d24/requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36", size = 24179 }, -] - [[package]] name = "rich" version = "13.9.4" @@ -4323,15 +4242,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 }, ] -[[package]] -name = "websocket-client" -version = "1.8.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e6/30/fba0d96b4b5fbf5948ed3f4681f7da2f9f64512e1d303f94b4cc174c24a5/websocket_client-1.8.0.tar.gz", hash = "sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da", size = 54648 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5a/84/44687a29792a70e111c5c477230a72c4b957d88d16141199bf9acb7537a3/websocket_client-1.8.0-py3-none-any.whl", hash = "sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526", size = 58826 }, -] - [[package]] name = "websockets" version = "15.0"