mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
feat: Extend the oauth_token provider to allow for token introspection
This may be desired in order to reject revoked tokens or where opaque tokens are used. Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
parent
87a4b9cb28
commit
7ee1ae0a3d
3 changed files with 251 additions and 13 deletions
|
@ -229,7 +229,7 @@ class AuthenticationConfig(BaseModel):
|
||||||
...,
|
...,
|
||||||
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
||||||
)
|
)
|
||||||
config: dict[str, str] = Field(
|
config: dict[str, Any] = Field(
|
||||||
...,
|
...,
|
||||||
description="Provider-specific configuration",
|
description="Provider-specific configuration",
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,15 +5,18 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import ssl
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from asyncio import Lock
|
from asyncio import Lock
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -85,7 +88,7 @@ class AuthProviderConfig(BaseModel):
|
||||||
"""Base configuration for authentication providers."""
|
"""Base configuration for authentication providers."""
|
||||||
|
|
||||||
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
||||||
config: dict[str, str] = Field(..., description="Provider-specific configuration")
|
config: dict[str, Any] = Field(..., description="Provider-specific configuration")
|
||||||
|
|
||||||
|
|
||||||
class AuthProvider(ABC):
|
class AuthProvider(ABC):
|
||||||
|
@ -198,10 +201,21 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
|
|
||||||
class OAuth2TokenAuthProviderConfig(BaseModel):
|
class OAuth2JWKSConfig(BaseModel):
|
||||||
# The JWKS URI for collecting public keys
|
# The JWKS URI for collecting public keys
|
||||||
jwks_uri: str
|
uri: str
|
||||||
cache_ttl: int = 3600
|
cache_ttl: int = 3600
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2IntrospectionConfig(BaseModel):
|
||||||
|
url: str
|
||||||
|
client_id: str
|
||||||
|
client_secret: str
|
||||||
|
send_secret_in_body: bool = False
|
||||||
|
tls_cafile: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||||
audience: str = "llama-stack"
|
audience: str = "llama-stack"
|
||||||
claims_mapping: dict[str, str] = Field(
|
claims_mapping: dict[str, str] = Field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
|
@ -214,6 +228,8 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||||
"namespace": "namespaces",
|
"namespace": "namespaces",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
jwks: OAuth2JWKSConfig | None
|
||||||
|
introspection: OAuth2IntrospectionConfig | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@field_validator("claims_mapping")
|
@field_validator("claims_mapping")
|
||||||
|
@ -225,6 +241,14 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||||
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_mode(self) -> Self:
|
||||||
|
if not self.jwks and not self.introspection:
|
||||||
|
raise ValueError("One of jwks or introspection must be configured")
|
||||||
|
if self.jwks and self.introspection:
|
||||||
|
raise ValueError("At present only one of jwks or introspection should be configured")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class OAuth2TokenAuthProvider(AuthProvider):
|
class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
"""
|
"""
|
||||||
|
@ -240,8 +264,17 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
self._jwks_lock = Lock()
|
self._jwks_lock = Lock()
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||||
|
if self.config.jwks:
|
||||||
|
return await self.validate_jwt_token(token, self.config.jwks, scope)
|
||||||
|
if self.config.introspection:
|
||||||
|
return await self.introspect_token(token, self.config.introspection, scope)
|
||||||
|
raise ValueError("One of jwks or introspection must be configured")
|
||||||
|
|
||||||
|
async def validate_jwt_token(
|
||||||
|
self, token: str, config: OAuth2JWKSConfig, scope: dict | None = None
|
||||||
|
) -> TokenValidationResult:
|
||||||
"""Validate a token using the JWT token."""
|
"""Validate a token using the JWT token."""
|
||||||
await self._refresh_jwks()
|
await self._refresh_jwks(config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
header = jwt.get_unverified_header(token)
|
header = jwt.get_unverified_header(token)
|
||||||
|
@ -269,14 +302,61 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
access_attributes=access_attributes,
|
access_attributes=access_attributes,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def close(self):
|
async def introspect_token(
|
||||||
"""Close the HTTP client."""
|
self, token: str, config: OAuth2IntrospectionConfig, scope: dict | None = None
|
||||||
|
) -> TokenValidationResult:
|
||||||
|
"""Validate a token using token introspection as defined by RFC 7662."""
|
||||||
|
form = {
|
||||||
|
"token": token,
|
||||||
|
}
|
||||||
|
if config.send_secret_in_body:
|
||||||
|
form["client_id"] = config.client_id
|
||||||
|
form["client_secret"] = config.client_secret
|
||||||
|
auth = None
|
||||||
|
else:
|
||||||
|
auth = (config.client_id, config.client_secret)
|
||||||
|
ssl_ctxt = None
|
||||||
|
if config.tls_cafile:
|
||||||
|
ssl_ctxt = ssl.create_default_context(cafile=config.tls_cafile)
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
|
||||||
|
response = await client.post(
|
||||||
|
config.url,
|
||||||
|
data=form,
|
||||||
|
auth=auth,
|
||||||
|
timeout=10.0, # Add a reasonable timeout
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
||||||
|
raise ValueError(f"Token introspection failed: {response.status_code}")
|
||||||
|
|
||||||
async def _refresh_jwks(self) -> None:
|
fields = response.json()
|
||||||
|
if not fields["active"]:
|
||||||
|
raise ValueError("Token not active")
|
||||||
|
principal = fields["sub"] or fields["username"]
|
||||||
|
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
||||||
|
return TokenValidationResult(
|
||||||
|
principal=principal,
|
||||||
|
access_attributes=access_attributes,
|
||||||
|
)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.exception("Token introspection request timed out")
|
||||||
|
raise
|
||||||
|
except ValueError:
|
||||||
|
# Re-raise ValueError exceptions to preserve their message
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error during token introspection")
|
||||||
|
raise ValueError("Token introspection error") from e
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _refresh_jwks(self, config: OAuth2JWKSConfig) -> None:
|
||||||
async with self._jwks_lock:
|
async with self._jwks_lock:
|
||||||
if time.time() - self._jwks_at > self.config.cache_ttl:
|
if time.time() - self._jwks_at > config.cache_ttl:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
res = await client.get(self.config.jwks_uri, timeout=5)
|
res = await client.get(config.uri, timeout=5)
|
||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
jwks_data = res.json()["keys"]
|
jwks_data = res.json()["keys"]
|
||||||
updated = {}
|
updated = {}
|
||||||
|
|
|
@ -396,8 +396,10 @@ def oauth2_app():
|
||||||
auth_config = AuthProviderConfig(
|
auth_config = AuthProviderConfig(
|
||||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
config={
|
config={
|
||||||
"jwks_uri": "http://mock-authz-service/token/introspect",
|
"jwks": {
|
||||||
"cache_ttl": "3600",
|
"uri": "http://mock-authz-service/token/introspect",
|
||||||
|
"cache_ttl": "3600",
|
||||||
|
},
|
||||||
"audience": "llama-stack",
|
"audience": "llama-stack",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -517,3 +519,159 @@ def test_get_attributes_from_claims():
|
||||||
|
|
||||||
|
|
||||||
# TODO: add more tests for oauth2 token provider
|
# TODO: add more tests for oauth2 token provider
|
||||||
|
|
||||||
|
|
||||||
|
# oauth token introspection tests
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_introspection_endpoint():
|
||||||
|
return "http://mock-authz-service/token/introspect"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def introspection_app(mock_introspection_endpoint):
|
||||||
|
app = FastAPI()
|
||||||
|
auth_config = AuthProviderConfig(
|
||||||
|
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
|
config={
|
||||||
|
"jwks": None,
|
||||||
|
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
def test_endpoint():
|
||||||
|
return {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
||||||
|
app = FastAPI()
|
||||||
|
auth_config = AuthProviderConfig(
|
||||||
|
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
|
config={
|
||||||
|
"jwks": None,
|
||||||
|
"introspection": {
|
||||||
|
"url": mock_introspection_endpoint,
|
||||||
|
"client_id": "myclient",
|
||||||
|
"client_secret": "abcdefg",
|
||||||
|
"send_secret_in_body": "true",
|
||||||
|
},
|
||||||
|
"claims_mapping": {
|
||||||
|
"sub": "roles",
|
||||||
|
"scope": "roles",
|
||||||
|
"groups": "teams",
|
||||||
|
"aud": "namespaces",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
def test_endpoint():
|
||||||
|
return {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def introspection_client(introspection_app):
|
||||||
|
return TestClient(introspection_app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def introspection_client_with_custom_mapping(introspection_app_with_custom_mapping):
|
||||||
|
return TestClient(introspection_app_with_custom_mapping)
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_auth_header_introspection(introspection_client):
|
||||||
|
response = introspection_client.get("/test")
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_auth_header_format_introspection(introspection_client):
|
||||||
|
response = introspection_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_introspection_active(*args, **kwargs):
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"active": True,
|
||||||
|
"sub": "my-user",
|
||||||
|
"groups": ["group1", "group2"],
|
||||||
|
"scope": "foo bar",
|
||||||
|
"aud": ["set1", "set2"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_introspection_inactive(*args, **kwargs):
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"active": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_introspection_invalid(*args, **kwargs):
|
||||||
|
class InvalidResponse:
|
||||||
|
def __init__(self, status_code):
|
||||||
|
self.status_code = status_code
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
raise ValueError("Not JSON")
|
||||||
|
|
||||||
|
return InvalidResponse(200)
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_introspection_failed(*args, **kwargs):
|
||||||
|
return MockResponse(
|
||||||
|
500,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_introspection_active)
|
||||||
|
def test_valid_introspection_authentication(introspection_client, valid_api_key):
|
||||||
|
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_introspection_inactive)
|
||||||
|
def test_inactive_introspection_authentication(introspection_client, invalid_api_key):
|
||||||
|
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Token not active" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_introspection_invalid)
|
||||||
|
def test_invalid_introspection_authentication(introspection_client, invalid_api_key):
|
||||||
|
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Not JSON" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_introspection_failed)
|
||||||
|
def test_failed_introspection_authentication(introspection_client, invalid_api_key):
|
||||||
|
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Token introspection failed: 500" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_introspection_active)
|
||||||
|
def test_valid_introspection_with_custom_mapping_authentication(
|
||||||
|
introspection_client_with_custom_mapping, valid_api_key
|
||||||
|
):
|
||||||
|
response = introspection_client_with_custom_mapping.get(
|
||||||
|
"/test", headers={"Authorization": f"Bearer {valid_api_key}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "Authentication successful"}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue