diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 446a88ca0..be5629ba1 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -229,7 +229,7 @@ class AuthenticationConfig(BaseModel): ..., description="Type of authentication provider (e.g., 'kubernetes', 'custom')", ) - config: dict[str, str] = Field( + config: dict[str, Any] = Field( ..., description="Provider-specific configuration", ) diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index b73fded58..baab75eca 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -5,15 +5,18 @@ # the root directory of this source tree. import json +import ssl import time from abc import ABC, abstractmethod from asyncio import Lock from enum import Enum +from typing import Any from urllib.parse import parse_qs import httpx from jose import jwt -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator +from typing_extensions import Self from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.log import get_logger @@ -85,7 +88,7 @@ class AuthProviderConfig(BaseModel): """Base configuration for authentication providers.""" provider_type: AuthProviderType = Field(..., description="Type of authentication provider") - config: dict[str, str] = Field(..., description="Provider-specific configuration") + config: dict[str, Any] = Field(..., description="Provider-specific configuration") class AuthProvider(ABC): @@ -198,10 +201,21 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) return attributes -class OAuth2TokenAuthProviderConfig(BaseModel): +class OAuth2JWKSConfig(BaseModel): # The JWKS URI for collecting public keys - jwks_uri: str + uri: str cache_ttl: int = 3600 + + +class OAuth2IntrospectionConfig(BaseModel): + url: str + client_id: str + client_secret: str + send_secret_in_body: bool = False + tls_cafile: str | None = None + + +class OAuth2TokenAuthProviderConfig(BaseModel): audience: str = "llama-stack" claims_mapping: dict[str, str] = Field( default_factory=lambda: { @@ -214,6 +228,8 @@ class OAuth2TokenAuthProviderConfig(BaseModel): "namespace": "namespaces", }, ) + jwks: OAuth2JWKSConfig | None + introspection: OAuth2IntrospectionConfig | None = None @classmethod @field_validator("claims_mapping") @@ -225,6 +241,14 @@ class OAuth2TokenAuthProviderConfig(BaseModel): raise ValueError(f"claims_mapping value is not a valid attribute: {value}") return v + @model_validator(mode="after") + def validate_mode(self) -> Self: + if not self.jwks and not self.introspection: + raise ValueError("One of jwks or introspection must be configured") + if self.jwks and self.introspection: + raise ValueError("At present only one of jwks or introspection should be configured") + return self + class OAuth2TokenAuthProvider(AuthProvider): """ @@ -240,8 +264,17 @@ class OAuth2TokenAuthProvider(AuthProvider): self._jwks_lock = Lock() async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + if self.config.jwks: + return await self.validate_jwt_token(token, self.config.jwks, scope) + if self.config.introspection: + return await self.introspect_token(token, self.config.introspection, scope) + raise ValueError("One of jwks or introspection must be configured") + + async def validate_jwt_token( + self, token: str, config: OAuth2JWKSConfig, scope: dict | None = None + ) -> TokenValidationResult: """Validate a token using the JWT token.""" - await self._refresh_jwks() + await self._refresh_jwks(config) try: header = jwt.get_unverified_header(token) @@ -269,14 +302,61 @@ class OAuth2TokenAuthProvider(AuthProvider): access_attributes=access_attributes, ) - async def close(self): - """Close the HTTP client.""" + async def introspect_token( + self, token: str, config: OAuth2IntrospectionConfig, scope: dict | None = None + ) -> TokenValidationResult: + """Validate a token using token introspection as defined by RFC 7662.""" + form = { + "token": token, + } + if config.send_secret_in_body: + form["client_id"] = config.client_id + form["client_secret"] = config.client_secret + auth = None + else: + auth = (config.client_id, config.client_secret) + ssl_ctxt = None + if config.tls_cafile: + ssl_ctxt = ssl.create_default_context(cafile=config.tls_cafile) + try: + async with httpx.AsyncClient(verify=ssl_ctxt) as client: + response = await client.post( + config.url, + data=form, + auth=auth, + timeout=10.0, # Add a reasonable timeout + ) + if response.status_code != 200: + logger.warning(f"Token introspection failed with status code: {response.status_code}") + raise ValueError(f"Token introspection failed: {response.status_code}") - async def _refresh_jwks(self) -> None: + fields = response.json() + if not fields["active"]: + raise ValueError("Token not active") + principal = fields["sub"] or fields["username"] + access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping) + return TokenValidationResult( + principal=principal, + access_attributes=access_attributes, + ) + except httpx.TimeoutException: + logger.exception("Token introspection request timed out") + raise + except ValueError: + # Re-raise ValueError exceptions to preserve their message + raise + except Exception as e: + logger.exception("Error during token introspection") + raise ValueError("Token introspection error") from e + + async def close(self): + pass + + async def _refresh_jwks(self, config: OAuth2JWKSConfig) -> None: async with self._jwks_lock: - if time.time() - self._jwks_at > self.config.cache_ttl: + if time.time() - self._jwks_at > config.cache_ttl: async with httpx.AsyncClient() as client: - res = await client.get(self.config.jwks_uri, timeout=5) + res = await client.get(config.uri, timeout=5) res.raise_for_status() jwks_data = res.json()["keys"] updated = {} diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index f15ca9de4..56458c0e7 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -396,8 +396,10 @@ def oauth2_app(): auth_config = AuthProviderConfig( provider_type=AuthProviderType.OAUTH2_TOKEN, config={ - "jwks_uri": "http://mock-authz-service/token/introspect", - "cache_ttl": "3600", + "jwks": { + "uri": "http://mock-authz-service/token/introspect", + "cache_ttl": "3600", + }, "audience": "llama-stack", }, ) @@ -517,3 +519,159 @@ def test_get_attributes_from_claims(): # TODO: add more tests for oauth2 token provider + + +# oauth token introspection tests +@pytest.fixture +def mock_introspection_endpoint(): + return "http://mock-authz-service/token/introspect" + + +@pytest.fixture +def introspection_app(mock_introspection_endpoint): + app = FastAPI() + auth_config = AuthProviderConfig( + provider_type=AuthProviderType.OAUTH2_TOKEN, + config={ + "jwks": None, + "introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"}, + }, + ) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + return app + + +@pytest.fixture +def introspection_app_with_custom_mapping(mock_introspection_endpoint): + app = FastAPI() + auth_config = AuthProviderConfig( + provider_type=AuthProviderType.OAUTH2_TOKEN, + config={ + "jwks": None, + "introspection": { + "url": mock_introspection_endpoint, + "client_id": "myclient", + "client_secret": "abcdefg", + "send_secret_in_body": "true", + }, + "claims_mapping": { + "sub": "roles", + "scope": "roles", + "groups": "teams", + "aud": "namespaces", + }, + }, + ) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + return app + + +@pytest.fixture +def introspection_client(introspection_app): + return TestClient(introspection_app) + + +@pytest.fixture +def introspection_client_with_custom_mapping(introspection_app_with_custom_mapping): + return TestClient(introspection_app_with_custom_mapping) + + +def test_missing_auth_header_introspection(introspection_client): + response = introspection_client.get("/test") + assert response.status_code == 401 + assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + + +def test_invalid_auth_header_format_introspection(introspection_client): + response = introspection_client.get("/test", headers={"Authorization": "InvalidFormat token123"}) + assert response.status_code == 401 + assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + + +async def mock_introspection_active(*args, **kwargs): + return MockResponse( + 200, + { + "active": True, + "sub": "my-user", + "groups": ["group1", "group2"], + "scope": "foo bar", + "aud": ["set1", "set2"], + }, + ) + + +async def mock_introspection_inactive(*args, **kwargs): + return MockResponse( + 200, + { + "active": False, + }, + ) + + +async def mock_introspection_invalid(*args, **kwargs): + class InvalidResponse: + def __init__(self, status_code): + self.status_code = status_code + + def json(self): + raise ValueError("Not JSON") + + return InvalidResponse(200) + + +async def mock_introspection_failed(*args, **kwargs): + return MockResponse( + 500, + {}, + ) + + +@patch("httpx.AsyncClient.post", new=mock_introspection_active) +def test_valid_introspection_authentication(introspection_client, valid_api_key): + response = introspection_client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"}) + assert response.status_code == 200 + assert response.json() == {"message": "Authentication successful"} + + +@patch("httpx.AsyncClient.post", new=mock_introspection_inactive) +def test_inactive_introspection_authentication(introspection_client, invalid_api_key): + response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"}) + assert response.status_code == 401 + assert "Token not active" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_introspection_invalid) +def test_invalid_introspection_authentication(introspection_client, invalid_api_key): + response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"}) + assert response.status_code == 401 + assert "Not JSON" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_introspection_failed) +def test_failed_introspection_authentication(introspection_client, invalid_api_key): + response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"}) + assert response.status_code == 401 + assert "Token introspection failed: 500" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_introspection_active) +def test_valid_introspection_with_custom_mapping_authentication( + introspection_client_with_custom_mapping, valid_api_key +): + response = introspection_client_with_custom_mapping.get( + "/test", headers={"Authorization": f"Bearer {valid_api_key}"} + ) + assert response.status_code == 200 + assert response.json() == {"message": "Authentication successful"}