mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat(server): add attribute based access control for resources (#1703)
This PR introduces a way to implement Attribute Based Access Control (ABAC) for the Llama Stack server. The rough design is: - https://github.com/meta-llama/llama-stack/pull/1626 added a way for the Llama Stack server to query an authenticator - We build upon that and expect "access attributes" as part of the response. These attributes indicate the scopes available for the request. - We use these attributes to perform access control for registered resources as well as for constructing the default access control policies for newly created resources. - By default, if you support authentication but don't return access attributes, we will add a unique namespace pointing to the API_KEY. That way, all resources by default will be scoped to API_KEYs. An important aspect of this design is that Llama Stack stays out of the business of credential management or the CRUD for attributes. How you manage your namespaces or projects is entirely up to you. The design only implements access control checks for the metadata / book-keeping information that the Stack tracks. ### Limitations - Currently, read vs. write vs. admin permissions aren't made explicit, but this can be easily extended by adding appropriate attributes to the `AccessAttributes` data structure. - This design does not apply to agent instances since they are not considered resources the Stack knows about. Agent instances are completely within the scope of the Agents API provider. ### Test Plan Added unit tests, existing integration tests
This commit is contained in:
parent
c4e1b8d094
commit
01a25d9744
10 changed files with 890 additions and 45 deletions
|
@ -41,11 +41,22 @@ from llama_stack.apis.tools import (
|
|||
ToolHost,
|
||||
)
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessAttributes,
|
||||
BenchmarkWithACL,
|
||||
DatasetWithACL,
|
||||
ModelWithACL,
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
ScoringFnWithACL,
|
||||
ShieldWithACL,
|
||||
ToolGroupWithACL,
|
||||
ToolWithACL,
|
||||
VectorDBWithACL,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
|
@ -186,6 +197,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if not obj:
|
||||
return None
|
||||
|
||||
# Check if user has permission to access this object
|
||||
if not check_access(obj, get_auth_attributes()):
|
||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||
return None
|
||||
|
||||
return obj
|
||||
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
|
@ -202,6 +218,13 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
# If object supports access control but no attributes set, use creator's attributes
|
||||
if not obj.access_attributes:
|
||||
creator_attributes = get_auth_attributes()
|
||||
if creator_attributes:
|
||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
if obj.type == ResourceType.model.value:
|
||||
|
@ -214,7 +237,13 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type == type]
|
||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||
|
||||
# Apply attribute-based access control filtering
|
||||
if filtered_objs:
|
||||
filtered_objs = [obj for obj in filtered_objs if check_access(obj, get_auth_attributes())]
|
||||
|
||||
return filtered_objs
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
|
@ -251,7 +280,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
model_type = ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
model = Model(
|
||||
model = ModelWithACL(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
|
@ -297,7 +326,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
)
|
||||
if params is None:
|
||||
params = {}
|
||||
shield = Shield(
|
||||
shield = ShieldWithACL(
|
||||
identifier=shield_id,
|
||||
provider_resource_id=provider_shield_id,
|
||||
provider_id=provider_id,
|
||||
|
@ -351,7 +380,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
|
||||
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
return vector_db
|
||||
|
||||
|
@ -405,7 +434,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
dataset = Dataset(
|
||||
dataset = DatasetWithACL(
|
||||
identifier=dataset_id,
|
||||
provider_resource_id=provider_dataset_id,
|
||||
provider_id=provider_id,
|
||||
|
@ -452,7 +481,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
scoring_fn = ScoringFn(
|
||||
scoring_fn = ScoringFnWithACL(
|
||||
identifier=scoring_fn_id,
|
||||
description=description,
|
||||
return_type=return_type,
|
||||
|
@ -494,7 +523,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
)
|
||||
if provider_benchmark_id is None:
|
||||
provider_benchmark_id = benchmark_id
|
||||
benchmark = Benchmark(
|
||||
benchmark = BenchmarkWithACL(
|
||||
identifier=benchmark_id,
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=scoring_functions,
|
||||
|
@ -537,7 +566,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
for tool_def in tool_defs:
|
||||
tools.append(
|
||||
Tool(
|
||||
ToolWithACL(
|
||||
identifier=tool_def.name,
|
||||
toolgroup_id=toolgroup_id,
|
||||
description=tool_def.description or "",
|
||||
|
@ -562,7 +591,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
await self.register_object(tool)
|
||||
|
||||
await self.dist_registry.register(
|
||||
ToolGroup(
|
||||
ToolGroupWithACL(
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue