mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 00:34:44 +00:00
Merge branch 'main' into patch-metadata
This commit is contained in:
commit
5a807da6af
6 changed files with 2551 additions and 2257 deletions
|
@ -93,7 +93,7 @@ class AuthenticationMiddleware:
|
||||||
|
|
||||||
# Validate token and get access attributes
|
# Validate token and get access attributes
|
||||||
try:
|
try:
|
||||||
access_attributes = await self.auth_provider.validate_token(token, scope)
|
validation_result = await self.auth_provider.validate_token(token, scope)
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
logger.exception("Authentication request timed out")
|
logger.exception("Authentication request timed out")
|
||||||
return await self._send_auth_error(send, "Authentication service timeout")
|
return await self._send_auth_error(send, "Authentication service timeout")
|
||||||
|
@ -105,17 +105,20 @@ class AuthenticationMiddleware:
|
||||||
return await self._send_auth_error(send, "Authentication service error")
|
return await self._send_auth_error(send, "Authentication service error")
|
||||||
|
|
||||||
# Store attributes in request scope for access control
|
# Store attributes in request scope for access control
|
||||||
if access_attributes:
|
if validation_result.access_attributes:
|
||||||
user_attributes = access_attributes.model_dump(exclude_none=True)
|
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
|
||||||
else:
|
else:
|
||||||
logger.warning("No access attributes, setting namespace to token by default")
|
logger.warning("No access attributes, setting namespace to token by default")
|
||||||
user_attributes = {
|
user_attributes = {
|
||||||
"namespaces": [token],
|
"roles": [token],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Store attributes in request scope
|
# Store attributes in request scope
|
||||||
scope["user_attributes"] = user_attributes
|
scope["user_attributes"] = user_attributes
|
||||||
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes")
|
scope["principal"] = validation_result.principal
|
||||||
|
logger.debug(
|
||||||
|
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
|
||||||
|
)
|
||||||
|
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
|
@ -5,12 +5,14 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, Field
|
from jose import jwt
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -18,9 +20,11 @@ from llama_stack.log import get_logger
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
|
||||||
|
|
||||||
class AuthResponse(BaseModel):
|
class TokenValidationResult(BaseModel):
|
||||||
"""The format of the authentication response from the auth endpoint."""
|
principal: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="The principal (username or persistent identifier) of the authenticated user",
|
||||||
|
)
|
||||||
access_attributes: AccessAttributes | None = Field(
|
access_attributes: AccessAttributes | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
|
@ -43,6 +47,10 @@ class AuthResponse(BaseModel):
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthResponse(TokenValidationResult):
|
||||||
|
"""The format of the authentication response from the auth endpoint."""
|
||||||
|
|
||||||
message: str | None = Field(
|
message: str | None = Field(
|
||||||
default=None, description="Optional message providing additional context about the authentication result."
|
default=None, description="Optional message providing additional context about the authentication result."
|
||||||
)
|
)
|
||||||
|
@ -69,6 +77,7 @@ class AuthProviderType(str, Enum):
|
||||||
|
|
||||||
KUBERNETES = "kubernetes"
|
KUBERNETES = "kubernetes"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
|
OAUTH2_TOKEN = "oauth2_token"
|
||||||
|
|
||||||
|
|
||||||
class AuthProviderConfig(BaseModel):
|
class AuthProviderConfig(BaseModel):
|
||||||
|
@ -82,7 +91,7 @@ class AuthProvider(ABC):
|
||||||
"""Abstract base class for authentication providers."""
|
"""Abstract base class for authentication providers."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||||
"""Validate a token and return access attributes."""
|
"""Validate a token and return access attributes."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -92,12 +101,16 @@ class AuthProvider(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class KubernetesAuthProviderConfig(BaseModel):
|
||||||
|
api_server_url: str
|
||||||
|
ca_cert_path: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class KubernetesAuthProvider(AuthProvider):
|
class KubernetesAuthProvider(AuthProvider):
|
||||||
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
||||||
|
|
||||||
def __init__(self, config: dict[str, str]):
|
def __init__(self, config: KubernetesAuthProviderConfig):
|
||||||
self.api_server_url = config["api_server_url"]
|
self.config = config
|
||||||
self.ca_cert_path = config.get("ca_cert_path")
|
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
async def _get_client(self):
|
async def _get_client(self):
|
||||||
|
@ -110,16 +123,16 @@ class KubernetesAuthProvider(AuthProvider):
|
||||||
|
|
||||||
# Configure the client
|
# Configure the client
|
||||||
configuration = client.Configuration()
|
configuration = client.Configuration()
|
||||||
configuration.host = self.api_server_url
|
configuration.host = self.config.api_server_url
|
||||||
if self.ca_cert_path:
|
if self.config.ca_cert_path:
|
||||||
configuration.ssl_ca_cert = self.ca_cert_path
|
configuration.ssl_ca_cert = self.config.ca_cert_path
|
||||||
configuration.verify_ssl = bool(self.ca_cert_path)
|
configuration.verify_ssl = bool(self.config.ca_cert_path)
|
||||||
|
|
||||||
# Create API client
|
# Create API client
|
||||||
self._client = ApiClient(configuration)
|
self._client = ApiClient(configuration)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||||
"""Validate a Kubernetes token and return access attributes."""
|
"""Validate a Kubernetes token and return access attributes."""
|
||||||
try:
|
try:
|
||||||
client = await self._get_client()
|
client = await self._get_client()
|
||||||
|
@ -146,9 +159,12 @@ class KubernetesAuthProvider(AuthProvider):
|
||||||
username = payload.get("sub", "")
|
username = payload.get("sub", "")
|
||||||
groups = payload.get("groups", [])
|
groups = payload.get("groups", [])
|
||||||
|
|
||||||
return AccessAttributes(
|
return TokenValidationResult(
|
||||||
roles=[username], # Use username as a role
|
principal=username,
|
||||||
teams=groups, # Use Kubernetes groups as teams
|
access_attributes=AccessAttributes(
|
||||||
|
roles=[username], # Use username as a role
|
||||||
|
teams=groups, # Use Kubernetes groups as teams
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -162,18 +178,125 @@ class KubernetesAuthProvider(AuthProvider):
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
||||||
|
attributes = AccessAttributes()
|
||||||
|
for claim_key, attribute_key in mapping.items():
|
||||||
|
if claim_key not in claims or not hasattr(attributes, attribute_key):
|
||||||
|
continue
|
||||||
|
claim = claims[claim_key]
|
||||||
|
if isinstance(claim, list):
|
||||||
|
values = claim
|
||||||
|
else:
|
||||||
|
values = claim.split()
|
||||||
|
|
||||||
|
current = getattr(attributes, attribute_key)
|
||||||
|
if current:
|
||||||
|
current.extend(values)
|
||||||
|
else:
|
||||||
|
setattr(attributes, attribute_key, values)
|
||||||
|
return attributes
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||||
|
# The JWKS URI for collecting public keys
|
||||||
|
jwks_uri: str
|
||||||
|
cache_ttl: int = 3600
|
||||||
|
audience: str = "llama-stack"
|
||||||
|
claims_mapping: dict[str, str] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"sub": "roles",
|
||||||
|
"username": "roles",
|
||||||
|
"groups": "teams",
|
||||||
|
"team": "teams",
|
||||||
|
"project": "projects",
|
||||||
|
"tenant": "namespaces",
|
||||||
|
"namespace": "namespaces",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@field_validator("claims_mapping")
|
||||||
|
def validate_claims_mapping(cls, v):
|
||||||
|
for key, value in v.items():
|
||||||
|
if not value:
|
||||||
|
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||||
|
if value not in AccessAttributes.model_fields:
|
||||||
|
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
|
"""
|
||||||
|
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
||||||
|
|
||||||
|
This should be the standard authentication provider for most use cases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: OAuth2TokenAuthProviderConfig):
|
||||||
|
self.config = config
|
||||||
|
self._jwks_at: float = 0.0
|
||||||
|
self._jwks: dict[str, str] = {}
|
||||||
|
|
||||||
|
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||||
|
"""Validate a token using the JWT token."""
|
||||||
|
await self._refresh_jwks()
|
||||||
|
|
||||||
|
try:
|
||||||
|
header = jwt.get_unverified_header(token)
|
||||||
|
kid = header["kid"]
|
||||||
|
if kid not in self._jwks:
|
||||||
|
raise ValueError(f"Unknown key ID: {kid}")
|
||||||
|
key_data = self._jwks[kid]
|
||||||
|
algorithm = header.get("alg", "RS256")
|
||||||
|
claims = jwt.decode(
|
||||||
|
token,
|
||||||
|
key_data,
|
||||||
|
algorithms=[algorithm],
|
||||||
|
audience=self.config.audience,
|
||||||
|
options={"verify_exp": True},
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValueError(f"Invalid JWT token: {token}") from exc
|
||||||
|
|
||||||
|
# There are other standard claims, the most relevant of which is `scope`.
|
||||||
|
# We should incorporate these into the access attributes.
|
||||||
|
principal = claims["sub"]
|
||||||
|
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||||
|
return TokenValidationResult(
|
||||||
|
principal=principal,
|
||||||
|
access_attributes=access_attributes,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
|
||||||
|
async def _refresh_jwks(self) -> None:
|
||||||
|
if time.time() - self._jwks_at > self.config.cache_ttl:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
res = await client.get(self.config.jwks_uri, timeout=5)
|
||||||
|
res.raise_for_status()
|
||||||
|
jwks_data = res.json()["keys"]
|
||||||
|
self._jwks = {}
|
||||||
|
for k in jwks_data:
|
||||||
|
kid = k["kid"]
|
||||||
|
# Store the entire key object as it may be needed for different algorithms
|
||||||
|
self._jwks[kid] = k
|
||||||
|
self._jwks_at = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
class CustomAuthProviderConfig(BaseModel):
|
||||||
|
endpoint: str
|
||||||
|
|
||||||
|
|
||||||
class CustomAuthProvider(AuthProvider):
|
class CustomAuthProvider(AuthProvider):
|
||||||
"""Custom authentication provider that uses an external endpoint."""
|
"""Custom authentication provider that uses an external endpoint."""
|
||||||
|
|
||||||
def __init__(self, config: dict[str, str]):
|
def __init__(self, config: CustomAuthProviderConfig):
|
||||||
self.endpoint = config["endpoint"]
|
self.config = config
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||||
"""Validate a token using the custom authentication endpoint."""
|
"""Validate a token using the custom authentication endpoint."""
|
||||||
if not self.endpoint:
|
|
||||||
raise ValueError("Authentication endpoint not configured")
|
|
||||||
|
|
||||||
if scope is None:
|
if scope is None:
|
||||||
scope = {}
|
scope = {}
|
||||||
|
|
||||||
|
@ -202,7 +325,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
self.endpoint,
|
self.config.endpoint,
|
||||||
json=auth_request.model_dump(),
|
json=auth_request.model_dump(),
|
||||||
timeout=10.0, # Add a reasonable timeout
|
timeout=10.0, # Add a reasonable timeout
|
||||||
)
|
)
|
||||||
|
@ -214,19 +337,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
try:
|
try:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
auth_response = AuthResponse(**response_data)
|
auth_response = AuthResponse(**response_data)
|
||||||
|
return auth_response
|
||||||
# Store attributes in request scope for access control
|
|
||||||
if auth_response.access_attributes:
|
|
||||||
return auth_response.access_attributes
|
|
||||||
else:
|
|
||||||
logger.warning("No access attributes, setting namespace to api_key by default")
|
|
||||||
user_attributes = {
|
|
||||||
"namespaces": [token],
|
|
||||||
}
|
|
||||||
|
|
||||||
scope["user_attributes"] = user_attributes
|
|
||||||
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
|
||||||
return auth_response.access_attributes
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error parsing authentication response")
|
logger.exception("Error parsing authentication response")
|
||||||
raise ValueError("Invalid authentication response format") from e
|
raise ValueError("Invalid authentication response format") from e
|
||||||
|
@ -253,9 +364,11 @@ def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
|
||||||
provider_type = config.provider_type.lower()
|
provider_type = config.provider_type.lower()
|
||||||
|
|
||||||
if provider_type == "kubernetes":
|
if provider_type == "kubernetes":
|
||||||
return KubernetesAuthProvider(config.config)
|
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
|
||||||
elif provider_type == "custom":
|
elif provider_type == "custom":
|
||||||
return CustomAuthProvider(config.config)
|
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
||||||
|
elif provider_type == "oauth2_token":
|
||||||
|
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
||||||
else:
|
else:
|
||||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
||||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
||||||
|
|
|
@ -31,6 +31,7 @@ dependencies = [
|
||||||
"openai>=1.66",
|
"openai>=1.66",
|
||||||
"prompt-toolkit",
|
"prompt-toolkit",
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
|
"python-jose",
|
||||||
"pydantic>=2",
|
"pydantic>=2",
|
||||||
"requests",
|
"requests",
|
||||||
"rich",
|
"rich",
|
||||||
|
|
|
@ -11,6 +11,7 @@ click==8.1.8
|
||||||
colorama==0.4.6 ; sys_platform == 'win32'
|
colorama==0.4.6 ; sys_platform == 'win32'
|
||||||
distro==1.9.0
|
distro==1.9.0
|
||||||
durationpy==0.9
|
durationpy==0.9
|
||||||
|
ecdsa==0.19.1
|
||||||
exceptiongroup==1.2.2 ; python_full_version < '3.11'
|
exceptiongroup==1.2.2 ; python_full_version < '3.11'
|
||||||
filelock==3.17.0
|
filelock==3.17.0
|
||||||
fire==0.7.0
|
fire==0.7.0
|
||||||
|
@ -39,14 +40,15 @@ pandas==2.2.3
|
||||||
pillow==11.1.0
|
pillow==11.1.0
|
||||||
prompt-toolkit==3.0.50
|
prompt-toolkit==3.0.50
|
||||||
pyaml==25.1.0
|
pyaml==25.1.0
|
||||||
pyasn1==0.6.1
|
pyasn1==0.4.8
|
||||||
pyasn1-modules==0.4.2
|
pyasn1-modules==0.4.1
|
||||||
pycryptodomex==3.21.0
|
pycryptodomex==3.21.0
|
||||||
pydantic==2.10.6
|
pydantic==2.10.6
|
||||||
pydantic-core==2.27.2
|
pydantic-core==2.27.2
|
||||||
pygments==2.19.1
|
pygments==2.19.1
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
|
python-jose==3.4.0
|
||||||
pytz==2025.1
|
pytz==2025.1
|
||||||
pyyaml==6.0.2
|
pyyaml==6.0.2
|
||||||
referencing==0.36.2
|
referencing==0.36.2
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -12,7 +13,12 @@ from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||||
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, AuthProviderType
|
from llama_stack.distribution.server.auth_providers import (
|
||||||
|
AuthProviderConfig,
|
||||||
|
AuthProviderType,
|
||||||
|
TokenValidationResult,
|
||||||
|
get_attributes_from_claims,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MockResponse:
|
class MockResponse:
|
||||||
|
@ -23,6 +29,10 @@ class MockResponse:
|
||||||
def json(self):
|
def json(self):
|
||||||
return self._json_data
|
return self._json_data
|
||||||
|
|
||||||
|
def raise_for_status(self):
|
||||||
|
if self.status_code != 200:
|
||||||
|
raise Exception(f"HTTP error: {self.status_code}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_auth_endpoint():
|
def mock_auth_endpoint():
|
||||||
|
@ -130,6 +140,7 @@ async def mock_post_success(*args, **kwargs):
|
||||||
200,
|
200,
|
||||||
{
|
{
|
||||||
"message": "Authentication successful",
|
"message": "Authentication successful",
|
||||||
|
"principal": "test-principal",
|
||||||
"access_attributes": {
|
"access_attributes": {
|
||||||
"roles": ["admin", "user"],
|
"roles": ["admin", "user"],
|
||||||
"teams": ["ml-team", "nlp-team"],
|
"teams": ["ml-team", "nlp-team"],
|
||||||
|
@ -223,6 +234,7 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
||||||
200,
|
200,
|
||||||
{
|
{
|
||||||
"message": "Authentication successful",
|
"message": "Authentication successful",
|
||||||
|
"principal": "test-principal",
|
||||||
"access_attributes": {
|
"access_attributes": {
|
||||||
"roles": ["admin", "user"],
|
"roles": ["admin", "user"],
|
||||||
"teams": ["ml-team", "nlp-team"],
|
"teams": ["ml-team", "nlp-team"],
|
||||||
|
@ -268,8 +280,8 @@ async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope):
|
||||||
|
|
||||||
assert "user_attributes" in mock_scope
|
assert "user_attributes" in mock_scope
|
||||||
attributes = mock_scope["user_attributes"]
|
attributes = mock_scope["user_attributes"]
|
||||||
assert "namespaces" in attributes
|
assert "roles" in attributes
|
||||||
assert attributes["namespaces"] == ["test.jwt.token"]
|
assert attributes["roles"] == ["test.jwt.token"]
|
||||||
|
|
||||||
|
|
||||||
# Kubernetes Tests
|
# Kubernetes Tests
|
||||||
|
@ -296,8 +308,11 @@ def test_valid_k8s_authentication(mock_api_client, k8s_client, valid_token):
|
||||||
|
|
||||||
# Mock the token validation to return valid access attributes
|
# Mock the token validation to return valid access attributes
|
||||||
with patch("llama_stack.distribution.server.auth_providers.KubernetesAuthProvider.validate_token") as mock_validate:
|
with patch("llama_stack.distribution.server.auth_providers.KubernetesAuthProvider.validate_token") as mock_validate:
|
||||||
mock_validate.return_value = AccessAttributes(
|
mock_validate.return_value = TokenValidationResult(
|
||||||
roles=["admin"], teams=["ml-team"], projects=["llama-3"], namespaces=["research"]
|
principal="test-principal",
|
||||||
|
access_attributes=AccessAttributes(
|
||||||
|
roles=["admin"], teams=["ml-team"], projects=["llama-3"], namespaces=["research"]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
response = k8s_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
response = k8s_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
@ -370,3 +385,135 @@ async def test_k8s_middleware_no_attributes(mock_k8s_middleware, mock_scope):
|
||||||
assert attributes["roles"] == ["admin"]
|
assert attributes["roles"] == ["admin"]
|
||||||
|
|
||||||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||||
|
|
||||||
|
|
||||||
|
# oauth2 token provider tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def oauth2_app():
|
||||||
|
app = FastAPI()
|
||||||
|
auth_config = AuthProviderConfig(
|
||||||
|
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
|
config={
|
||||||
|
"jwks_uri": "http://mock-authz-service/token/introspect",
|
||||||
|
"cache_ttl": "3600",
|
||||||
|
"audience": "llama-stack",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
def test_endpoint():
|
||||||
|
return {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def oauth2_client(oauth2_app):
|
||||||
|
return TestClient(oauth2_app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_auth_header_oauth2(oauth2_client):
|
||||||
|
response = oauth2_client.get("/test")
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_auth_header_format_oauth2(oauth2_client):
|
||||||
|
response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_jwks_response(*args, **kwargs):
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"keys": [
|
||||||
|
{
|
||||||
|
"kid": "1234567890",
|
||||||
|
"kty": "oct",
|
||||||
|
"alg": "HS256",
|
||||||
|
"use": "sig",
|
||||||
|
"k": base64.b64encode(b"foobarbaz").decode(),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def jwt_token_valid():
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
return jwt.encode(
|
||||||
|
{
|
||||||
|
"sub": "my-user",
|
||||||
|
"groups": ["group1", "group2"],
|
||||||
|
"scope": "foo bar",
|
||||||
|
"aud": "llama-stack",
|
||||||
|
},
|
||||||
|
key="foobarbaz",
|
||||||
|
algorithm="HS256",
|
||||||
|
headers={"kid": "1234567890"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
|
||||||
|
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid):
|
||||||
|
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
|
||||||
|
def test_invalid_oauth2_authentication(oauth2_client, invalid_token):
|
||||||
|
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Invalid JWT token" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_attributes_from_claims():
|
||||||
|
claims = {
|
||||||
|
"sub": "my-user",
|
||||||
|
"groups": ["group1", "group2"],
|
||||||
|
"scope": "foo bar",
|
||||||
|
"aud": "llama-stack",
|
||||||
|
}
|
||||||
|
attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"})
|
||||||
|
assert attributes.roles == ["my-user"]
|
||||||
|
assert attributes.teams == ["group1", "group2"]
|
||||||
|
|
||||||
|
claims = {
|
||||||
|
"sub": "my-user",
|
||||||
|
"tenant": "my-tenant",
|
||||||
|
}
|
||||||
|
attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"})
|
||||||
|
assert attributes.roles == ["my-user"]
|
||||||
|
assert attributes.namespaces == ["my-tenant"]
|
||||||
|
|
||||||
|
claims = {
|
||||||
|
"sub": "my-user",
|
||||||
|
"username": "my-username",
|
||||||
|
"tenant": "my-tenant",
|
||||||
|
"groups": ["group1", "group2"],
|
||||||
|
"team": "my-team",
|
||||||
|
}
|
||||||
|
attributes = get_attributes_from_claims(
|
||||||
|
claims,
|
||||||
|
{
|
||||||
|
"sub": "roles",
|
||||||
|
"tenant": "namespaces",
|
||||||
|
"username": "roles",
|
||||||
|
"team": "teams",
|
||||||
|
"groups": "teams",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert set(attributes.roles) == {"my-user", "my-username"}
|
||||||
|
assert set(attributes.teams) == {"my-team", "group1", "group2"}
|
||||||
|
assert attributes.namespaces == ["my-tenant"]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: add more tests for oauth2 token provider
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue