mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
feat: Extend the oauth_token provider to allow for token introspection
This may be desired in order to reject revoked tokens or where opaque tokens are used. Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
parent
87a4b9cb28
commit
7ee1ae0a3d
3 changed files with 251 additions and 13 deletions
|
@ -229,7 +229,7 @@ class AuthenticationConfig(BaseModel):
|
|||
...,
|
||||
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
||||
)
|
||||
config: dict[str, str] = Field(
|
||||
config: dict[str, Any] = Field(
|
||||
...,
|
||||
description="Provider-specific configuration",
|
||||
)
|
||||
|
|
|
@ -5,15 +5,18 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import ssl
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Lock
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -85,7 +88,7 @@ class AuthProviderConfig(BaseModel):
|
|||
"""Base configuration for authentication providers."""
|
||||
|
||||
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
||||
config: dict[str, str] = Field(..., description="Provider-specific configuration")
|
||||
config: dict[str, Any] = Field(..., description="Provider-specific configuration")
|
||||
|
||||
|
||||
class AuthProvider(ABC):
|
||||
|
@ -198,10 +201,21 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
|||
return attributes
|
||||
|
||||
|
||||
class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||
class OAuth2JWKSConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
jwks_uri: str
|
||||
uri: str
|
||||
cache_ttl: int = 3600
|
||||
|
||||
|
||||
class OAuth2IntrospectionConfig(BaseModel):
|
||||
url: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
send_secret_in_body: bool = False
|
||||
tls_cafile: str | None = None
|
||||
|
||||
|
||||
class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||
audience: str = "llama-stack"
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
|
@ -214,6 +228,8 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
|
|||
"namespace": "namespaces",
|
||||
},
|
||||
)
|
||||
jwks: OAuth2JWKSConfig | None
|
||||
introspection: OAuth2IntrospectionConfig | None = None
|
||||
|
||||
@classmethod
|
||||
@field_validator("claims_mapping")
|
||||
|
@ -225,6 +241,14 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
|
|||
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mode(self) -> Self:
|
||||
if not self.jwks and not self.introspection:
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
if self.jwks and self.introspection:
|
||||
raise ValueError("At present only one of jwks or introspection should be configured")
|
||||
return self
|
||||
|
||||
|
||||
class OAuth2TokenAuthProvider(AuthProvider):
|
||||
"""
|
||||
|
@ -240,8 +264,17 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
self._jwks_lock = Lock()
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
if self.config.jwks:
|
||||
return await self.validate_jwt_token(token, self.config.jwks, scope)
|
||||
if self.config.introspection:
|
||||
return await self.introspect_token(token, self.config.introspection, scope)
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
|
||||
async def validate_jwt_token(
|
||||
self, token: str, config: OAuth2JWKSConfig, scope: dict | None = None
|
||||
) -> TokenValidationResult:
|
||||
"""Validate a token using the JWT token."""
|
||||
await self._refresh_jwks()
|
||||
await self._refresh_jwks(config)
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
|
@ -269,14 +302,61 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
access_attributes=access_attributes,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
async def introspect_token(
|
||||
self, token: str, config: OAuth2IntrospectionConfig, scope: dict | None = None
|
||||
) -> TokenValidationResult:
|
||||
"""Validate a token using token introspection as defined by RFC 7662."""
|
||||
form = {
|
||||
"token": token,
|
||||
}
|
||||
if config.send_secret_in_body:
|
||||
form["client_id"] = config.client_id
|
||||
form["client_secret"] = config.client_secret
|
||||
auth = None
|
||||
else:
|
||||
auth = (config.client_id, config.client_secret)
|
||||
ssl_ctxt = None
|
||||
if config.tls_cafile:
|
||||
ssl_ctxt = ssl.create_default_context(cafile=config.tls_cafile)
|
||||
try:
|
||||
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
|
||||
response = await client.post(
|
||||
config.url,
|
||||
data=form,
|
||||
auth=auth,
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Token introspection failed: {response.status_code}")
|
||||
|
||||
async def _refresh_jwks(self) -> None:
|
||||
fields = response.json()
|
||||
if not fields["active"]:
|
||||
raise ValueError("Token not active")
|
||||
principal = fields["sub"] or fields["username"]
|
||||
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
||||
return TokenValidationResult(
|
||||
principal=principal,
|
||||
access_attributes=access_attributes,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Token introspection request timed out")
|
||||
raise
|
||||
except ValueError:
|
||||
# Re-raise ValueError exceptions to preserve their message
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error during token introspection")
|
||||
raise ValueError("Token introspection error") from e
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
async def _refresh_jwks(self, config: OAuth2JWKSConfig) -> None:
|
||||
async with self._jwks_lock:
|
||||
if time.time() - self._jwks_at > self.config.cache_ttl:
|
||||
if time.time() - self._jwks_at > config.cache_ttl:
|
||||
async with httpx.AsyncClient() as client:
|
||||
res = await client.get(self.config.jwks_uri, timeout=5)
|
||||
res = await client.get(config.uri, timeout=5)
|
||||
res.raise_for_status()
|
||||
jwks_data = res.json()["keys"]
|
||||
updated = {}
|
||||
|
|
|
@ -396,8 +396,10 @@ def oauth2_app():
|
|||
auth_config = AuthProviderConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks_uri": "http://mock-authz-service/token/introspect",
|
||||
"cache_ttl": "3600",
|
||||
"jwks": {
|
||||
"uri": "http://mock-authz-service/token/introspect",
|
||||
"cache_ttl": "3600",
|
||||
},
|
||||
"audience": "llama-stack",
|
||||
},
|
||||
)
|
||||
|
@ -517,3 +519,159 @@ def test_get_attributes_from_claims():
|
|||
|
||||
|
||||
# TODO: add more tests for oauth2 token provider
|
||||
|
||||
|
||||
# oauth token introspection tests
|
||||
@pytest.fixture
|
||||
def mock_introspection_endpoint():
|
||||
return "http://mock-authz-service/token/introspect"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def introspection_app(mock_introspection_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthProviderConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": None,
|
||||
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"},
|
||||
},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthProviderConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": None,
|
||||
"introspection": {
|
||||
"url": mock_introspection_endpoint,
|
||||
"client_id": "myclient",
|
||||
"client_secret": "abcdefg",
|
||||
"send_secret_in_body": "true",
|
||||
},
|
||||
"claims_mapping": {
|
||||
"sub": "roles",
|
||||
"scope": "roles",
|
||||
"groups": "teams",
|
||||
"aud": "namespaces",
|
||||
},
|
||||
},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def introspection_client(introspection_app):
|
||||
return TestClient(introspection_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def introspection_client_with_custom_mapping(introspection_app_with_custom_mapping):
|
||||
return TestClient(introspection_app_with_custom_mapping)
|
||||
|
||||
|
||||
def test_missing_auth_header_introspection(introspection_client):
|
||||
response = introspection_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_introspection(introspection_client):
|
||||
response = introspection_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
async def mock_introspection_active(*args, **kwargs):
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"active": True,
|
||||
"sub": "my-user",
|
||||
"groups": ["group1", "group2"],
|
||||
"scope": "foo bar",
|
||||
"aud": ["set1", "set2"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def mock_introspection_inactive(*args, **kwargs):
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"active": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def mock_introspection_invalid(*args, **kwargs):
|
||||
class InvalidResponse:
|
||||
def __init__(self, status_code):
|
||||
self.status_code = status_code
|
||||
|
||||
def json(self):
|
||||
raise ValueError("Not JSON")
|
||||
|
||||
return InvalidResponse(200)
|
||||
|
||||
|
||||
async def mock_introspection_failed(*args, **kwargs):
|
||||
return MockResponse(
|
||||
500,
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_introspection_active)
|
||||
def test_valid_introspection_authentication(introspection_client, valid_api_key):
|
||||
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_introspection_inactive)
|
||||
def test_inactive_introspection_authentication(introspection_client, invalid_api_key):
|
||||
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Token not active" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_introspection_invalid)
|
||||
def test_invalid_introspection_authentication(introspection_client, invalid_api_key):
|
||||
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Not JSON" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_introspection_failed)
|
||||
def test_failed_introspection_authentication(introspection_client, invalid_api_key):
|
||||
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Token introspection failed: 500" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_introspection_active)
|
||||
def test_valid_introspection_with_custom_mapping_authentication(
|
||||
introspection_client_with_custom_mapping, valid_api_key
|
||||
):
|
||||
response = introspection_client_with_custom_mapping.get(
|
||||
"/test", headers={"Authorization": f"Bearer {valid_api_key}"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue