mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
feat: fine grained access control policy
This allows a set of rules to be defined for determining access to resources. Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
parent
9623d5d230
commit
01ad876012
20 changed files with 724 additions and 214 deletions
|
@ -1,86 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def check_access(
|
||||
obj_identifier: str,
|
||||
obj_attributes: AccessAttributes | None,
|
||||
user_attributes: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""Check if the current user has access to the given object, based on access attributes.
|
||||
|
||||
Access control algorithm:
|
||||
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
|
||||
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
|
||||
3. For each attribute category in the resource's access_attributes:
|
||||
a. If the user lacks that category, access is DENIED
|
||||
b. If the user has the category but none of the required values, access is DENIED
|
||||
c. If the user has at least one matching value in each required category, access is GRANTED
|
||||
|
||||
Example:
|
||||
# Resource requires:
|
||||
access_attributes = AccessAttributes(
|
||||
roles=["admin", "data-scientist"],
|
||||
teams=["ml-team"]
|
||||
)
|
||||
|
||||
# User has:
|
||||
user_attributes = {
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "infra-team"],
|
||||
"projects": ["llama-3"]
|
||||
}
|
||||
|
||||
# Result: Access GRANTED
|
||||
# - User has the "data-scientist" role (matches one of the required roles)
|
||||
# - AND user is part of the "ml-team" (matches the required team)
|
||||
# - The extra "projects" attribute is ignored
|
||||
|
||||
Args:
|
||||
obj_identifier: The identifier of the resource object to check access for
|
||||
obj_attributes: The access attributes of the resource object
|
||||
user_attributes: The attributes of the current user
|
||||
|
||||
Returns:
|
||||
bool: True if access is granted, False if denied
|
||||
"""
|
||||
# If object has no access attributes, allow access by default
|
||||
if not obj_attributes:
|
||||
return True
|
||||
|
||||
# If no user attributes, deny access to objects with access control
|
||||
if not user_attributes:
|
||||
return False
|
||||
|
||||
dict_attribs = obj_attributes.model_dump(exclude_none=True)
|
||||
if not dict_attribs:
|
||||
return True
|
||||
|
||||
# Check each attribute category (requires ALL categories to match)
|
||||
# TODO: formalize this into a proper ABAC policy
|
||||
for attr_key, required_values in dict_attribs.items():
|
||||
user_values = user_attributes.get(attr_key, [])
|
||||
|
||||
if not user_values:
|
||||
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
|
||||
return False
|
||||
|
||||
if not any(val in user_values for val in required_values):
|
||||
logger.debug(
|
||||
f"Access denied to {obj_identifier}: "
|
||||
f"no match for attribute '{attr_key}', required one of {required_values}"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.debug(f"Access granted to {obj_identifier}")
|
||||
return True
|
5
llama_stack/distribution/access_control/__init__.py
Normal file
5
llama_stack/distribution/access_control/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
175
llama_stack/distribution/access_control/access_control.py
Normal file
175
llama_stack/distribution/access_control/access_control.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
from llama_stack.distribution.request_headers import User
|
||||
|
||||
from .datatypes import (
|
||||
AccessAttributes,
|
||||
AccessRule,
|
||||
Action,
|
||||
AttributeReference,
|
||||
Condition,
|
||||
Scope,
|
||||
)
|
||||
|
||||
|
||||
def matches_resource(resource_scope: str, actual_resource: str) -> bool:
|
||||
if resource_scope == actual_resource:
|
||||
return True
|
||||
return resource_scope.endswith("::*") and actual_resource.startswith(resource_scope[:-1])
|
||||
|
||||
|
||||
def matches_scope(
|
||||
scope: Scope,
|
||||
action: Action,
|
||||
resource: str,
|
||||
user: str | None,
|
||||
) -> bool:
|
||||
if scope.resource and not matches_resource(scope.resource, resource):
|
||||
return False
|
||||
if scope.principal and scope.principal != user:
|
||||
return False
|
||||
return action in scope.actions
|
||||
|
||||
|
||||
def user_in_literal(
|
||||
literal: str,
|
||||
user_attributes: dict[str, list[str]] | None,
|
||||
) -> bool:
|
||||
for qualifier in ["role::", "team::", "project::", "namespace::"]:
|
||||
if literal.startswith(qualifier):
|
||||
if not user_attributes:
|
||||
return False
|
||||
ref = qualifier.replace("::", "s")
|
||||
if ref in user_attributes:
|
||||
value = literal.removeprefix(qualifier)
|
||||
return value in user_attributes[ref]
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def user_in(
|
||||
ref: AttributeReference | str,
|
||||
resource_attributes: AccessAttributes | None,
|
||||
user_attributes: dict[str, list[str]] | None,
|
||||
) -> bool:
|
||||
if not ref.startswith("resource."):
|
||||
return user_in_literal(ref, user_attributes)
|
||||
name = ref.removeprefix("resource.")
|
||||
required = resource_attributes and getattr(resource_attributes, name)
|
||||
if not required:
|
||||
return True
|
||||
if not user_attributes or name not in user_attributes:
|
||||
return False
|
||||
actual = user_attributes[name]
|
||||
for value in required:
|
||||
if value in actual:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def as_list(obj: Any) -> list[Any]:
|
||||
if isinstance(obj, list):
|
||||
return obj
|
||||
return [obj]
|
||||
|
||||
|
||||
def matches_conditions(
|
||||
conditions: list[Condition],
|
||||
resource_attributes: AccessAttributes | None,
|
||||
user_attributes: dict[str, list[str]] | None,
|
||||
) -> bool:
|
||||
for condition in conditions:
|
||||
# must match all conditions
|
||||
if not matches_condition(condition, resource_attributes, user_attributes):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def matches_condition(
|
||||
condition: Condition | list[Condition],
|
||||
resource_attributes: AccessAttributes | None,
|
||||
user_attributes: dict[str, list[str]] | None,
|
||||
) -> bool:
|
||||
if isinstance(condition, list):
|
||||
return matches_conditions(as_list(condition), resource_attributes, user_attributes)
|
||||
if condition.user_in:
|
||||
for ref in as_list(condition.user_in):
|
||||
# if multiple references are specified, all must match
|
||||
if not user_in(ref, resource_attributes, user_attributes):
|
||||
return False
|
||||
return True
|
||||
if condition.user_not_in:
|
||||
for ref in as_list(condition.user_not_in):
|
||||
# if multiple references are specified, none must match
|
||||
if user_in(ref, resource_attributes, user_attributes):
|
||||
return False
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
def default_policy() -> list[AccessRule]:
|
||||
# for backwards compatibility, if no rules are provided , assume
|
||||
# full access to all subject to attribute matching rules
|
||||
return [
|
||||
AccessRule(
|
||||
permit=Scope(actions=list(Action)),
|
||||
when=Condition(user_in=list(AttributeReference)),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class ProtectedResource(Protocol):
|
||||
type: str
|
||||
identifier: str
|
||||
access_attributes: AccessAttributes
|
||||
|
||||
|
||||
def is_action_allowed(
|
||||
policy: list[AccessRule],
|
||||
action: Action,
|
||||
resource: ProtectedResource,
|
||||
user: User | None,
|
||||
) -> bool:
|
||||
# If user is not set, assume authentication is not enabled
|
||||
if not user:
|
||||
return True
|
||||
|
||||
if not len(policy):
|
||||
policy = default_policy()
|
||||
|
||||
resource_attributes = AccessAttributes()
|
||||
if hasattr(resource, "access_attributes"):
|
||||
resource_attributes = resource.access_attributes
|
||||
qualified_resource_id = resource.type + "::" + resource.identifier
|
||||
for rule in policy:
|
||||
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
|
||||
if rule.when:
|
||||
if matches_condition(rule.when, resource_attributes, user.attributes):
|
||||
return False
|
||||
elif rule.unless:
|
||||
if not matches_condition(rule.unless, resource_attributes, user.attributes):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
|
||||
if rule.when:
|
||||
if matches_condition(rule.when, resource_attributes, user.attributes):
|
||||
return True
|
||||
elif rule.unless:
|
||||
if not matches_condition(rule.unless, resource_attributes, user.attributes):
|
||||
return True
|
||||
else:
|
||||
return True
|
||||
# assume access is denied unless we find a rule that permits access
|
||||
return False
|
||||
|
||||
|
||||
class AccessDeniedError(RuntimeError):
|
||||
pass
|
136
llama_stack/distribution/access_control/datatypes.py
Normal file
136
llama_stack/distribution/access_control/datatypes.py
Normal file
|
@ -0,0 +1,136 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class AccessAttributes(BaseModel):
|
||||
"""Structured representation of user attributes for access control.
|
||||
|
||||
This model defines a structured approach to representing user attributes
|
||||
with common standard categories for access control.
|
||||
|
||||
Standard attribute categories include:
|
||||
- roles: Role-based attributes (e.g., admin, data-scientist)
|
||||
- teams: Team-based attributes (e.g., ml-team, infra-team)
|
||||
- projects: Project access attributes (e.g., llama-3, customer-insights)
|
||||
- namespaces: Namespace-based access control for resource isolation
|
||||
"""
|
||||
|
||||
# Standard attribute categories - the minimal set we need now
|
||||
roles: list[str] | None = Field(
|
||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
||||
)
|
||||
|
||||
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||
|
||||
projects: list[str] | None = Field(
|
||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
||||
)
|
||||
|
||||
namespaces: list[str] | None = Field(
|
||||
default=None, description="Namespace-based access control for resource isolation"
|
||||
)
|
||||
|
||||
|
||||
class Action(str, Enum):
|
||||
CREATE = "create"
|
||||
READ = "read"
|
||||
UPDATE = "update"
|
||||
DELETE = "delete"
|
||||
|
||||
|
||||
class Scope(BaseModel):
|
||||
principal: str | None = None
|
||||
actions: Action | list[Action]
|
||||
resource: str | None = None
|
||||
|
||||
|
||||
def _mutually_exclusive(obj, a: str, b: str):
|
||||
if getattr(obj, a) and getattr(obj, b):
|
||||
raise ValueError(f"{a} and {b} are mutually exclusive")
|
||||
|
||||
|
||||
def _require_one_of(obj, a: str, b: str):
|
||||
if not getattr(obj, a) and not getattr(obj, b):
|
||||
raise ValueError(f"on of {a} or {b} is required")
|
||||
|
||||
|
||||
class AttributeReference(str, Enum):
|
||||
RESOURCE_ROLES = "resource.roles"
|
||||
RESOURCE_TEAMS = "resource.teams"
|
||||
RESOURCE_PROJECTS = "resource.projects"
|
||||
RESOURCE_NAMESPACES = "resource.namespaces"
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
user_in: AttributeReference | list[AttributeReference] | str | None = None
|
||||
user_not_in: AttributeReference | list[AttributeReference] | str | None = None
|
||||
|
||||
|
||||
class AccessRule(BaseModel):
|
||||
"""Access rule based loosely on cedar policy language
|
||||
|
||||
A rule defines a list of action either to permit or to forbid. It may specify a
|
||||
principal or a resource that must match for the rule to take effect. The resource
|
||||
to match should be specified in the form of a type qualified identifier, e.g.
|
||||
model::my-model or vector_db::some-db, or a wildcard for all resources of a type,
|
||||
e.g. model::*. If the principal or resource are not specified, they will match all
|
||||
requests.
|
||||
|
||||
A rule may also specify a condition, either a 'when' or an 'unless', with additional
|
||||
constraints as to where the rule applies. The constraints at present are whether the
|
||||
user requesting access is in or not in some set. This set can either be a particular
|
||||
set of attributes on the resource e.g. resource.roles or a literal value of some
|
||||
notion of group, e.g. role::admin or namespace::foo.
|
||||
|
||||
Rules are tested in order to find a match. If a match is found, the request is
|
||||
permitted or forbidden depending on the type of rule. If no match is found, the
|
||||
request is denied. If no rules are specified, a rule that allows any action as
|
||||
long as the resource attributes match the user attributes is added
|
||||
(i.e. the previous behaviour is the default).
|
||||
|
||||
Some examples in yaml:
|
||||
|
||||
- permit:
|
||||
principal: user-1
|
||||
actions: [create, read, delete]
|
||||
resource: model::*
|
||||
description: user-1 has full access to all models
|
||||
- permit:
|
||||
principal: user-2
|
||||
actions: [read]
|
||||
resource: model::model-1
|
||||
description: user-2 has read access to model-1 only
|
||||
- permit:
|
||||
actions: [read]
|
||||
when:
|
||||
user_in: resource.namespaces
|
||||
description: any user has read access to any resource with matching attributes
|
||||
- forbid:
|
||||
actions: [create, read, delete]
|
||||
resource: vector_db::*
|
||||
unless:
|
||||
user_in: role::admin
|
||||
description: only user with admin role can use vector_db resources
|
||||
|
||||
"""
|
||||
|
||||
permit: Scope | None = None
|
||||
forbid: Scope | None = None
|
||||
when: Condition | list[Condition] | None = None
|
||||
unless: Condition | list[Condition] | None = None
|
||||
description: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_rule_format(self) -> Self:
|
||||
_require_one_of(self, "permit", "forbid")
|
||||
_mutually_exclusive(self, "permit", "forbid")
|
||||
_mutually_exclusive(self, "when", "unless")
|
||||
return self
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput
|
|||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
||||
|
@ -35,35 +36,6 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
|||
RoutingKey = str | list[str]
|
||||
|
||||
|
||||
class AccessAttributes(BaseModel):
|
||||
"""Structured representation of user attributes for access control.
|
||||
|
||||
This model defines a structured approach to representing user attributes
|
||||
with common standard categories for access control.
|
||||
|
||||
Standard attribute categories include:
|
||||
- roles: Role-based attributes (e.g., admin, data-scientist)
|
||||
- teams: Team-based attributes (e.g., ml-team, infra-team)
|
||||
- projects: Project access attributes (e.g., llama-3, customer-insights)
|
||||
- namespaces: Namespace-based access control for resource isolation
|
||||
"""
|
||||
|
||||
# Standard attribute categories - the minimal set we need now
|
||||
roles: list[str] | None = Field(
|
||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
||||
)
|
||||
|
||||
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||
|
||||
projects: list[str] | None = Field(
|
||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
||||
)
|
||||
|
||||
namespaces: list[str] | None = Field(
|
||||
default=None, description="Namespace-based access control for resource isolation"
|
||||
)
|
||||
|
||||
|
||||
class ResourceWithACL(Resource):
|
||||
"""Extension of Resource that adds attribute-based access control capabilities.
|
||||
|
||||
|
@ -234,6 +206,7 @@ class AuthenticationConfig(BaseModel):
|
|||
...,
|
||||
description="Provider-specific configuration",
|
||||
)
|
||||
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
|
||||
|
||||
|
||||
class AuthenticationRequiredError(Exception):
|
||||
|
|
|
@ -18,15 +18,23 @@ log = logging.getLogger(__name__)
|
|||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class User:
|
||||
principal: str
|
||||
# further attributes that may be used for access control decisions
|
||||
attributes: dict[str, list[str]]
|
||||
|
||||
def __init__(self, principal: str, attributes: dict[str, list[str]]):
|
||||
self.principal = principal
|
||||
self.attributes = attributes
|
||||
|
||||
|
||||
class RequestProviderDataContext(AbstractContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
||||
def __init__(
|
||||
self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
|
||||
):
|
||||
def __init__(self, provider_data: dict[str, Any] | None = None, user: User | None = None):
|
||||
self.provider_data = provider_data or {}
|
||||
if auth_attributes:
|
||||
self.provider_data["__auth_attributes"] = auth_attributes
|
||||
if user:
|
||||
self.provider_data["__authenticated_user"] = user
|
||||
|
||||
self.token = None
|
||||
|
||||
|
@ -95,9 +103,9 @@ def request_provider_data_context(
|
|||
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||
|
||||
|
||||
def get_auth_attributes() -> dict[str, list[str]] | None:
|
||||
def get_authenticated_user() -> User | None:
|
||||
"""Helper to retrieve auth attributes from the provider data context"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
if not provider_data:
|
||||
return None
|
||||
return provider_data.get("__auth_attributes")
|
||||
return provider_data.get("__authenticated_user")
|
||||
|
|
|
@ -28,6 +28,7 @@ from llama_stack.apis.vector_dbs import VectorDBs
|
|||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.client import get_client_impl
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessRule,
|
||||
AutoRoutedProviderSpec,
|
||||
Provider,
|
||||
RoutingTableProviderSpec,
|
||||
|
@ -118,6 +119,7 @@ async def resolve_impls(
|
|||
run_config: StackRunConfig,
|
||||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
policy: list[AccessRule],
|
||||
) -> dict[Api, Any]:
|
||||
"""
|
||||
Resolves provider implementations by:
|
||||
|
@ -140,7 +142,7 @@ async def resolve_impls(
|
|||
|
||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config)
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy)
|
||||
|
||||
|
||||
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||
|
@ -247,6 +249,7 @@ async def instantiate_providers(
|
|||
router_apis: set[Api],
|
||||
dist_registry: DistributionRegistry,
|
||||
run_config: StackRunConfig,
|
||||
policy: list[AccessRule],
|
||||
) -> dict:
|
||||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: dict[Api, Any] = {}
|
||||
|
@ -261,7 +264,7 @@ async def instantiate_providers(
|
|||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config)
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
|
||||
|
||||
if api_str.startswith("inner-"):
|
||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||
|
@ -312,6 +315,7 @@ async def instantiate_provider(
|
|||
inner_impls: dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
run_config: StackRunConfig,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
provider_spec = provider.spec
|
||||
if not hasattr(provider_spec, "module"):
|
||||
|
@ -336,13 +340,15 @@ async def instantiate_provider(
|
|||
method = "get_routing_table_impl"
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, inner_impls, deps, dist_registry]
|
||||
args = [provider_spec.api, inner_impls, deps, dist_registry, policy]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider.config)
|
||||
args = [config, deps]
|
||||
if "policy" in inspect.signature(getattr(module, method)).parameters:
|
||||
args.append(policy)
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
||||
from llama_stack.distribution.datatypes import AccessRule, RoutedProtocol
|
||||
from llama_stack.distribution.stack import StackRunConfig
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
@ -18,6 +18,7 @@ async def get_routing_table_impl(
|
|||
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||
_deps,
|
||||
dist_registry: DistributionRegistry,
|
||||
policy: list[AccessRule],
|
||||
) -> Any:
|
||||
from ..routing_tables.benchmarks import BenchmarksRoutingTable
|
||||
from ..routing_tables.datasets import DatasetsRoutingTable
|
||||
|
@ -40,7 +41,7 @@ async def get_routing_table_impl(
|
|||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
|
||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
|
|
@ -8,14 +8,15 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessAttributes,
|
||||
AccessRule,
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
from llama_stack.distribution.request_headers import get_authenticated_user
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
@ -73,9 +74,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
self,
|
||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||
dist_registry: DistributionRegistry,
|
||||
policy: list[AccessRule],
|
||||
) -> None:
|
||||
self.impls_by_provider_id = impls_by_provider_id
|
||||
self.dist_registry = dist_registry
|
||||
self.policy = policy
|
||||
|
||||
async def initialize(self) -> None:
|
||||
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||
|
@ -166,13 +169,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
return None
|
||||
|
||||
# Check if user has permission to access this object
|
||||
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
|
||||
logger.debug(f"Access denied to {type} '{identifier}'")
|
||||
return None
|
||||
|
||||
return obj
|
||||
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
if not is_action_allowed(self.policy, "delete", obj, get_authenticated_user()):
|
||||
raise AccessDeniedError()
|
||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||
|
||||
|
@ -187,11 +192,12 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
# If object supports access control but no attributes set, use creator's attributes
|
||||
if not obj.access_attributes:
|
||||
creator_attributes = get_auth_attributes()
|
||||
if creator_attributes:
|
||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||
creator = get_authenticated_user()
|
||||
if not is_action_allowed(self.policy, "create", obj, creator):
|
||||
raise AccessDeniedError()
|
||||
if creator and creator.attributes:
|
||||
obj.access_attributes = AccessAttributes(**creator.attributes)
|
||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
|
@ -210,9 +216,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
# Apply attribute-based access control filtering
|
||||
if filtered_objs:
|
||||
filtered_objs = [
|
||||
obj
|
||||
for obj in filtered_objs
|
||||
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
||||
obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
|
||||
]
|
||||
|
||||
return filtered_objs
|
||||
|
|
|
@ -30,10 +30,7 @@ from pydantic import BaseModel, ValidationError
|
|||
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.server.endpoints import (
|
||||
find_matching_endpoint,
|
||||
|
@ -213,11 +210,13 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
async def endpoint(request: Request, **kwargs):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
principal = request.scope.get("principal", "")
|
||||
user = User(principal, user_attributes)
|
||||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user_attributes):
|
||||
with request_provider_data_context(request.headers, user):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
try:
|
||||
|
|
|
@ -223,7 +223,10 @@ async def construct_stack(
|
|||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
||||
) -> dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
||||
impls = await resolve_impls(
|
||||
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
|
||||
)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, run_config)
|
||||
|
|
|
@ -6,12 +6,12 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.datatypes import AccessRule, Api
|
||||
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
|
@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
|
|||
deps[Api.safety],
|
||||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
policy,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -60,6 +60,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
|
@ -96,13 +97,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
vector_io_api: VectorIO,
|
||||
persistence_store: KVStore,
|
||||
created_at: str,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.agent_config = agent_config
|
||||
self.inference_api = inference_api
|
||||
self.safety_api = safety_api
|
||||
self.vector_io_api = vector_io_api
|
||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||
self.storage = AgentPersistence(agent_id, persistence_store, policy)
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.created_at = created_at
|
||||
|
|
|
@ -40,6 +40,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.datatypes import AccessRule
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
@ -61,6 +62,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
|
@ -71,6 +73,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
||||
self.policy = policy
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||
|
@ -129,6 +132,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||
),
|
||||
created_at=agent_info.created_at,
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
async def create_agent_session(
|
||||
|
|
|
@ -10,9 +10,9 @@ import uuid
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule
|
||||
from llama_stack.distribution.request_headers import get_authenticated_user
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -23,6 +23,8 @@ class AgentSessionInfo(Session):
|
|||
vector_db_id: str | None = None
|
||||
started_at: datetime
|
||||
access_attributes: AccessAttributes | None = None
|
||||
identifier: str | None = None
|
||||
type: str = "session"
|
||||
|
||||
|
||||
class AgentInfo(AgentConfig):
|
||||
|
@ -30,15 +32,17 @@ class AgentInfo(AgentConfig):
|
|||
|
||||
|
||||
class AgentPersistence:
|
||||
def __init__(self, agent_id: str, kvstore: KVStore):
|
||||
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
||||
self.agent_id = agent_id
|
||||
self.kvstore = kvstore
|
||||
self.policy = policy
|
||||
|
||||
async def create_session(self, name: str) -> str:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Get current user's auth attributes for new sessions
|
||||
auth_attributes = get_auth_attributes()
|
||||
user = get_authenticated_user()
|
||||
auth_attributes = user and user.attributes
|
||||
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
|
||||
|
||||
session_info = AgentSessionInfo(
|
||||
|
@ -47,7 +51,10 @@ class AgentPersistence:
|
|||
started_at=datetime.now(timezone.utc),
|
||||
access_attributes=access_attributes,
|
||||
turns=[],
|
||||
identifier=name, # should this be qualified in any way?
|
||||
)
|
||||
if not is_action_allowed(self.policy, "create", session_info, user):
|
||||
raise AccessDeniedError()
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
|
@ -76,7 +83,7 @@ class AgentPersistence:
|
|||
if not hasattr(session_info, "access_attributes"):
|
||||
return True
|
||||
|
||||
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
|
||||
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
|
|
|
@ -121,7 +121,7 @@ class ToolGroupsImpl(Impl):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_models_routing_table(cached_disk_dist_registry):
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry)
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple models and verify listing
|
||||
|
@ -163,7 +163,7 @@ async def test_models_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shields_routing_table(cached_disk_dist_registry):
|
||||
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry)
|
||||
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple shields and verify listing
|
||||
|
@ -179,14 +179,14 @@ async def test_shields_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry)
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry)
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_providere",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
@ -209,7 +209,7 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
|||
|
||||
|
||||
async def test_datasets_routing_table(cached_disk_dist_registry):
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry)
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple datasets and verify listing
|
||||
|
@ -235,7 +235,7 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry)
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple scoring functions and verify listing
|
||||
|
@ -261,7 +261,7 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry)
|
||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple benchmarks and verify listing
|
||||
|
@ -279,7 +279,7 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_groups_routing_table(cached_disk_dist_registry):
|
||||
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry)
|
||||
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple tool groups and verify listing
|
||||
|
|
|
@ -59,6 +59,7 @@ async def agents_impl(config, mock_apis):
|
|||
mock_apis["safety_api"],
|
||||
mock_apis["tool_runtime_api"],
|
||||
mock_apis["tool_groups_api"],
|
||||
{},
|
||||
)
|
||||
await impl.initialize()
|
||||
yield impl
|
||||
|
|
|
@ -13,23 +13,24 @@ import pytest
|
|||
from llama_stack.apis.agents import Turn
|
||||
from llama_stack.apis.inference import CompletionMessage, StopReason
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.request_headers import User
|
||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_setup(sqlite_kvstore):
|
||||
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore)
|
||||
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
|
||||
yield agent_persistence
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Set creator's attributes for the session
|
||||
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
|
||||
mock_get_auth_attributes.return_value = creator_attributes
|
||||
mock_get_authenticated_user.return_value = User("test_user", creator_attributes)
|
||||
|
||||
# Create a session
|
||||
session_id = await agent_persistence.create_session("Test Session")
|
||||
|
@ -43,8 +44,8 @@ async def test_session_creation_with_access_attributes(mock_get_auth_attributes,
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_session_access_control(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_session_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with specific access attributes
|
||||
|
@ -55,6 +56,7 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
|
|||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
|
@ -63,20 +65,22 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
|
|||
)
|
||||
|
||||
# User with matching attributes can access
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
|
||||
)
|
||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||
assert retrieved_session is not None
|
||||
assert retrieved_session.session_id == session_id
|
||||
|
||||
# User without matching attributes cannot access
|
||||
mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]}
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]})
|
||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||
assert retrieved_session is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with restricted access
|
||||
|
@ -87,6 +91,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
|||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
|
@ -109,7 +114,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
|||
)
|
||||
|
||||
# Admin can add turn
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
|
||||
await agent_persistence.add_turn_to_session(session_id, turn)
|
||||
|
||||
# Admin can get turn
|
||||
|
@ -118,7 +123,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
|||
assert retrieved_turn.turn_id == turn_id
|
||||
|
||||
# Regular user cannot get turn
|
||||
mock_get_auth_attributes.return_value = {"roles": ["user"]}
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
|
||||
with pytest.raises(ValueError):
|
||||
await agent_persistence.get_session_turn(session_id, turn_id)
|
||||
|
||||
|
@ -128,8 +133,8 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with restricted access
|
||||
|
@ -140,6 +145,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
|
|||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
|
@ -150,7 +156,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
|
|||
turn_id = str(uuid.uuid4())
|
||||
|
||||
# Admin user can set inference iterations
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
|
||||
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
|
||||
|
||||
# Admin user can get inference iterations
|
||||
|
@ -158,7 +164,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
|
|||
assert infer_iters == 5
|
||||
|
||||
# Regular user cannot get inference iterations
|
||||
mock_get_auth_attributes.return_value = {"roles": ["user"]}
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
|
||||
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
|
||||
assert infer_iters is None
|
||||
|
||||
|
|
|
@ -7,10 +7,14 @@
|
|||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, AccessRule, ModelWithACL
|
||||
from llama_stack.distribution.request_headers import User
|
||||
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
||||
|
||||
|
||||
|
@ -32,13 +36,14 @@ async def test_setup(cached_disk_dist_registry):
|
|||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"test_provider": mock_inference},
|
||||
dist_registry=cached_disk_dist_registry,
|
||||
policy={},
|
||||
)
|
||||
yield cached_disk_dist_registry, routing_table
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
identifier="model-public",
|
||||
|
@ -64,7 +69,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
|||
await registry.register(model_admin_only)
|
||||
await registry.register(model_data_scientist)
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
|
||||
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"], "teams": ["management"]})
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 2
|
||||
|
||||
|
@ -75,7 +80,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
|||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-data-scientist")
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
|
||||
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["other-team"]})
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 1
|
||||
assert all_models.data[0].identifier == "model-public"
|
||||
|
@ -86,7 +91,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
|||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-data-scientist")
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]}
|
||||
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["ml-team"]})
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 2
|
||||
model_ids = [m.identifier for m in all_models.data]
|
||||
|
@ -102,8 +107,8 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
identifier="model-updates",
|
||||
|
@ -112,28 +117,37 @@ async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
|
|||
model_type=ModelType.llm,
|
||||
)
|
||||
await registry.register(model_public)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["user"],
|
||||
}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
{
|
||||
"roles": ["user"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-updates")
|
||||
assert model.identifier == "model-updates"
|
||||
model_public.access_attributes = AccessAttributes(roles=["admin"])
|
||||
await registry.update(model_public)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["user"],
|
||||
}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
{
|
||||
"roles": ["user"],
|
||||
},
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-updates")
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["admin"],
|
||||
}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
{
|
||||
"roles": ["admin"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-updates")
|
||||
assert model.identifier == "model-updates"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model = ModelWithACL(
|
||||
identifier="model-empty-attrs",
|
||||
|
@ -143,9 +157,12 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
|
|||
access_attributes=AccessAttributes(),
|
||||
)
|
||||
await registry.register(model)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": [],
|
||||
}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
{
|
||||
"roles": [],
|
||||
},
|
||||
)
|
||||
result = await routing_table.get_model("model-empty-attrs")
|
||||
assert result.identifier == "model-empty-attrs"
|
||||
all_models = await routing_table.list_models()
|
||||
|
@ -154,8 +171,8 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
identifier="model-public-2",
|
||||
|
@ -172,7 +189,7 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
|||
)
|
||||
await registry.register(model_public)
|
||||
await registry.register(model_restricted)
|
||||
mock_get_auth_attributes.return_value = None
|
||||
mock_get_authenticated_user.return_value = User("test-user", None)
|
||||
model = await routing_table.get_model("model-public-2")
|
||||
assert model.identifier == "model-public-2"
|
||||
|
||||
|
@ -185,14 +202,14 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup):
|
||||
"""Test that newly created resources inherit access attributes from their creator."""
|
||||
registry, routing_table = test_setup
|
||||
|
||||
# Set creator's attributes
|
||||
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
|
||||
mock_get_auth_attributes.return_value = creator_attributes
|
||||
mock_get_authenticated_user.return_value = User("test-user", creator_attributes)
|
||||
|
||||
# Create model without explicit access attributes
|
||||
model = ModelWithACL(
|
||||
|
@ -211,15 +228,262 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup)
|
|||
assert registered_model.access_attributes.projects == ["llama-3"]
|
||||
|
||||
# Verify another user without matching attributes can't access it
|
||||
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]}
|
||||
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("auto-access-model")
|
||||
|
||||
# But a user with matching attributes can
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "platform-team"],
|
||||
"projects": ["llama-3"],
|
||||
}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
{
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "platform-team"],
|
||||
"projects": ["llama-3"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("auto-access-model")
|
||||
assert model.identifier == "auto-access-model"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_setup_with_access_policy(cached_disk_dist_registry):
|
||||
mock_inference = Mock()
|
||||
mock_inference.__provider_spec__ = MagicMock()
|
||||
mock_inference.__provider_spec__.api = Api.inference
|
||||
mock_inference.register_model = AsyncMock(side_effect=_return_model)
|
||||
mock_inference.unregister_model = AsyncMock(side_effect=_return_model)
|
||||
|
||||
config = """
|
||||
- permit:
|
||||
principal: user-1
|
||||
actions: [create, read, delete]
|
||||
description: user-1 has full access to all models
|
||||
- permit:
|
||||
principal: user-2
|
||||
actions: [read]
|
||||
resource: model::model-1
|
||||
description: user-2 has read access to model-1 only
|
||||
- permit:
|
||||
principal: user-3
|
||||
actions: [read]
|
||||
resource: model::model-2
|
||||
description: user-3 has read access to model-2 only
|
||||
- forbid:
|
||||
actions: [create, read, delete]
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"test_provider": mock_inference},
|
||||
dist_registry=cached_disk_dist_registry,
|
||||
policy=policy,
|
||||
)
|
||||
yield routing_table
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy):
|
||||
routing_table = test_setup_with_access_policy
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-1",
|
||||
{
|
||||
"roles": ["admin"],
|
||||
"projects": ["foo", "bar"],
|
||||
},
|
||||
)
|
||||
await routing_table.register_model("model-1", provider_id="test_provider")
|
||||
await routing_table.register_model("model-2", provider_id="test_provider")
|
||||
await routing_table.register_model("model-3", provider_id="test_provider")
|
||||
model = await routing_table.get_model("model-1")
|
||||
assert model.identifier == "model-1"
|
||||
model = await routing_table.get_model("model-2")
|
||||
assert model.identifier == "model-2"
|
||||
model = await routing_table.get_model("model-3")
|
||||
assert model.identifier == "model-3"
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-2",
|
||||
{
|
||||
"roles": ["user"],
|
||||
"projects": ["foo"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-1")
|
||||
assert model.identifier == "model-1"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-2")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.register_model("model-4", provider_id="test_provider")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.unregister_model("model-1")
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-3",
|
||||
{
|
||||
"roles": ["user"],
|
||||
"projects": ["bar"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-2")
|
||||
assert model.identifier == "model-2"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-1")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.register_model("model-5", provider_id="test_provider")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.unregister_model("model-2")
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-1",
|
||||
{
|
||||
"roles": ["admin"],
|
||||
"projects": ["foo", "bar"],
|
||||
},
|
||||
)
|
||||
await routing_table.unregister_model("model-3")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
|
||||
|
||||
def test_permit_when():
|
||||
config = """
|
||||
- permit:
|
||||
principal: user-1
|
||||
actions: [read]
|
||||
when:
|
||||
user_in: resource.namespaces
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
)
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||
|
||||
|
||||
def test_permit_unless():
|
||||
config = """
|
||||
- permit:
|
||||
principal: user-1
|
||||
actions: [read]
|
||||
resource: model::*
|
||||
unless:
|
||||
- user_not_in: resource.namespaces
|
||||
- user_in: resource.teams
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
)
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||
|
||||
|
||||
def test_forbid_when():
|
||||
config = """
|
||||
- forbid:
|
||||
principal: user-1
|
||||
actions: [read]
|
||||
when:
|
||||
user_in: resource.namespaces
|
||||
- permit:
|
||||
actions: [read]
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
)
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||
|
||||
|
||||
def test_forbid_unless():
|
||||
config = """
|
||||
- forbid:
|
||||
principal: user-1
|
||||
actions: [read]
|
||||
unless:
|
||||
user_in: resource.namespaces
|
||||
- permit:
|
||||
actions: [read]
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
)
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||
|
||||
|
||||
def test_condition_with_literal():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
when:
|
||||
user_in: role::admin
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
)
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||
|
||||
|
||||
def test_condition_with_unrecognised_literal():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
when:
|
||||
user_in: whatever
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
)
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-2", None))
|
||||
|
||||
|
||||
def test_empty_condition():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
when: {}
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", None))
|
||||
|
|
|
@ -100,9 +100,10 @@ async def test_resolve_impls_basic():
|
|||
add_protocol_methods(SampleImpl, Inference)
|
||||
|
||||
mock_module.get_provider_impl = AsyncMock(return_value=impl)
|
||||
mock_module.get_provider_impl.__text_signature__ = "()"
|
||||
sys.modules["test_module"] = mock_module
|
||||
|
||||
impls = await resolve_impls(run_config, provider_registry, dist_registry)
|
||||
impls = await resolve_impls(run_config, provider_registry, dist_registry, policy={})
|
||||
|
||||
assert Api.inference in impls
|
||||
assert isinstance(impls[Api.inference], InferenceRouter)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue