mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
feat(server): add attribute based access control for resources
This commit is contained in:
parent
7c0448456e
commit
b937a49436
8 changed files with 862 additions and 35 deletions
|
@ -14,6 +14,7 @@ from llama_stack.apis.datasets import Dataset, DatasetInput
|
||||||
from llama_stack.apis.eval import Eval
|
from llama_stack.apis.eval import Eval
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.models import Model, ModelInput
|
from llama_stack.apis.models import Model, ModelInput
|
||||||
|
from llama_stack.apis.resource import Resource
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||||
|
@ -31,28 +32,137 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||||
RoutingKey = Union[str, List[str]]
|
RoutingKey = Union[str, List[str]]
|
||||||
|
|
||||||
|
|
||||||
|
class AccessAttributes(BaseModel):
|
||||||
|
"""Structured representation of user attributes for access control.
|
||||||
|
|
||||||
|
This model defines a structured approach to representing user attributes
|
||||||
|
with common standard categories for access control.
|
||||||
|
|
||||||
|
Standard attribute categories include:
|
||||||
|
- roles: Role-based attributes (e.g., admin, data-scientist)
|
||||||
|
- teams: Team-based attributes (e.g., ml-team, infra-team)
|
||||||
|
- projects: Project access attributes (e.g., llama-3, customer-insights)
|
||||||
|
- namespaces: Namespace-based access control for resource isolation
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Standard attribute categories - the minimal set we need now
|
||||||
|
roles: Optional[List[str]] = Field(
|
||||||
|
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
||||||
|
)
|
||||||
|
|
||||||
|
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||||
|
|
||||||
|
projects: Optional[List[str]] = Field(
|
||||||
|
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
||||||
|
)
|
||||||
|
|
||||||
|
namespaces: Optional[List[str]] = Field(
|
||||||
|
default=None, description="Namespace-based access control for resource isolation"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceWithACL(Resource):
|
||||||
|
"""Extension of Resource that adds attribute-based access control capabilities.
|
||||||
|
|
||||||
|
This class adds an optional access_attributes field that allows fine-grained control
|
||||||
|
over which users can access each resource. When attributes are defined, a user must have
|
||||||
|
matching attributes to access the resource.
|
||||||
|
|
||||||
|
Attribute Matching Algorithm:
|
||||||
|
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
|
||||||
|
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
|
||||||
|
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
|
||||||
|
4. Within each category, ANY value match is sufficient (OR relationship within a category)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Resource visible to everyone (no access control)
|
||||||
|
model = Model(identifier="llama-2", ...)
|
||||||
|
|
||||||
|
# Resource visible only to admins
|
||||||
|
model = Model(
|
||||||
|
identifier="gpt-4",
|
||||||
|
access_attributes=AccessAttributes(roles=["admin"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resource visible to data scientists on the ML team
|
||||||
|
model = Model(
|
||||||
|
identifier="private-model",
|
||||||
|
access_attributes=AccessAttributes(
|
||||||
|
roles=["data-scientist", "researcher"],
|
||||||
|
teams=["ml-team"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# ^ User must have at least one of the roles AND be on the ml-team
|
||||||
|
|
||||||
|
# Resource visible to users with specific project access
|
||||||
|
vector_db = VectorDB(
|
||||||
|
identifier="customer-embeddings",
|
||||||
|
access_attributes=AccessAttributes(
|
||||||
|
projects=["customer-insights"],
|
||||||
|
namespaces=["confidential"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# ^ User must have access to the customer-insights project AND have confidential namespace
|
||||||
|
"""
|
||||||
|
|
||||||
|
access_attributes: Optional[AccessAttributes] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Use the extended Resource for all routable objects
|
||||||
|
class ModelWithACL(Model, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldWithACL(Shield, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDBWithACL(VectorDB, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetWithACL(Dataset, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkWithACL(Benchmark, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ToolWithACL(Tool, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
RoutableObject = Union[
|
RoutableObject = Union[
|
||||||
Model,
|
ModelWithACL,
|
||||||
Shield,
|
ShieldWithACL,
|
||||||
VectorDB,
|
VectorDBWithACL,
|
||||||
Dataset,
|
DatasetWithACL,
|
||||||
ScoringFn,
|
ScoringFnWithACL,
|
||||||
Benchmark,
|
BenchmarkWithACL,
|
||||||
Tool,
|
ToolWithACL,
|
||||||
ToolGroup,
|
ToolGroupWithACL,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
RoutableObjectWithProvider = Annotated[
|
RoutableObjectWithProvider = Annotated[
|
||||||
Union[
|
Union[
|
||||||
Model,
|
ModelWithACL,
|
||||||
Shield,
|
ShieldWithACL,
|
||||||
VectorDB,
|
VectorDBWithACL,
|
||||||
Dataset,
|
DatasetWithACL,
|
||||||
ScoringFn,
|
ScoringFnWithACL,
|
||||||
Benchmark,
|
BenchmarkWithACL,
|
||||||
Tool,
|
ToolWithACL,
|
||||||
ToolGroup,
|
ToolGroupWithACL,
|
||||||
],
|
],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
|
@ -7,21 +7,29 @@
|
||||||
import contextvars
|
import contextvars
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, ContextManager, Dict, Optional
|
from typing import Any, ContextManager, Dict, List, Optional
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Context variable for request provider data
|
# Context variable for request provider data and auth attributes
|
||||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||||
|
|
||||||
|
|
||||||
class RequestProviderDataContext(ContextManager):
|
class RequestProviderDataContext(ContextManager):
|
||||||
"""Context manager for request provider data"""
|
"""Context manager for request provider data"""
|
||||||
|
|
||||||
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
|
def __init__(
|
||||||
self.provider_data = provider_data
|
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||||
|
):
|
||||||
|
# Initialize with either provider_data or create a new dict
|
||||||
|
self.provider_data = provider_data or {}
|
||||||
|
|
||||||
|
# Add auth attributes under a special key if provided
|
||||||
|
if auth_attributes:
|
||||||
|
self.provider_data["__auth_attributes"] = auth_attributes
|
||||||
|
|
||||||
self.token = None
|
self.token = None
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
@ -80,7 +88,17 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
|
def request_provider_data_context(
|
||||||
"""Context manager that sets request provider data from headers for the duration of the context"""
|
headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||||
|
) -> ContextManager:
|
||||||
|
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
||||||
provider_data = parse_request_provider_data(headers)
|
provider_data = parse_request_provider_data(headers)
|
||||||
return RequestProviderDataContext(provider_data)
|
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
|
||||||
|
"""Helper to retrieve auth attributes from the provider data context"""
|
||||||
|
provider_data = PROVIDER_DATA_VAR.get()
|
||||||
|
if not provider_data:
|
||||||
|
return None
|
||||||
|
return provider_data.get("__auth_attributes")
|
||||||
|
|
|
@ -40,10 +40,12 @@ from llama_stack.apis.tools import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AccessAttributes,
|
||||||
RoutableObject,
|
RoutableObject,
|
||||||
RoutableObjectWithProvider,
|
RoutableObjectWithProvider,
|
||||||
RoutedProtocol,
|
RoutedProtocol,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
|
@ -184,6 +186,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
if not obj:
|
if not obj:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Check if user has permission to access this object
|
||||||
|
if not self._check_access(obj):
|
||||||
|
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||||
|
return None
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
|
@ -200,6 +207,13 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
|
# If object supports access control but no attributes set, use creator's attributes
|
||||||
|
if not obj.access_attributes:
|
||||||
|
creator_attributes = get_auth_attributes()
|
||||||
|
if creator_attributes:
|
||||||
|
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||||
|
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||||
|
|
||||||
registered_obj = await register_object_with_provider(obj, p)
|
registered_obj = await register_object_with_provider(obj, p)
|
||||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||||
if obj.type == ResourceType.model.value:
|
if obj.type == ResourceType.model.value:
|
||||||
|
@ -212,7 +226,89 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||||
objs = await self.dist_registry.get_all()
|
objs = await self.dist_registry.get_all()
|
||||||
return [obj for obj in objs if obj.type == type]
|
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||||
|
|
||||||
|
# Apply attribute-based access control filtering
|
||||||
|
if filtered_objs:
|
||||||
|
filtered_objs = [obj for obj in filtered_objs if self._check_access(obj)]
|
||||||
|
|
||||||
|
return filtered_objs
|
||||||
|
|
||||||
|
def _check_access(self, obj: RoutableObjectWithProvider) -> bool:
|
||||||
|
"""Check if the current user has access to the given object, based on access attributes.
|
||||||
|
|
||||||
|
Access control algorithm:
|
||||||
|
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
|
||||||
|
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
|
||||||
|
3. For each attribute category in the resource's access_attributes:
|
||||||
|
a. If the user lacks that category, access is DENIED
|
||||||
|
b. If the user has the category but none of the required values, access is DENIED
|
||||||
|
c. If the user has at least one matching value in each required category, access is GRANTED
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Resource requires:
|
||||||
|
access_attributes = AccessAttributes(
|
||||||
|
roles=["admin", "data-scientist"],
|
||||||
|
teams=["ml-team"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# User has:
|
||||||
|
user_attributes = {
|
||||||
|
"roles": ["data-scientist", "engineer"],
|
||||||
|
"teams": ["ml-team", "infra-team"],
|
||||||
|
"projects": ["llama-3"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Result: Access GRANTED
|
||||||
|
# - User has the "data-scientist" role (matches one of the required roles)
|
||||||
|
# - AND user is part of the "ml-team" (matches the required team)
|
||||||
|
# - The extra "projects" attribute is ignored
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The resource object to check access for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if access is granted, False if denied
|
||||||
|
"""
|
||||||
|
# If object has no access attributes, allow access by default
|
||||||
|
if not hasattr(obj, "access_attributes") or not obj.access_attributes:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Get user attributes from context
|
||||||
|
user_attributes = get_auth_attributes()
|
||||||
|
|
||||||
|
# If no user attributes, deny access to objects with access control
|
||||||
|
if not user_attributes:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Convert AccessAttributes to dictionary for checking
|
||||||
|
obj_attributes = obj.access_attributes.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
# If the model_dump is empty (all fields are None), allow access
|
||||||
|
if not obj_attributes:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check each attribute category (requires ALL categories to match)
|
||||||
|
for attr_key, required_values in obj_attributes.items():
|
||||||
|
user_values = user_attributes.get(attr_key, [])
|
||||||
|
|
||||||
|
# No values for this category in user attributes
|
||||||
|
if not user_values:
|
||||||
|
logger.debug(
|
||||||
|
f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# None of the values in this category match (need at least one match per category)
|
||||||
|
if not any(val in user_values for val in required_values):
|
||||||
|
logger.debug(
|
||||||
|
f"Access denied to {obj.type} '{obj.identifier}': "
|
||||||
|
f"no match for attribute '{attr_key}', required one of {required_values}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.debug(f"Access granted to {obj.type} '{obj.identifier}'")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
|
@ -5,16 +5,118 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import Dict, List, Optional
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
|
||||||
|
|
||||||
|
class AuthRequestContext(BaseModel):
|
||||||
|
path: str = Field(description="The path of the request being authenticated")
|
||||||
|
|
||||||
|
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||||
|
|
||||||
|
params: Dict[str, List[str]] = Field(
|
||||||
|
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthRequest(BaseModel):
|
||||||
|
api_key: str = Field(description="The API key extracted from the Authorization header")
|
||||||
|
|
||||||
|
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||||
|
|
||||||
|
|
||||||
|
class AuthResponse(BaseModel):
|
||||||
|
"""The format of the authentication response from the auth endpoint."""
|
||||||
|
|
||||||
|
access_attributes: Optional[AccessAttributes] = Field(
|
||||||
|
default=None,
|
||||||
|
description="""
|
||||||
|
Structured user attributes for attribute-based access control.
|
||||||
|
|
||||||
|
These attributes determine which resources the user can access.
|
||||||
|
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
||||||
|
Each attribute category contains a list of values that the user has for that category.
|
||||||
|
During access control checks, these values are compared against resource requirements.
|
||||||
|
|
||||||
|
Example with standard categories:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"roles": ["admin", "data-scientist"],
|
||||||
|
"teams": ["ml-team"],
|
||||||
|
"projects": ["llama-3"],
|
||||||
|
"namespaces": ["research"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
message: Optional[str] = Field(
|
||||||
|
default=None, description="Optional message providing additional context about the authentication result."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationMiddleware:
|
class AuthenticationMiddleware:
|
||||||
|
"""Middleware that authenticates requests using an external auth endpoint.
|
||||||
|
|
||||||
|
This middleware:
|
||||||
|
1. Extracts the Bearer token from the Authorization header
|
||||||
|
2. Sends it to the configured auth endpoint along with request details
|
||||||
|
3. Validates the response and extracts user attributes
|
||||||
|
4. Makes these attributes available to the route handlers for access control
|
||||||
|
|
||||||
|
Authentication Request Format:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"api_key": "the-api-key-extracted-from-auth-header",
|
||||||
|
"request": {
|
||||||
|
"path": "/models/list",
|
||||||
|
"headers": {
|
||||||
|
"content-type": "application/json",
|
||||||
|
"user-agent": "..."
|
||||||
|
// All headers except Authorization
|
||||||
|
},
|
||||||
|
"params": {
|
||||||
|
"limit": ["100"],
|
||||||
|
"offset": ["0"]
|
||||||
|
// Query parameters as key -> list of values
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected Auth Endpoint Response Format:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"access_attributes": { // Structured attribute format
|
||||||
|
"roles": ["admin", "user"],
|
||||||
|
"teams": ["ml-team", "nlp-team"],
|
||||||
|
"projects": ["llama-3", "project-x"],
|
||||||
|
"namespaces": ["research"]
|
||||||
|
},
|
||||||
|
"message": "Optional message about auth result"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Attribute-Based Access Control:
|
||||||
|
The attributes returned by the auth endpoint are used to determine which
|
||||||
|
resources the user can access. Resources can specify required attributes
|
||||||
|
using the access_attributes field. For a user to access a resource:
|
||||||
|
|
||||||
|
1. All attribute categories specified in the resource must be present in the user's attributes
|
||||||
|
2. For each category, the user must have at least one matching value
|
||||||
|
|
||||||
|
If the auth endpoint doesn't return any attributes, the user will only be able to
|
||||||
|
access resources that don't have access_attributes defined.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, app, auth_endpoint):
|
def __init__(self, app, auth_endpoint):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.auth_endpoint = auth_endpoint
|
self.auth_endpoint = auth_endpoint
|
||||||
|
@ -32,25 +134,57 @@ class AuthenticationMiddleware:
|
||||||
path = scope.get("path", "")
|
path = scope.get("path", "")
|
||||||
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||||
|
|
||||||
|
# Remove sensitive headers
|
||||||
|
if "authorization" in request_headers:
|
||||||
|
del request_headers["authorization"]
|
||||||
|
|
||||||
query_string = scope.get("query_string", b"").decode()
|
query_string = scope.get("query_string", b"").decode()
|
||||||
params = parse_qs(query_string)
|
params = parse_qs(query_string)
|
||||||
|
|
||||||
auth_data = {
|
# Build the auth request model
|
||||||
"api_key": api_key,
|
auth_request = AuthRequest(
|
||||||
"request": {
|
api_key=api_key,
|
||||||
"path": path,
|
request=AuthRequestContext(
|
||||||
"headers": request_headers,
|
path=path,
|
||||||
"params": params,
|
headers=request_headers,
|
||||||
},
|
params=params,
|
||||||
}
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# Validate with authentication endpoint
|
# Validate with authentication endpoint
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(self.auth_endpoint, json=auth_data)
|
response = await client.post(
|
||||||
|
self.auth_endpoint,
|
||||||
|
json=auth_request.model_dump(),
|
||||||
|
timeout=10.0, # Add a reasonable timeout
|
||||||
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
logger.warning(f"Authentication failed: {response.status_code}")
|
logger.warning(f"Authentication failed: {response.status_code}")
|
||||||
return await self._send_auth_error(send, "Authentication failed")
|
return await self._send_auth_error(send, "Authentication failed")
|
||||||
|
|
||||||
|
# Parse and validate the auth response
|
||||||
|
try:
|
||||||
|
response_data = response.json()
|
||||||
|
|
||||||
|
auth_response = AuthResponse(**response_data)
|
||||||
|
|
||||||
|
# Store attributes in request scope for access control
|
||||||
|
if auth_response.access_attributes:
|
||||||
|
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
|
||||||
|
scope["user_attributes"] = user_attributes
|
||||||
|
else:
|
||||||
|
logger.warning("Authentication response did not contain any attributes")
|
||||||
|
scope["user_attributes"] = {}
|
||||||
|
|
||||||
|
# Log authentication success with attribute details
|
||||||
|
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error parsing authentication response")
|
||||||
|
return await self._send_auth_error(send, "Invalid authentication response format")
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.exception("Authentication request timed out")
|
||||||
|
return await self._send_auth_error(send, "Authentication service timeout")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error during authentication")
|
logger.exception("Error during authentication")
|
||||||
return await self._send_auth_error(send, "Authentication service error")
|
return await self._send_auth_error(send, "Authentication service error")
|
||||||
|
|
|
@ -28,7 +28,9 @@ from typing_extensions import Annotated
|
||||||
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import (
|
from llama_stack.distribution.request_headers import (
|
||||||
|
AUTH_ATTRIBUTES_VAR,
|
||||||
PROVIDER_DATA_VAR,
|
PROVIDER_DATA_VAR,
|
||||||
|
auth_attributes_context,
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
|
@ -179,8 +181,11 @@ async def sse_generator(event_gen):
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
async def endpoint(request: Request, **kwargs):
|
async def endpoint(request: Request, **kwargs):
|
||||||
# Use context manager for request provider data
|
# Get auth attributes from the request scope
|
||||||
with request_provider_data_context(request.headers):
|
user_attributes = request.scope.get("user_attributes", {})
|
||||||
|
|
||||||
|
# Use context manager with both provider data and auth attributes
|
||||||
|
with request_provider_data_context(request.headers, user_attributes):
|
||||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
143
tests/unit/registry/test_registry_acl.py
Normal file
143
tests/unit/registry/test_registry_acl.py
Normal file
|
@ -0,0 +1,143 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ModelType
|
||||||
|
from llama_stack.distribution.datatypes import ModelWithACL
|
||||||
|
from llama_stack.distribution.server.auth import AccessAttributes
|
||||||
|
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||||
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def kvstore():
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
db_path = os.path.join(temp_dir, "test_registry_acl.db")
|
||||||
|
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
|
||||||
|
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||||
|
await kvstore.initialize()
|
||||||
|
yield kvstore
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def registry(kvstore):
|
||||||
|
registry = CachedDiskDistributionRegistry(kvstore)
|
||||||
|
await registry.initialize()
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
async def test_registry_cache_with_acl(registry):
|
||||||
|
model = ModelWithACL(
|
||||||
|
identifier="model-acl",
|
||||||
|
provider_id="test-provider",
|
||||||
|
provider_resource_id="model-acl-resource",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
success = await registry.register(model)
|
||||||
|
assert success
|
||||||
|
|
||||||
|
cached_model = registry.get_cached("model", "model-acl")
|
||||||
|
assert cached_model is not None
|
||||||
|
assert cached_model.identifier == "model-acl"
|
||||||
|
assert cached_model.access_attributes.roles == ["admin"]
|
||||||
|
assert cached_model.access_attributes.teams == ["ai-team"]
|
||||||
|
|
||||||
|
fetched_model = await registry.get("model", "model-acl")
|
||||||
|
assert fetched_model is not None
|
||||||
|
assert fetched_model.identifier == "model-acl"
|
||||||
|
assert fetched_model.access_attributes.roles == ["admin"]
|
||||||
|
|
||||||
|
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
|
||||||
|
await registry.update(model)
|
||||||
|
|
||||||
|
updated_cached = registry.get_cached("model", "model-acl")
|
||||||
|
assert updated_cached is not None
|
||||||
|
assert updated_cached.access_attributes.roles == ["admin", "user"]
|
||||||
|
assert updated_cached.access_attributes.projects == ["project-x"]
|
||||||
|
assert updated_cached.access_attributes.teams is None
|
||||||
|
|
||||||
|
new_registry = CachedDiskDistributionRegistry(registry._kvstore)
|
||||||
|
await new_registry.initialize()
|
||||||
|
|
||||||
|
new_model = await new_registry.get("model", "model-acl")
|
||||||
|
assert new_model is not None
|
||||||
|
assert new_model.identifier == "model-acl"
|
||||||
|
assert new_model.access_attributes.roles == ["admin", "user"]
|
||||||
|
assert new_model.access_attributes.projects == ["project-x"]
|
||||||
|
assert new_model.access_attributes.teams is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_registry_empty_acl(registry):
|
||||||
|
model = ModelWithACL(
|
||||||
|
identifier="model-empty-acl",
|
||||||
|
provider_id="test-provider",
|
||||||
|
provider_resource_id="model-resource",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
access_attributes=AccessAttributes(),
|
||||||
|
)
|
||||||
|
|
||||||
|
await registry.register(model)
|
||||||
|
|
||||||
|
cached_model = registry.get_cached("model", "model-empty-acl")
|
||||||
|
assert cached_model is not None
|
||||||
|
assert cached_model.access_attributes is not None
|
||||||
|
assert cached_model.access_attributes.roles is None
|
||||||
|
assert cached_model.access_attributes.teams is None
|
||||||
|
assert cached_model.access_attributes.projects is None
|
||||||
|
assert cached_model.access_attributes.namespaces is None
|
||||||
|
|
||||||
|
all_models = await registry.get_all()
|
||||||
|
assert len(all_models) == 1
|
||||||
|
|
||||||
|
model = ModelWithACL(
|
||||||
|
identifier="model-no-acl",
|
||||||
|
provider_id="test-provider",
|
||||||
|
provider_resource_id="model-resource-2",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
await registry.register(model)
|
||||||
|
|
||||||
|
cached_model = registry.get_cached("model", "model-no-acl")
|
||||||
|
assert cached_model is not None
|
||||||
|
assert cached_model.access_attributes is None
|
||||||
|
|
||||||
|
all_models = await registry.get_all()
|
||||||
|
assert len(all_models) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_registry_serialization(registry):
|
||||||
|
model = ModelWithACL(
|
||||||
|
identifier="model-serialize",
|
||||||
|
provider_id="test-provider",
|
||||||
|
provider_resource_id="model-resource",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
access_attributes=AccessAttributes(
|
||||||
|
roles=["admin", "researcher"],
|
||||||
|
teams=["ai-team", "ml-team"],
|
||||||
|
projects=["project-a", "project-b"],
|
||||||
|
namespaces=["prod", "staging"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await registry.register(model)
|
||||||
|
registry.cache.clear()
|
||||||
|
|
||||||
|
loaded_model = await registry.get("model", "model-serialize")
|
||||||
|
|
||||||
|
assert loaded_model.access_attributes.roles == ["admin", "researcher"]
|
||||||
|
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"]
|
||||||
|
assert loaded_model.access_attributes.projects == ["project-a", "project-b"]
|
||||||
|
assert loaded_model.access_attributes.namespaces == ["prod", "staging"]
|
236
tests/unit/server/test_access_control.py
Normal file
236
tests/unit/server/test_access_control.py
Normal file
|
@ -0,0 +1,236 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ModelType
|
||||||
|
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
||||||
|
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
|
||||||
|
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||||
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMock(MagicMock):
|
||||||
|
async def __call__(self, *args, **kwargs):
|
||||||
|
return super(AsyncMock, self).__call__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _return_model(model):
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_setup():
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
db_path = os.path.join(temp_dir, "test_access_control.db")
|
||||||
|
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
|
||||||
|
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||||
|
registry = CachedDiskDistributionRegistry(kvstore)
|
||||||
|
await registry.initialize()
|
||||||
|
|
||||||
|
mock_inference = Mock()
|
||||||
|
mock_inference.__provider_spec__ = MagicMock()
|
||||||
|
mock_inference.__provider_spec__.api = "inference"
|
||||||
|
mock_inference.register_model = AsyncMock(side_effect=_return_model)
|
||||||
|
routing_table = ModelsRoutingTable(
|
||||||
|
impls_by_provider_id={"test_provider": mock_inference},
|
||||||
|
dist_registry=registry,
|
||||||
|
)
|
||||||
|
yield registry, routing_table
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||||
|
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
||||||
|
registry, routing_table = test_setup
|
||||||
|
model_public = ModelWithACL(
|
||||||
|
identifier="model-public",
|
||||||
|
provider_id="test_provider",
|
||||||
|
provider_resource_id="model-public",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
model_admin_only = ModelWithACL(
|
||||||
|
identifier="model-admin",
|
||||||
|
provider_id="test_provider",
|
||||||
|
provider_resource_id="model-admin",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
access_attributes=AccessAttributes(roles=["admin"]),
|
||||||
|
)
|
||||||
|
model_data_scientist = ModelWithACL(
|
||||||
|
identifier="model-data-scientist",
|
||||||
|
provider_id="test_provider",
|
||||||
|
provider_resource_id="model-data-scientist",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]),
|
||||||
|
)
|
||||||
|
await registry.register(model_public)
|
||||||
|
await registry.register(model_admin_only)
|
||||||
|
await registry.register(model_data_scientist)
|
||||||
|
|
||||||
|
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
|
||||||
|
all_models = await routing_table.list_models()
|
||||||
|
assert len(all_models.data) == 3
|
||||||
|
|
||||||
|
model = await routing_table.get_model("model-public")
|
||||||
|
assert model.identifier == "model-public"
|
||||||
|
model = await routing_table.get_model("model-admin")
|
||||||
|
assert model.identifier == "model-admin"
|
||||||
|
model = await routing_table.get_model("model-data-scientist")
|
||||||
|
assert model.identifier == "model-data-scientist"
|
||||||
|
|
||||||
|
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
|
||||||
|
all_models = await routing_table.list_models()
|
||||||
|
assert len(all_models.data) == 1
|
||||||
|
assert all_models.data[0].identifier == "model-public"
|
||||||
|
model = await routing_table.get_model("model-public")
|
||||||
|
assert model.identifier == "model-public"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await routing_table.get_model("model-admin")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await routing_table.get_model("model-data-scientist")
|
||||||
|
|
||||||
|
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]}
|
||||||
|
all_models = await routing_table.list_models()
|
||||||
|
assert len(all_models.data) == 2
|
||||||
|
model_ids = [m.identifier for m in all_models.data]
|
||||||
|
assert "model-public" in model_ids
|
||||||
|
assert "model-data-scientist" in model_ids
|
||||||
|
assert "model-admin" not in model_ids
|
||||||
|
model = await routing_table.get_model("model-public")
|
||||||
|
assert model.identifier == "model-public"
|
||||||
|
model = await routing_table.get_model("model-data-scientist")
|
||||||
|
assert model.identifier == "model-data-scientist"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await routing_table.get_model("model-admin")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||||
|
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
|
||||||
|
registry, routing_table = test_setup
|
||||||
|
model_public = ModelWithACL(
|
||||||
|
identifier="model-updates",
|
||||||
|
provider_id="test_provider",
|
||||||
|
provider_resource_id="model-updates",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
await registry.register(model_public)
|
||||||
|
mock_get_auth_attributes.return_value = {
|
||||||
|
"roles": ["user"],
|
||||||
|
}
|
||||||
|
model = await routing_table.get_model("model-updates")
|
||||||
|
assert model.identifier == "model-updates"
|
||||||
|
model_public.access_attributes = AccessAttributes(roles=["admin"])
|
||||||
|
await registry.update(model_public)
|
||||||
|
mock_get_auth_attributes.return_value = {
|
||||||
|
"roles": ["user"],
|
||||||
|
}
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await routing_table.get_model("model-updates")
|
||||||
|
mock_get_auth_attributes.return_value = {
|
||||||
|
"roles": ["admin"],
|
||||||
|
}
|
||||||
|
model = await routing_table.get_model("model-updates")
|
||||||
|
assert model.identifier == "model-updates"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||||
|
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
|
||||||
|
registry, routing_table = test_setup
|
||||||
|
model = ModelWithACL(
|
||||||
|
identifier="model-empty-attrs",
|
||||||
|
provider_id="test_provider",
|
||||||
|
provider_resource_id="model-empty-attrs",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
access_attributes=AccessAttributes(),
|
||||||
|
)
|
||||||
|
await registry.register(model)
|
||||||
|
mock_get_auth_attributes.return_value = {}
|
||||||
|
result = await routing_table.get_model("model-empty-attrs")
|
||||||
|
assert result.identifier == "model-empty-attrs"
|
||||||
|
all_models = await routing_table.list_models()
|
||||||
|
model_ids = [m.identifier for m in all_models.data]
|
||||||
|
assert "model-empty-attrs" in model_ids
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||||
|
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
||||||
|
registry, routing_table = test_setup
|
||||||
|
model_public = ModelWithACL(
|
||||||
|
identifier="model-public-2",
|
||||||
|
provider_id="test_provider",
|
||||||
|
provider_resource_id="model-public-2",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
model_restricted = ModelWithACL(
|
||||||
|
identifier="model-restricted",
|
||||||
|
provider_id="test_provider",
|
||||||
|
provider_resource_id="model-restricted",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
access_attributes=AccessAttributes(roles=["admin"]),
|
||||||
|
)
|
||||||
|
await registry.register(model_public)
|
||||||
|
await registry.register(model_restricted)
|
||||||
|
mock_get_auth_attributes.return_value = None
|
||||||
|
model = await routing_table.get_model("model-public-2")
|
||||||
|
assert model.identifier == "model-public-2"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await routing_table.get_model("model-restricted")
|
||||||
|
|
||||||
|
all_models = await routing_table.list_models()
|
||||||
|
assert len(all_models.data) == 1
|
||||||
|
assert all_models.data[0].identifier == "model-public-2"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||||
|
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
|
||||||
|
"""Test that newly created resources inherit access attributes from their creator."""
|
||||||
|
registry, routing_table = test_setup
|
||||||
|
|
||||||
|
# Set creator's attributes
|
||||||
|
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
|
||||||
|
mock_get_auth_attributes.return_value = creator_attributes
|
||||||
|
|
||||||
|
# Create model without explicit access attributes
|
||||||
|
model = ModelWithACL(
|
||||||
|
identifier="auto-access-model",
|
||||||
|
provider_id="test_provider",
|
||||||
|
provider_resource_id="auto-access-model",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
await routing_table.register_object(model)
|
||||||
|
|
||||||
|
# Verify the model got creator's attributes
|
||||||
|
registered_model = await routing_table.get_model("auto-access-model")
|
||||||
|
assert registered_model.access_attributes is not None
|
||||||
|
assert registered_model.access_attributes.roles == ["data-scientist"]
|
||||||
|
assert registered_model.access_attributes.teams == ["ml-team"]
|
||||||
|
assert registered_model.access_attributes.projects == ["llama-3"]
|
||||||
|
|
||||||
|
# Verify another user without matching attributes can't access it
|
||||||
|
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]}
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await routing_table.get_model("auto-access-model")
|
||||||
|
|
||||||
|
# But a user with matching attributes can
|
||||||
|
mock_get_auth_attributes.return_value = {
|
||||||
|
"roles": ["data-scientist", "engineer"],
|
||||||
|
"teams": ["ml-team", "platform-team"],
|
||||||
|
}
|
||||||
|
model = await routing_table.get_model("auto-access-model")
|
||||||
|
assert model.identifier == "auto-access-model"
|
|
@ -13,6 +13,15 @@ from fastapi.testclient import TestClient
|
||||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, status_code, json_data):
|
||||||
|
self.status_code = status_code
|
||||||
|
self._json_data = json_data
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_auth_endpoint():
|
def mock_auth_endpoint():
|
||||||
return "http://mock-auth-service/validate"
|
return "http://mock-auth-service/validate"
|
||||||
|
@ -45,6 +54,26 @@ def client(app):
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_scope():
|
||||||
|
return {
|
||||||
|
"type": "http",
|
||||||
|
"path": "/models/list",
|
||||||
|
"headers": [
|
||||||
|
(b"content-type", b"application/json"),
|
||||||
|
(b"authorization", b"Bearer test-api-key"),
|
||||||
|
(b"user-agent", b"test-user-agent"),
|
||||||
|
],
|
||||||
|
"query_string": b"limit=100&offset=0",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_middleware(mock_auth_endpoint):
|
||||||
|
mock_app = AsyncMock()
|
||||||
|
return AuthenticationMiddleware(mock_app, mock_auth_endpoint), mock_app
|
||||||
|
|
||||||
|
|
||||||
async def mock_post_success(*args, **kwargs):
|
async def mock_post_success(*args, **kwargs):
|
||||||
mock_response = AsyncMock()
|
mock_response = AsyncMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
|
@ -122,3 +151,59 @@ def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
|
||||||
assert "authorization" in payload["request"]["headers"]
|
assert "authorization" in payload["request"]["headers"]
|
||||||
assert "param1" in payload["request"]["params"]
|
assert "param1" in payload["request"]["params"]
|
||||||
assert "param2" in payload["request"]["params"]
|
assert "param2" in payload["request"]["params"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_middleware_with_access_attributes(mock_middleware, mock_scope):
|
||||||
|
middleware, mock_app = mock_middleware
|
||||||
|
mock_receive = AsyncMock()
|
||||||
|
mock_send = AsyncMock()
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
|
mock_client_instance.post.return_value = MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"access_attributes": {
|
||||||
|
"roles": ["admin", "user"],
|
||||||
|
"teams": ["ml-team"],
|
||||||
|
"projects": ["project-x", "project-y"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await middleware(mock_scope, mock_receive, mock_send)
|
||||||
|
|
||||||
|
assert "user_attributes" in mock_scope
|
||||||
|
assert mock_scope["user_attributes"]["roles"] == ["admin", "user"]
|
||||||
|
assert mock_scope["user_attributes"]["teams"] == ["ml-team"]
|
||||||
|
assert mock_scope["user_attributes"]["projects"] == ["project-x", "project-y"]
|
||||||
|
|
||||||
|
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_middleware_no_attributes(mock_middleware, mock_scope):
|
||||||
|
"""Test middleware behavior with no access attributes"""
|
||||||
|
middleware, mock_app = mock_middleware
|
||||||
|
mock_receive = AsyncMock()
|
||||||
|
mock_send = AsyncMock()
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
|
mock_client_instance.post.return_value = MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"message": "Authentication successful"
|
||||||
|
# No access_attributes
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await middleware(mock_scope, mock_receive, mock_send)
|
||||||
|
|
||||||
|
assert "user_attributes" in mock_scope
|
||||||
|
assert mock_scope["user_attributes"] == {}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue