diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 3feb17c28..4065a65f3 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -12,7 +12,7 @@ from urllib.parse import parse_qs import httpx from jose import jwt -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.log import get_logger @@ -77,6 +77,7 @@ class AuthProviderType(str, Enum): KUBERNETES = "kubernetes" CUSTOM = "custom" + OAUTH2_TOKEN = "oauth2_token" class AuthProviderConfig(BaseModel): @@ -177,24 +178,61 @@ class KubernetesAuthProvider(AuthProvider): self._client = None -JWT_AUDIENCE = "llama-stack" +def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes: + attributes = AccessAttributes() + for claim_key, attribute_key in mapping.items(): + if claim_key not in claims or not hasattr(attributes, attribute_key): + continue + claim = claims[claim_key] + if isinstance(claim, list): + values = claim + else: + values = claim.split() + + current = getattr(attributes, attribute_key) + if current: + current.extend(values) + else: + setattr(attributes, attribute_key, values) + return attributes -class JWTAuthProviderConfig(BaseModel): +class OAuth2TokenAuthProviderConfig(BaseModel): # The JWKS URI for collecting public keys jwks_uri: str - algorithm: str = "RS256" cache_ttl: int = 3600 + audience: str = "llama-stack" + claims_mapping: dict[str, str] = Field( + default_factory=lambda: { + "sub": "roles", + "username": "roles", + "groups": "teams", + "team": "teams", + "project": "projects", + "tenant": "namespaces", + "namespace": "namespaces", + }, + ) + + @classmethod + @field_validator("claims_mapping") + def validate_claims_mapping(cls, v): + for key, value in v.items(): + if not value: + raise ValueError(f"claims_mapping value cannot be empty: {key}") + if value not in AccessAttributes.model_fields: + raise ValueError(f"claims_mapping value is not a valid attribute: {value}") + return v -class JWTAuthProvider(AuthProvider): +class OAuth2TokenAuthProvider(AuthProvider): """ JWT token authentication provider that validates a JWT token and extracts access attributes. This should be the standard authentication provider for most use cases. """ - def __init__(self, config: JWTAuthProviderConfig): + def __init__(self, config: OAuth2TokenAuthProviderConfig): self.config = config self._jwks_at: float = 0.0 self._jwks: dict[str, str] = {} @@ -204,15 +242,17 @@ class JWTAuthProvider(AuthProvider): await self._refresh_jwks() try: - kid = jwt.get_unverified_header(token)["kid"] + header = jwt.get_unverified_header(token) + kid = header["kid"] if kid not in self._jwks: raise ValueError(f"Unknown key ID: {kid}") - key = self._jwks[kid] + key_data = self._jwks[kid] + algorithm = header.get("alg", "RS256") claims = jwt.decode( token, - key, - algorithms=[self.config.algorithm], - audience=JWT_AUDIENCE, + key_data, + algorithms=[algorithm], + audience=self.config.audience, options={"verify_exp": True}, ) except Exception as exc: @@ -220,29 +260,11 @@ class JWTAuthProvider(AuthProvider): # There are other standard claims, the most relevant of which is `scope`. # We should incorporate these into the access attributes. - principal = f"{claims['iss']}:{claims['sub']}" - - teams = claims.get("teams", []) - if not teams: - if team := claims.get("team", claims.get("team_id")): - teams = [team] - projects = claims.get("projects", []) - if not projects: - if project := claims.get("project", claims.get("project_id")): - projects = [project] - namespaces = claims.get("namespaces", []) - if not namespaces: - if namespace := claims.get("namespace", claims.get("tenant")): - namespaces = [namespace] - + principal = claims["sub"] + access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping) return TokenValidationResult( principal=principal, - access_attributes=AccessAttributes( - roles=claims.get("groups", claims.get("roles", [])), # Okta / Auth0 - teams=teams, - projects=projects, - namespaces=namespaces, - ), + access_attributes=access_attributes, ) async def close(self): @@ -250,10 +272,15 @@ class JWTAuthProvider(AuthProvider): async def _refresh_jwks(self) -> None: if time.time() - self._jwks_at > self.config.cache_ttl: - with httpx.AsyncClient() as client: + async with httpx.AsyncClient() as client: res = await client.get(self.config.jwks_uri, timeout=5) res.raise_for_status() - self._jwks = {k["kid"]: k for k in res.json()["keys"]} + jwks_data = res.json()["keys"] + self._jwks = {} + for k in jwks_data: + kid = k["kid"] + # Store the entire key object as it may be needed for different algorithms + self._jwks[kid] = k self._jwks_at = time.time() @@ -340,8 +367,8 @@ def create_auth_provider(config: AuthProviderConfig) -> AuthProvider: return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config)) elif provider_type == "custom": return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config)) - elif provider_type == "jwt": - return JWTAuthProvider(JWTAuthProviderConfig.model_validate(config.config)) + elif provider_type == "oauth2_token": + return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config)) else: supported_providers = ", ".join([t.value for t in AuthProviderType]) raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}") diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index f46574fd9..4a60d43c6 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -27,6 +27,10 @@ class MockResponse: def json(self): return self._json_data + def raise_for_status(self): + if self.status_code != 200: + raise Exception(f"HTTP error: {self.status_code}") + @pytest.fixture def mock_auth_endpoint(): @@ -379,3 +383,88 @@ async def test_k8s_middleware_no_attributes(mock_k8s_middleware, mock_scope): assert attributes["roles"] == ["admin"] mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send) + + +# oauth2 token provider tests + + +@pytest.fixture +def oauth2_app(): + app = FastAPI() + auth_config = AuthProviderConfig( + provider_type=AuthProviderType.OAUTH2_TOKEN, + config={ + "jwks_uri": "http://mock-authz-service/token/introspect", + "cache_ttl": "3600", + "audience": "llama-stack", + }, + ) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + return app + + +@pytest.fixture +def oauth2_client(oauth2_app): + return TestClient(oauth2_app) + + +def test_missing_auth_header_oauth2(oauth2_client): + response = oauth2_client.get("/test") + assert response.status_code == 401 + assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + + +def test_invalid_auth_header_format_oauth2(oauth2_client): + response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"}) + assert response.status_code == 401 + assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + + +async def mock_jwks_response(*args, **kwargs): + return MockResponse( + 200, + { + "keys": [ + { + "kid": "1234567890", + "kty": "oct", + "alg": "HS256", + "use": "sig", + "k": "MTIzNDU2Nzg5MA", # Base64-encoded "1234567890" + } + ] + }, + ) + + +@pytest.fixture +def jwt_token_valid(): + from jose import jwt + + # correctly signed jwt token with "kid" in header + return jwt.encode( + { + "sub": "my-user", + "groups": ["group1", "group2"], + "scope": "foo bar", + "aud": "llama-stack", + }, + key="1234567890", + algorithm="HS256", + headers={"kid": "1234567890"}, + ) + + +@patch("httpx.AsyncClient.get", new=mock_jwks_response) +def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid): + response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) + assert response.status_code == 200 + assert response.json() == {"message": "Authentication successful"} + + +# TODO: add more tests for oauth2 token provider