feat: fine grained access control policy

This allows a set of rules to be defined for determining access to resources.

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
Gordon Sim 2025-05-06 18:54:58 +01:00
parent 9623d5d230
commit 01ad876012
20 changed files with 724 additions and 214 deletions

View file

@ -28,6 +28,7 @@ from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.datatypes import (
AccessRule,
AutoRoutedProviderSpec,
Provider,
RoutingTableProviderSpec,
@ -118,6 +119,7 @@ async def resolve_impls(
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> dict[Api, Any]:
"""
Resolves provider implementations by:
@ -140,7 +142,7 @@ async def resolve_impls(
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy)
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
@ -247,6 +249,7 @@ async def instantiate_providers(
router_apis: set[Api],
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
policy: list[AccessRule],
) -> dict:
"""Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {}
@ -261,7 +264,7 @@ async def instantiate_providers(
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config)
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
if api_str.startswith("inner-"):
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
@ -312,6 +315,7 @@ async def instantiate_provider(
inner_impls: dict[str, Any],
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
policy: list[AccessRule],
):
provider_spec = provider.spec
if not hasattr(provider_spec, "module"):
@ -336,13 +340,15 @@ async def instantiate_provider(
method = "get_routing_table_impl"
config = None
args = [provider_spec.api, inner_impls, deps, dist_registry]
args = [provider_spec.api, inner_impls, deps, dist_registry, policy]
else:
method = "get_provider_impl"
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config)
args = [config, deps]
if "policy" in inspect.signature(getattr(module, method)).parameters:
args.append(policy)
fn = getattr(module, method)
impl = await fn(*args)