mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: fine grained access control policy (#2264)
This allows a set of rules to be defined for determining access to resources. The rules are (loosely) based on the cedar policy format. A rule defines a list of action either to permit or to forbid. It may specify a principal or a resource that must match for the rule to take effect. It may also specify a condition, either a 'when' or an 'unless', with additional constraints as to where the rule applies. A list of rules is held for each type to be protected and tried in order to find a match. If a match is found, the request is permitted or forbidden depening on the type of rule. If no match is found, the request is denied. If no rules are specified for a given type, a rule that allows any action as long as the resource attributes match the user attributes is added (i.e. the previous behaviour is the default. Some examples in yaml: ``` model: - permit: principal: user-1 actions: [create, read, delete] comment: user-1 has full access to all models - permit: principal: user-2 actions: [read] resource: model-1 comment: user-2 has read access to model-1 only - permit: actions: [read] when: user_in: resource.namespaces comment: any user has read access to models with matching attributes vector_db: - forbid: actions: [create, read, delete] unless: user_in: role::admin comment: only user with admin role can use vector_db resources ``` --------- Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
parent
8bee2954be
commit
7c1998db25
32 changed files with 956 additions and 450 deletions
|
@ -1,86 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
def check_access(
|
|
||||||
obj_identifier: str,
|
|
||||||
obj_attributes: AccessAttributes | None,
|
|
||||||
user_attributes: dict[str, Any] | None = None,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if the current user has access to the given object, based on access attributes.
|
|
||||||
|
|
||||||
Access control algorithm:
|
|
||||||
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
|
|
||||||
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
|
|
||||||
3. For each attribute category in the resource's access_attributes:
|
|
||||||
a. If the user lacks that category, access is DENIED
|
|
||||||
b. If the user has the category but none of the required values, access is DENIED
|
|
||||||
c. If the user has at least one matching value in each required category, access is GRANTED
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Resource requires:
|
|
||||||
access_attributes = AccessAttributes(
|
|
||||||
roles=["admin", "data-scientist"],
|
|
||||||
teams=["ml-team"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# User has:
|
|
||||||
user_attributes = {
|
|
||||||
"roles": ["data-scientist", "engineer"],
|
|
||||||
"teams": ["ml-team", "infra-team"],
|
|
||||||
"projects": ["llama-3"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Result: Access GRANTED
|
|
||||||
# - User has the "data-scientist" role (matches one of the required roles)
|
|
||||||
# - AND user is part of the "ml-team" (matches the required team)
|
|
||||||
# - The extra "projects" attribute is ignored
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj_identifier: The identifier of the resource object to check access for
|
|
||||||
obj_attributes: The access attributes of the resource object
|
|
||||||
user_attributes: The attributes of the current user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if access is granted, False if denied
|
|
||||||
"""
|
|
||||||
# If object has no access attributes, allow access by default
|
|
||||||
if not obj_attributes:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# If no user attributes, deny access to objects with access control
|
|
||||||
if not user_attributes:
|
|
||||||
return False
|
|
||||||
|
|
||||||
dict_attribs = obj_attributes.model_dump(exclude_none=True)
|
|
||||||
if not dict_attribs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check each attribute category (requires ALL categories to match)
|
|
||||||
# TODO: formalize this into a proper ABAC policy
|
|
||||||
for attr_key, required_values in dict_attribs.items():
|
|
||||||
user_values = user_attributes.get(attr_key, [])
|
|
||||||
|
|
||||||
if not user_values:
|
|
||||||
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not any(val in user_values for val in required_values):
|
|
||||||
logger.debug(
|
|
||||||
f"Access denied to {obj_identifier}: "
|
|
||||||
f"no match for attribute '{attr_key}', required one of {required_values}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
logger.debug(f"Access granted to {obj_identifier}")
|
|
||||||
return True
|
|
5
llama_stack/distribution/access_control/__init__.py
Normal file
5
llama_stack/distribution/access_control/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
109
llama_stack/distribution/access_control/access_control.py
Normal file
109
llama_stack/distribution/access_control/access_control.py
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import User
|
||||||
|
|
||||||
|
from .conditions import (
|
||||||
|
Condition,
|
||||||
|
ProtectedResource,
|
||||||
|
parse_conditions,
|
||||||
|
)
|
||||||
|
from .datatypes import (
|
||||||
|
AccessRule,
|
||||||
|
Action,
|
||||||
|
Scope,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def matches_resource(resource_scope: str, actual_resource: str) -> bool:
|
||||||
|
if resource_scope == actual_resource:
|
||||||
|
return True
|
||||||
|
return resource_scope.endswith("::*") and actual_resource.startswith(resource_scope[:-1])
|
||||||
|
|
||||||
|
|
||||||
|
def matches_scope(
|
||||||
|
scope: Scope,
|
||||||
|
action: Action,
|
||||||
|
resource: str,
|
||||||
|
user: str | None,
|
||||||
|
) -> bool:
|
||||||
|
if scope.resource and not matches_resource(scope.resource, resource):
|
||||||
|
return False
|
||||||
|
if scope.principal and scope.principal != user:
|
||||||
|
return False
|
||||||
|
return action in scope.actions
|
||||||
|
|
||||||
|
|
||||||
|
def as_list(obj: Any) -> list[Any]:
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return obj
|
||||||
|
return [obj]
|
||||||
|
|
||||||
|
|
||||||
|
def matches_conditions(
|
||||||
|
conditions: list[Condition],
|
||||||
|
resource: ProtectedResource,
|
||||||
|
user: User,
|
||||||
|
) -> bool:
|
||||||
|
for condition in conditions:
|
||||||
|
# must match all conditions
|
||||||
|
if not condition.matches(resource, user):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def default_policy() -> list[AccessRule]:
|
||||||
|
# for backwards compatibility, if no rules are provided, assume
|
||||||
|
# full access subject to previous attribute matching rules
|
||||||
|
return [
|
||||||
|
AccessRule(
|
||||||
|
permit=Scope(actions=list(Action)),
|
||||||
|
when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def is_action_allowed(
|
||||||
|
policy: list[AccessRule],
|
||||||
|
action: Action,
|
||||||
|
resource: ProtectedResource,
|
||||||
|
user: User | None,
|
||||||
|
) -> bool:
|
||||||
|
# If user is not set, assume authentication is not enabled
|
||||||
|
if not user:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not len(policy):
|
||||||
|
policy = default_policy()
|
||||||
|
|
||||||
|
qualified_resource_id = resource.type + "::" + resource.identifier
|
||||||
|
for rule in policy:
|
||||||
|
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
|
||||||
|
if rule.when:
|
||||||
|
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
|
||||||
|
return False
|
||||||
|
elif rule.unless:
|
||||||
|
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
|
||||||
|
if rule.when:
|
||||||
|
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
|
||||||
|
return True
|
||||||
|
elif rule.unless:
|
||||||
|
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
# assume access is denied unless we find a rule that permits access
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class AccessDeniedError(RuntimeError):
|
||||||
|
pass
|
129
llama_stack/distribution/access_control/conditions.py
Normal file
129
llama_stack/distribution/access_control/conditions.py
Normal file
|
@ -0,0 +1,129 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
|
||||||
|
class User(Protocol):
|
||||||
|
principal: str
|
||||||
|
attributes: dict[str, list[str]] | None
|
||||||
|
|
||||||
|
|
||||||
|
class ProtectedResource(Protocol):
|
||||||
|
type: str
|
||||||
|
identifier: str
|
||||||
|
owner: User
|
||||||
|
|
||||||
|
|
||||||
|
class Condition(Protocol):
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
class UserInOwnersList:
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def owners_values(self, resource: ProtectedResource) -> list[str] | None:
|
||||||
|
if (
|
||||||
|
hasattr(resource, "owner")
|
||||||
|
and resource.owner
|
||||||
|
and resource.owner.attributes
|
||||||
|
and self.name in resource.owner.attributes
|
||||||
|
):
|
||||||
|
return resource.owner.attributes[self.name]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
required = self.owners_values(resource)
|
||||||
|
if not required:
|
||||||
|
return True
|
||||||
|
if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]:
|
||||||
|
return False
|
||||||
|
user_values = user.attributes[self.name]
|
||||||
|
for value in required:
|
||||||
|
if value in user_values:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"user in owners {self.name}"
|
||||||
|
|
||||||
|
|
||||||
|
class UserNotInOwnersList(UserInOwnersList):
|
||||||
|
def __init__(self, name: str):
|
||||||
|
super().__init__(name)
|
||||||
|
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
return not super().matches(resource, user)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"user not in owners {self.name}"
|
||||||
|
|
||||||
|
|
||||||
|
class UserWithValueInList:
|
||||||
|
def __init__(self, name: str, value: str):
|
||||||
|
self.name = name
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
if user.attributes and self.name in user.attributes:
|
||||||
|
return self.value in user.attributes[self.name]
|
||||||
|
print(f"User does not have {self.value} in {self.name}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"user with {self.value} in {self.name}"
|
||||||
|
|
||||||
|
|
||||||
|
class UserWithValueNotInList(UserWithValueInList):
|
||||||
|
def __init__(self, name: str, value: str):
|
||||||
|
super().__init__(name, value)
|
||||||
|
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
return not super().matches(resource, user)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"user with {self.value} not in {self.name}"
|
||||||
|
|
||||||
|
|
||||||
|
class UserIsOwner:
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
return resource.owner.principal == user.principal if resource.owner else False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "user is owner"
|
||||||
|
|
||||||
|
|
||||||
|
class UserIsNotOwner:
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
return not resource.owner or resource.owner.principal != user.principal
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "user is not owner"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_condition(condition: str) -> Condition:
|
||||||
|
words = condition.split()
|
||||||
|
match words:
|
||||||
|
case ["user", "is", "owner"]:
|
||||||
|
return UserIsOwner()
|
||||||
|
case ["user", "is", "not", "owner"]:
|
||||||
|
return UserIsNotOwner()
|
||||||
|
case ["user", "with", value, "in", name]:
|
||||||
|
return UserWithValueInList(name, value)
|
||||||
|
case ["user", "with", value, "not", "in", name]:
|
||||||
|
return UserWithValueNotInList(name, value)
|
||||||
|
case ["user", "in", "owners", name]:
|
||||||
|
return UserInOwnersList(name)
|
||||||
|
case ["user", "not", "in", "owners", name]:
|
||||||
|
return UserNotInOwnersList(name)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid condition: {condition}")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_conditions(conditions: list[str]) -> list[Condition]:
|
||||||
|
return [parse_condition(c) for c in conditions]
|
107
llama_stack/distribution/access_control/datatypes.py
Normal file
107
llama_stack/distribution/access_control/datatypes.py
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from .conditions import parse_conditions
|
||||||
|
|
||||||
|
|
||||||
|
class Action(str, Enum):
|
||||||
|
CREATE = "create"
|
||||||
|
READ = "read"
|
||||||
|
UPDATE = "update"
|
||||||
|
DELETE = "delete"
|
||||||
|
|
||||||
|
|
||||||
|
class Scope(BaseModel):
|
||||||
|
principal: str | None = None
|
||||||
|
actions: Action | list[Action]
|
||||||
|
resource: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _mutually_exclusive(obj, a: str, b: str):
|
||||||
|
if getattr(obj, a) and getattr(obj, b):
|
||||||
|
raise ValueError(f"{a} and {b} are mutually exclusive")
|
||||||
|
|
||||||
|
|
||||||
|
def _require_one_of(obj, a: str, b: str):
|
||||||
|
if not getattr(obj, a) and not getattr(obj, b):
|
||||||
|
raise ValueError(f"on of {a} or {b} is required")
|
||||||
|
|
||||||
|
|
||||||
|
class AccessRule(BaseModel):
|
||||||
|
"""Access rule based loosely on cedar policy language
|
||||||
|
|
||||||
|
A rule defines a list of action either to permit or to forbid. It may specify a
|
||||||
|
principal or a resource that must match for the rule to take effect. The resource
|
||||||
|
to match should be specified in the form of a type qualified identifier, e.g.
|
||||||
|
model::my-model or vector_db::some-db, or a wildcard for all resources of a type,
|
||||||
|
e.g. model::*. If the principal or resource are not specified, they will match all
|
||||||
|
requests.
|
||||||
|
|
||||||
|
A rule may also specify a condition, either a 'when' or an 'unless', with additional
|
||||||
|
constraints as to where the rule applies. The constraints supported at present are:
|
||||||
|
|
||||||
|
- 'user with <attr-value> in <attr-name>'
|
||||||
|
- 'user with <attr-value> not in <attr-name>'
|
||||||
|
- 'user is owner'
|
||||||
|
- 'user is not owner'
|
||||||
|
- 'user in owners <attr-name>'
|
||||||
|
- 'user not in owners <attr-name>'
|
||||||
|
|
||||||
|
Rules are tested in order to find a match. If a match is found, the request is
|
||||||
|
permitted or forbidden depending on the type of rule. If no match is found, the
|
||||||
|
request is denied. If no rules are specified, a rule that allows any action as
|
||||||
|
long as the resource attributes match the user attributes is added
|
||||||
|
(i.e. the previous behaviour is the default).
|
||||||
|
|
||||||
|
Some examples in yaml:
|
||||||
|
|
||||||
|
- permit:
|
||||||
|
principal: user-1
|
||||||
|
actions: [create, read, delete]
|
||||||
|
resource: model::*
|
||||||
|
description: user-1 has full access to all models
|
||||||
|
- permit:
|
||||||
|
principal: user-2
|
||||||
|
actions: [read]
|
||||||
|
resource: model::model-1
|
||||||
|
description: user-2 has read access to model-1 only
|
||||||
|
- permit:
|
||||||
|
actions: [read]
|
||||||
|
when: user in owner teams
|
||||||
|
description: any user has read access to any resource created by a member of their team
|
||||||
|
- forbid:
|
||||||
|
actions: [create, read, delete]
|
||||||
|
resource: vector_db::*
|
||||||
|
unless: user with admin in roles
|
||||||
|
description: only user with admin role can use vector_db resources
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
permit: Scope | None = None
|
||||||
|
forbid: Scope | None = None
|
||||||
|
when: str | list[str] | None = None
|
||||||
|
unless: str | list[str] | None = None
|
||||||
|
description: str | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_rule_format(self) -> Self:
|
||||||
|
_require_one_of(self, "permit", "forbid")
|
||||||
|
_mutually_exclusive(self, "permit", "forbid")
|
||||||
|
_mutually_exclusive(self, "when", "unless")
|
||||||
|
if isinstance(self.when, list):
|
||||||
|
parse_conditions(self.when)
|
||||||
|
elif self.when:
|
||||||
|
parse_conditions([self.when])
|
||||||
|
if isinstance(self.unless, list):
|
||||||
|
parse_conditions(self.unless)
|
||||||
|
elif self.unless:
|
||||||
|
parse_conditions([self.unless])
|
||||||
|
return self
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput
|
||||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.distribution.access_control.datatypes import AccessRule
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
||||||
|
@ -35,126 +36,66 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||||
RoutingKey = str | list[str]
|
RoutingKey = str | list[str]
|
||||||
|
|
||||||
|
|
||||||
class AccessAttributes(BaseModel):
|
class User(BaseModel):
|
||||||
"""Structured representation of user attributes for access control.
|
principal: str
|
||||||
|
# further attributes that may be used for access control decisions
|
||||||
|
attributes: dict[str, list[str]] | None = None
|
||||||
|
|
||||||
This model defines a structured approach to representing user attributes
|
def __init__(self, principal: str, attributes: dict[str, list[str]] | None):
|
||||||
with common standard categories for access control.
|
super().__init__(principal=principal, attributes=attributes)
|
||||||
|
|
||||||
Standard attribute categories include:
|
|
||||||
- roles: Role-based attributes (e.g., admin, data-scientist)
|
|
||||||
- teams: Team-based attributes (e.g., ml-team, infra-team)
|
|
||||||
- projects: Project access attributes (e.g., llama-3, customer-insights)
|
|
||||||
- namespaces: Namespace-based access control for resource isolation
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Standard attribute categories - the minimal set we need now
|
|
||||||
roles: list[str] | None = Field(
|
|
||||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
|
||||||
)
|
|
||||||
|
|
||||||
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
|
||||||
|
|
||||||
projects: list[str] | None = Field(
|
|
||||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
|
||||||
)
|
|
||||||
|
|
||||||
namespaces: list[str] | None = Field(
|
|
||||||
default=None, description="Namespace-based access control for resource isolation"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ResourceWithACL(Resource):
|
class ResourceWithOwner(Resource):
|
||||||
"""Extension of Resource that adds attribute-based access control capabilities.
|
"""Extension of Resource that adds an optional owner, i.e. the user that created the
|
||||||
|
resource. This can be used to constrain access to the resource."""
|
||||||
|
|
||||||
This class adds an optional access_attributes field that allows fine-grained control
|
owner: User | None = None
|
||||||
over which users can access each resource. When attributes are defined, a user must have
|
|
||||||
matching attributes to access the resource.
|
|
||||||
|
|
||||||
Attribute Matching Algorithm:
|
|
||||||
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
|
|
||||||
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
|
|
||||||
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
|
|
||||||
4. Within each category, ANY value match is sufficient (OR relationship within a category)
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
# Resource visible to everyone (no access control)
|
|
||||||
model = Model(identifier="llama-2", ...)
|
|
||||||
|
|
||||||
# Resource visible only to admins
|
|
||||||
model = Model(
|
|
||||||
identifier="gpt-4",
|
|
||||||
access_attributes=AccessAttributes(roles=["admin"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Resource visible to data scientists on the ML team
|
|
||||||
model = Model(
|
|
||||||
identifier="private-model",
|
|
||||||
access_attributes=AccessAttributes(
|
|
||||||
roles=["data-scientist", "researcher"],
|
|
||||||
teams=["ml-team"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# ^ User must have at least one of the roles AND be on the ml-team
|
|
||||||
|
|
||||||
# Resource visible to users with specific project access
|
|
||||||
vector_db = VectorDB(
|
|
||||||
identifier="customer-embeddings",
|
|
||||||
access_attributes=AccessAttributes(
|
|
||||||
projects=["customer-insights"],
|
|
||||||
namespaces=["confidential"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# ^ User must have access to the customer-insights project AND have confidential namespace
|
|
||||||
"""
|
|
||||||
|
|
||||||
access_attributes: AccessAttributes | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# Use the extended Resource for all routable objects
|
# Use the extended Resource for all routable objects
|
||||||
class ModelWithACL(Model, ResourceWithACL):
|
class ModelWithOwner(Model, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ShieldWithACL(Shield, ResourceWithACL):
|
class ShieldWithOwner(Shield, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class VectorDBWithACL(VectorDB, ResourceWithACL):
|
class VectorDBWithOwner(VectorDB, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DatasetWithACL(Dataset, ResourceWithACL):
|
class DatasetWithOwner(Dataset, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
|
class ScoringFnWithOwner(ScoringFn, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkWithACL(Benchmark, ResourceWithACL):
|
class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolWithACL(Tool, ResourceWithACL):
|
class ToolWithOwner(Tool, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
||||||
|
|
||||||
RoutableObjectWithProvider = Annotated[
|
RoutableObjectWithProvider = Annotated[
|
||||||
ModelWithACL
|
ModelWithOwner
|
||||||
| ShieldWithACL
|
| ShieldWithOwner
|
||||||
| VectorDBWithACL
|
| VectorDBWithOwner
|
||||||
| DatasetWithACL
|
| DatasetWithOwner
|
||||||
| ScoringFnWithACL
|
| ScoringFnWithOwner
|
||||||
| BenchmarkWithACL
|
| BenchmarkWithOwner
|
||||||
| ToolWithACL
|
| ToolWithOwner
|
||||||
| ToolGroupWithACL,
|
| ToolGroupWithOwner,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -234,6 +175,7 @@ class AuthenticationConfig(BaseModel):
|
||||||
...,
|
...,
|
||||||
description="Provider-specific configuration",
|
description="Provider-specific configuration",
|
||||||
)
|
)
|
||||||
|
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationRequiredError(Exception):
|
class AuthenticationRequiredError(Exception):
|
||||||
|
|
|
@ -10,6 +10,8 @@ import logging
|
||||||
from contextlib import AbstractContextManager
|
from contextlib import AbstractContextManager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import User
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -21,12 +23,10 @@ PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||||
class RequestProviderDataContext(AbstractContextManager):
|
class RequestProviderDataContext(AbstractContextManager):
|
||||||
"""Context manager for request provider data"""
|
"""Context manager for request provider data"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, provider_data: dict[str, Any] | None = None, user: User | None = None):
|
||||||
self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
|
|
||||||
):
|
|
||||||
self.provider_data = provider_data or {}
|
self.provider_data = provider_data or {}
|
||||||
if auth_attributes:
|
if user:
|
||||||
self.provider_data["__auth_attributes"] = auth_attributes
|
self.provider_data["__authenticated_user"] = user
|
||||||
|
|
||||||
self.token = None
|
self.token = None
|
||||||
|
|
||||||
|
@ -95,9 +95,9 @@ def request_provider_data_context(
|
||||||
return RequestProviderDataContext(provider_data, auth_attributes)
|
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||||
|
|
||||||
|
|
||||||
def get_auth_attributes() -> dict[str, list[str]] | None:
|
def get_authenticated_user() -> User | None:
|
||||||
"""Helper to retrieve auth attributes from the provider data context"""
|
"""Helper to retrieve auth attributes from the provider data context"""
|
||||||
provider_data = PROVIDER_DATA_VAR.get()
|
provider_data = PROVIDER_DATA_VAR.get()
|
||||||
if not provider_data:
|
if not provider_data:
|
||||||
return None
|
return None
|
||||||
return provider_data.get("__auth_attributes")
|
return provider_data.get("__authenticated_user")
|
||||||
|
|
|
@ -28,6 +28,7 @@ from llama_stack.apis.vector_dbs import VectorDBs
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.distribution.client import get_client_impl
|
from llama_stack.distribution.client import get_client_impl
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AccessRule,
|
||||||
AutoRoutedProviderSpec,
|
AutoRoutedProviderSpec,
|
||||||
Provider,
|
Provider,
|
||||||
RoutingTableProviderSpec,
|
RoutingTableProviderSpec,
|
||||||
|
@ -118,6 +119,7 @@ async def resolve_impls(
|
||||||
run_config: StackRunConfig,
|
run_config: StackRunConfig,
|
||||||
provider_registry: ProviderRegistry,
|
provider_registry: ProviderRegistry,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
|
policy: list[AccessRule],
|
||||||
) -> dict[Api, Any]:
|
) -> dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
Resolves provider implementations by:
|
Resolves provider implementations by:
|
||||||
|
@ -140,7 +142,7 @@ async def resolve_impls(
|
||||||
|
|
||||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||||
|
|
||||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config)
|
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy)
|
||||||
|
|
||||||
|
|
||||||
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||||
|
@ -247,6 +249,7 @@ async def instantiate_providers(
|
||||||
router_apis: set[Api],
|
router_apis: set[Api],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
run_config: StackRunConfig,
|
run_config: StackRunConfig,
|
||||||
|
policy: list[AccessRule],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Instantiates providers asynchronously while managing dependencies."""
|
"""Instantiates providers asynchronously while managing dependencies."""
|
||||||
impls: dict[Api, Any] = {}
|
impls: dict[Api, Any] = {}
|
||||||
|
@ -261,7 +264,7 @@ async def instantiate_providers(
|
||||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||||
|
|
||||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config)
|
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
|
||||||
|
|
||||||
if api_str.startswith("inner-"):
|
if api_str.startswith("inner-"):
|
||||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||||
|
@ -312,6 +315,7 @@ async def instantiate_provider(
|
||||||
inner_impls: dict[str, Any],
|
inner_impls: dict[str, Any],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
run_config: StackRunConfig,
|
run_config: StackRunConfig,
|
||||||
|
policy: list[AccessRule],
|
||||||
):
|
):
|
||||||
provider_spec = provider.spec
|
provider_spec = provider.spec
|
||||||
if not hasattr(provider_spec, "module"):
|
if not hasattr(provider_spec, "module"):
|
||||||
|
@ -336,13 +340,15 @@ async def instantiate_provider(
|
||||||
method = "get_routing_table_impl"
|
method = "get_routing_table_impl"
|
||||||
|
|
||||||
config = None
|
config = None
|
||||||
args = [provider_spec.api, inner_impls, deps, dist_registry]
|
args = [provider_spec.api, inner_impls, deps, dist_registry, policy]
|
||||||
else:
|
else:
|
||||||
method = "get_provider_impl"
|
method = "get_provider_impl"
|
||||||
|
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
config = config_type(**provider.config)
|
config = config_type(**provider.config)
|
||||||
args = [config, deps]
|
args = [config, deps]
|
||||||
|
if "policy" in inspect.signature(getattr(module, method)).parameters:
|
||||||
|
args.append(policy)
|
||||||
|
|
||||||
fn = getattr(module, method)
|
fn = getattr(module, method)
|
||||||
impl = await fn(*args)
|
impl = await fn(*args)
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
from llama_stack.distribution.datatypes import AccessRule, RoutedProtocol
|
||||||
from llama_stack.distribution.stack import StackRunConfig
|
from llama_stack.distribution.stack import StackRunConfig
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
@ -18,6 +18,7 @@ async def get_routing_table_impl(
|
||||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||||
_deps,
|
_deps,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
|
policy: list[AccessRule],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
from ..routing_tables.benchmarks import BenchmarksRoutingTable
|
from ..routing_tables.benchmarks import BenchmarksRoutingTable
|
||||||
from ..routing_tables.datasets import DatasetsRoutingTable
|
from ..routing_tables.datasets import DatasetsRoutingTable
|
||||||
|
@ -40,7 +41,7 @@ async def get_routing_table_impl(
|
||||||
if api.value not in api_to_tables:
|
if api.value not in api_to_tables:
|
||||||
raise ValueError(f"API {api.value} not found in router map")
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
|
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
BenchmarkWithACL,
|
BenchmarkWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
)
|
)
|
||||||
if provider_benchmark_id is None:
|
if provider_benchmark_id is None:
|
||||||
provider_benchmark_id = benchmark_id
|
provider_benchmark_id = benchmark_id
|
||||||
benchmark = BenchmarkWithACL(
|
benchmark = BenchmarkWithOwner(
|
||||||
identifier=benchmark_id,
|
identifier=benchmark_id,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
|
|
|
@ -8,14 +8,14 @@ from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
from llama_stack.distribution.access_control import check_access
|
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
AccessAttributes,
|
AccessRule,
|
||||||
RoutableObject,
|
RoutableObject,
|
||||||
RoutableObjectWithProvider,
|
RoutableObjectWithProvider,
|
||||||
RoutedProtocol,
|
RoutedProtocol,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
from llama_stack.distribution.request_headers import get_authenticated_user
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
@ -73,9 +73,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
self,
|
self,
|
||||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
|
policy: list[AccessRule],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.impls_by_provider_id = impls_by_provider_id
|
self.impls_by_provider_id = impls_by_provider_id
|
||||||
self.dist_registry = dist_registry
|
self.dist_registry = dist_registry
|
||||||
|
self.policy = policy
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||||
|
@ -166,13 +168,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check if user has permission to access this object
|
# Check if user has permission to access this object
|
||||||
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
|
||||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
logger.debug(f"Access denied to {type} '{identifier}'")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
|
if not is_action_allowed(self.policy, "delete", obj, get_authenticated_user()):
|
||||||
|
raise AccessDeniedError()
|
||||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||||
|
|
||||||
|
@ -187,11 +191,12 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
# If object supports access control but no attributes set, use creator's attributes
|
# If object supports access control but no attributes set, use creator's attributes
|
||||||
if not obj.access_attributes:
|
creator = get_authenticated_user()
|
||||||
creator_attributes = get_auth_attributes()
|
if not is_action_allowed(self.policy, "create", obj, creator):
|
||||||
if creator_attributes:
|
raise AccessDeniedError()
|
||||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
if creator:
|
||||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
obj.owner = creator
|
||||||
|
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
|
||||||
|
|
||||||
registered_obj = await register_object_with_provider(obj, p)
|
registered_obj = await register_object_with_provider(obj, p)
|
||||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||||
|
@ -210,9 +215,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
# Apply attribute-based access control filtering
|
# Apply attribute-based access control filtering
|
||||||
if filtered_objs:
|
if filtered_objs:
|
||||||
filtered_objs = [
|
filtered_objs = [
|
||||||
obj
|
obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
|
||||||
for obj in filtered_objs
|
|
||||||
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return filtered_objs
|
return filtered_objs
|
||||||
|
|
|
@ -19,7 +19,7 @@ from llama_stack.apis.datasets import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
DatasetWithACL,
|
DatasetWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
dataset = DatasetWithACL(
|
dataset = DatasetWithOwner(
|
||||||
identifier=dataset_id,
|
identifier=dataset_id,
|
||||||
provider_resource_id=provider_dataset_id,
|
provider_resource_id=provider_dataset_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
ModelWithACL,
|
ModelWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
model_type = ModelType.llm
|
model_type = ModelType.llm
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.scoring_functions import (
|
||||||
ScoringFunctions,
|
ScoringFunctions,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
ScoringFnWithACL,
|
ScoringFnWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
)
|
)
|
||||||
scoring_fn = ScoringFnWithACL(
|
scoring_fn = ScoringFnWithOwner(
|
||||||
identifier=scoring_fn_id,
|
identifier=scoring_fn_id,
|
||||||
description=description,
|
description=description,
|
||||||
return_type=return_type,
|
return_type=return_type,
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
ShieldWithACL,
|
ShieldWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
)
|
)
|
||||||
if params is None:
|
if params is None:
|
||||||
params = {}
|
params = {}
|
||||||
shield = ShieldWithACL(
|
shield = ShieldWithOwner(
|
||||||
identifier=shield_id,
|
identifier=shield_id,
|
||||||
provider_resource_id=provider_shield_id,
|
provider_resource_id=provider_shield_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||||
from llama_stack.distribution.datatypes import ToolGroupWithACL
|
from llama_stack.distribution.datatypes import ToolGroupWithOwner
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
@ -106,7 +106,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
mcp_endpoint: URL | None = None,
|
mcp_endpoint: URL | None = None,
|
||||||
args: dict[str, Any] | None = None,
|
args: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
toolgroup = ToolGroupWithACL(
|
toolgroup = ToolGroupWithOwner(
|
||||||
identifier=toolgroup_id,
|
identifier=toolgroup_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_resource_id=toolgroup_id,
|
provider_resource_id=toolgroup_id,
|
||||||
|
|
|
@ -10,7 +10,7 @@ from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
VectorDBWithACL,
|
VectorDBWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
"embedding_model": embedding_model,
|
"embedding_model": embedding_model,
|
||||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||||
}
|
}
|
||||||
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||||
await self.register_object(vector_db)
|
await self.register_object(vector_db)
|
||||||
return vector_db
|
return vector_db
|
||||||
|
|
||||||
|
|
|
@ -105,24 +105,16 @@ class AuthenticationMiddleware:
|
||||||
logger.exception("Error during authentication")
|
logger.exception("Error during authentication")
|
||||||
return await self._send_auth_error(send, "Authentication service error")
|
return await self._send_auth_error(send, "Authentication service error")
|
||||||
|
|
||||||
# Store attributes in request scope for access control
|
|
||||||
if validation_result.access_attributes:
|
|
||||||
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
|
|
||||||
else:
|
|
||||||
logger.warning("No access attributes, setting namespace to token by default")
|
|
||||||
user_attributes = {
|
|
||||||
"roles": [token],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
||||||
# can identify the requester and enforce per-client rate limits.
|
# can identify the requester and enforce per-client rate limits.
|
||||||
scope["authenticated_client_id"] = token
|
scope["authenticated_client_id"] = token
|
||||||
|
|
||||||
# Store attributes in request scope
|
# Store attributes in request scope
|
||||||
scope["user_attributes"] = user_attributes
|
|
||||||
scope["principal"] = validation_result.principal
|
scope["principal"] = validation_result.principal
|
||||||
|
if validation_result.attributes:
|
||||||
|
scope["user_attributes"] = validation_result.attributes
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
|
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
|
@ -16,43 +16,18 @@ from jose import jwt
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType
|
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
|
||||||
|
|
||||||
class TokenValidationResult(BaseModel):
|
class AuthResponse(BaseModel):
|
||||||
principal: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="The principal (username or persistent identifier) of the authenticated user",
|
|
||||||
)
|
|
||||||
access_attributes: AccessAttributes | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="""
|
|
||||||
Structured user attributes for attribute-based access control.
|
|
||||||
|
|
||||||
These attributes determine which resources the user can access.
|
|
||||||
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
|
||||||
Each attribute category contains a list of values that the user has for that category.
|
|
||||||
During access control checks, these values are compared against resource requirements.
|
|
||||||
|
|
||||||
Example with standard categories:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"roles": ["admin", "data-scientist"],
|
|
||||||
"teams": ["ml-team"],
|
|
||||||
"projects": ["llama-3"],
|
|
||||||
"namespaces": ["research"]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthResponse(TokenValidationResult):
|
|
||||||
"""The format of the authentication response from the auth endpoint."""
|
"""The format of the authentication response from the auth endpoint."""
|
||||||
|
|
||||||
|
principal: str
|
||||||
|
# further attributes that may be used for access control decisions
|
||||||
|
attributes: dict[str, list[str]] | None = None
|
||||||
message: str | None = Field(
|
message: str | None = Field(
|
||||||
default=None, description="Optional message providing additional context about the authentication result."
|
default=None, description="Optional message providing additional context about the authentication result."
|
||||||
)
|
)
|
||||||
|
@ -78,7 +53,7 @@ class AuthProvider(ABC):
|
||||||
"""Abstract base class for authentication providers."""
|
"""Abstract base class for authentication providers."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
"""Validate a token and return access attributes."""
|
"""Validate a token and return access attributes."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -88,10 +63,10 @@ class AuthProvider(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
||||||
attributes = AccessAttributes()
|
attributes: dict[str, list[str]] = {}
|
||||||
for claim_key, attribute_key in mapping.items():
|
for claim_key, attribute_key in mapping.items():
|
||||||
if claim_key not in claims or not hasattr(attributes, attribute_key):
|
if claim_key not in claims:
|
||||||
continue
|
continue
|
||||||
claim = claims[claim_key]
|
claim = claims[claim_key]
|
||||||
if isinstance(claim, list):
|
if isinstance(claim, list):
|
||||||
|
@ -99,11 +74,10 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
||||||
else:
|
else:
|
||||||
values = claim.split()
|
values = claim.split()
|
||||||
|
|
||||||
current = getattr(attributes, attribute_key)
|
if attribute_key in attributes:
|
||||||
if current:
|
attributes[attribute_key].extend(values)
|
||||||
current.extend(values)
|
|
||||||
else:
|
else:
|
||||||
setattr(attributes, attribute_key, values)
|
attributes[attribute_key] = values
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
|
|
||||||
|
@ -145,8 +119,6 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||||
for key, value in v.items():
|
for key, value in v.items():
|
||||||
if not value:
|
if not value:
|
||||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||||
if value not in AccessAttributes.model_fields:
|
|
||||||
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
@ -171,14 +143,14 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
self._jwks: dict[str, str] = {}
|
self._jwks: dict[str, str] = {}
|
||||||
self._jwks_lock = Lock()
|
self._jwks_lock = Lock()
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
if self.config.jwks:
|
if self.config.jwks:
|
||||||
return await self.validate_jwt_token(token, scope)
|
return await self.validate_jwt_token(token, scope)
|
||||||
if self.config.introspection:
|
if self.config.introspection:
|
||||||
return await self.introspect_token(token, scope)
|
return await self.introspect_token(token, scope)
|
||||||
raise ValueError("One of jwks or introspection must be configured")
|
raise ValueError("One of jwks or introspection must be configured")
|
||||||
|
|
||||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
"""Validate a token using the JWT token."""
|
"""Validate a token using the JWT token."""
|
||||||
await self._refresh_jwks()
|
await self._refresh_jwks()
|
||||||
|
|
||||||
|
@ -203,12 +175,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
# We should incorporate these into the access attributes.
|
# We should incorporate these into the access attributes.
|
||||||
principal = claims["sub"]
|
principal = claims["sub"]
|
||||||
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||||
return TokenValidationResult(
|
return User(
|
||||||
principal=principal,
|
principal=principal,
|
||||||
access_attributes=access_attributes,
|
attributes=access_attributes,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def introspect_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
"""Validate a token using token introspection as defined by RFC 7662."""
|
"""Validate a token using token introspection as defined by RFC 7662."""
|
||||||
form = {
|
form = {
|
||||||
"token": token,
|
"token": token,
|
||||||
|
@ -242,9 +214,9 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
raise ValueError("Token not active")
|
raise ValueError("Token not active")
|
||||||
principal = fields["sub"] or fields["username"]
|
principal = fields["sub"] or fields["username"]
|
||||||
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
||||||
return TokenValidationResult(
|
return User(
|
||||||
principal=principal,
|
principal=principal,
|
||||||
access_attributes=access_attributes,
|
attributes=access_attributes,
|
||||||
)
|
)
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
logger.exception("Token introspection request timed out")
|
logger.exception("Token introspection request timed out")
|
||||||
|
@ -299,7 +271,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
"""Validate a token using the custom authentication endpoint."""
|
"""Validate a token using the custom authentication endpoint."""
|
||||||
if scope is None:
|
if scope is None:
|
||||||
scope = {}
|
scope = {}
|
||||||
|
@ -333,6 +305,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
json=auth_request.model_dump(),
|
json=auth_request.model_dump(),
|
||||||
timeout=10.0, # Add a reasonable timeout
|
timeout=10.0, # Add a reasonable timeout
|
||||||
)
|
)
|
||||||
|
print("MADE CALL")
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
||||||
raise ValueError(f"Authentication failed: {response.status_code}")
|
raise ValueError(f"Authentication failed: {response.status_code}")
|
||||||
|
@ -341,7 +314,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
try:
|
try:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
auth_response = AuthResponse(**response_data)
|
auth_response = AuthResponse(**response_data)
|
||||||
return auth_response
|
return User(auth_response.principal, auth_response.attributes)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error parsing authentication response")
|
logger.exception("Error parsing authentication response")
|
||||||
raise ValueError("Invalid authentication response format") from e
|
raise ValueError("Invalid authentication response format") from e
|
||||||
|
|
|
@ -33,10 +33,7 @@ from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import (
|
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
||||||
PROVIDER_DATA_VAR,
|
|
||||||
request_provider_data_context,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.server.routes import (
|
from llama_stack.distribution.server.routes import (
|
||||||
find_matching_route,
|
find_matching_route,
|
||||||
|
@ -217,11 +214,13 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
async def route_handler(request: Request, **kwargs):
|
async def route_handler(request: Request, **kwargs):
|
||||||
# Get auth attributes from the request scope
|
# Get auth attributes from the request scope
|
||||||
user_attributes = request.scope.get("user_attributes", {})
|
user_attributes = request.scope.get("user_attributes", {})
|
||||||
|
principal = request.scope.get("principal", "")
|
||||||
|
user = User(principal, user_attributes)
|
||||||
|
|
||||||
await log_request_pre_validation(request)
|
await log_request_pre_validation(request)
|
||||||
|
|
||||||
# Use context manager with both provider data and auth attributes
|
# Use context manager with both provider data and auth attributes
|
||||||
with request_provider_data_context(request.headers, user_attributes):
|
with request_provider_data_context(request.headers, user):
|
||||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -223,7 +223,10 @@ async def construct_stack(
|
||||||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
||||||
) -> dict[Api, Any]:
|
) -> dict[Api, Any]:
|
||||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
||||||
|
impls = await resolve_impls(
|
||||||
|
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
|
||||||
|
)
|
||||||
|
|
||||||
# Add internal implementations after all other providers are resolved
|
# Add internal implementations after all other providers are resolved
|
||||||
add_internal_implementations(impls, run_config)
|
add_internal_implementations(impls, run_config)
|
||||||
|
|
|
@ -6,12 +6,12 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import AccessRule, Api
|
||||||
|
|
||||||
from .config import MetaReferenceAgentsImplConfig
|
from .config import MetaReferenceAgentsImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
|
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||||
from .agents import MetaReferenceAgentsImpl
|
from .agents import MetaReferenceAgentsImpl
|
||||||
|
|
||||||
impl = MetaReferenceAgentsImpl(
|
impl = MetaReferenceAgentsImpl(
|
||||||
|
@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
|
||||||
deps[Api.safety],
|
deps[Api.safety],
|
||||||
deps[Api.tool_runtime],
|
deps[Api.tool_runtime],
|
||||||
deps[Api.tool_groups],
|
deps[Api.tool_groups],
|
||||||
|
policy,
|
||||||
)
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -60,6 +60,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.distribution.datatypes import AccessRule
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
|
@ -96,13 +97,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
vector_io_api: VectorIO,
|
vector_io_api: VectorIO,
|
||||||
persistence_store: KVStore,
|
persistence_store: KVStore,
|
||||||
created_at: str,
|
created_at: str,
|
||||||
|
policy: list[AccessRule],
|
||||||
):
|
):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
self.agent_config = agent_config
|
self.agent_config = agent_config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.vector_io_api = vector_io_api
|
self.vector_io_api = vector_io_api
|
||||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
self.storage = AgentPersistence(agent_id, persistence_store, policy)
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
self.created_at = created_at
|
self.created_at = created_at
|
||||||
|
|
|
@ -41,6 +41,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.distribution.datatypes import AccessRule
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
from llama_stack.providers.utils.pagination import paginate_records
|
from llama_stack.providers.utils.pagination import paginate_records
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
|
@ -62,6 +63,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
tool_groups_api: ToolGroups,
|
tool_groups_api: ToolGroups,
|
||||||
|
policy: list[AccessRule],
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
@ -72,6 +74,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
||||||
self.in_memory_store = InmemoryKVStoreImpl()
|
self.in_memory_store = InmemoryKVStoreImpl()
|
||||||
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
||||||
|
self.policy = policy
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||||
|
@ -130,6 +133,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||||
),
|
),
|
||||||
created_at=agent_info.created_at,
|
created_at=agent_info.created_at,
|
||||||
|
policy=self.policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agent_session(
|
async def create_agent_session(
|
||||||
|
|
|
@ -10,9 +10,10 @@ import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||||
from llama_stack.distribution.access_control import check_access
|
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.access_control.datatypes import AccessRule
|
||||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
from llama_stack.distribution.datatypes import User
|
||||||
|
from llama_stack.distribution.request_headers import get_authenticated_user
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -22,7 +23,9 @@ class AgentSessionInfo(Session):
|
||||||
# TODO: is this used anywhere?
|
# TODO: is this used anywhere?
|
||||||
vector_db_id: str | None = None
|
vector_db_id: str | None = None
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
access_attributes: AccessAttributes | None = None
|
owner: User | None = None
|
||||||
|
identifier: str | None = None
|
||||||
|
type: str = "session"
|
||||||
|
|
||||||
|
|
||||||
class AgentInfo(AgentConfig):
|
class AgentInfo(AgentConfig):
|
||||||
|
@ -30,24 +33,27 @@ class AgentInfo(AgentConfig):
|
||||||
|
|
||||||
|
|
||||||
class AgentPersistence:
|
class AgentPersistence:
|
||||||
def __init__(self, agent_id: str, kvstore: KVStore):
|
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
self.kvstore = kvstore
|
self.kvstore = kvstore
|
||||||
|
self.policy = policy
|
||||||
|
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# Get current user's auth attributes for new sessions
|
# Get current user's auth attributes for new sessions
|
||||||
auth_attributes = get_auth_attributes()
|
user = get_authenticated_user()
|
||||||
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
|
|
||||||
|
|
||||||
session_info = AgentSessionInfo(
|
session_info = AgentSessionInfo(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
session_name=name,
|
session_name=name,
|
||||||
started_at=datetime.now(timezone.utc),
|
started_at=datetime.now(timezone.utc),
|
||||||
access_attributes=access_attributes,
|
owner=user,
|
||||||
turns=[],
|
turns=[],
|
||||||
|
identifier=name, # should this be qualified in any way?
|
||||||
)
|
)
|
||||||
|
if not is_action_allowed(self.policy, "create", session_info, user):
|
||||||
|
raise AccessDeniedError()
|
||||||
|
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
|
@ -73,10 +79,10 @@ class AgentPersistence:
|
||||||
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
||||||
"""Check if current user has access to the session."""
|
"""Check if current user has access to the session."""
|
||||||
# Handle backward compatibility for old sessions without access control
|
# Handle backward compatibility for old sessions without access control
|
||||||
if not hasattr(session_info, "access_attributes"):
|
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
|
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
|
||||||
|
|
||||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||||
|
|
|
@ -121,7 +121,7 @@ class ToolGroupsImpl(Impl):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_models_routing_table(cached_disk_dist_registry):
|
async def test_models_routing_table(cached_disk_dist_registry):
|
||||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry)
|
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# Register multiple models and verify listing
|
# Register multiple models and verify listing
|
||||||
|
@ -163,7 +163,7 @@ async def test_models_routing_table(cached_disk_dist_registry):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_shields_routing_table(cached_disk_dist_registry):
|
async def test_shields_routing_table(cached_disk_dist_registry):
|
||||||
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry)
|
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# Register multiple shields and verify listing
|
# Register multiple shields and verify listing
|
||||||
|
@ -179,14 +179,14 @@ async def test_shields_routing_table(cached_disk_dist_registry):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry)
|
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry)
|
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
await m_table.initialize()
|
await m_table.initialize()
|
||||||
await m_table.register_model(
|
await m_table.register_model(
|
||||||
model_id="test-model",
|
model_id="test-model",
|
||||||
provider_id="test_providere",
|
provider_id="test_provider",
|
||||||
metadata={"embedding_dimension": 128},
|
metadata={"embedding_dimension": 128},
|
||||||
model_type=ModelType.embedding,
|
model_type=ModelType.embedding,
|
||||||
)
|
)
|
||||||
|
@ -209,7 +209,7 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||||
|
|
||||||
|
|
||||||
async def test_datasets_routing_table(cached_disk_dist_registry):
|
async def test_datasets_routing_table(cached_disk_dist_registry):
|
||||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry)
|
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# Register multiple datasets and verify listing
|
# Register multiple datasets and verify listing
|
||||||
|
@ -235,7 +235,7 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
||||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry)
|
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# Register multiple scoring functions and verify listing
|
# Register multiple scoring functions and verify listing
|
||||||
|
@ -261,7 +261,7 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
||||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry)
|
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# Register multiple benchmarks and verify listing
|
# Register multiple benchmarks and verify listing
|
||||||
|
@ -279,7 +279,7 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tool_groups_routing_table(cached_disk_dist_registry):
|
async def test_tool_groups_routing_table(cached_disk_dist_registry):
|
||||||
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry)
|
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# Register multiple tool groups and verify listing
|
# Register multiple tool groups and verify listing
|
||||||
|
|
|
@ -59,6 +59,7 @@ async def agents_impl(config, mock_apis):
|
||||||
mock_apis["safety_api"],
|
mock_apis["safety_api"],
|
||||||
mock_apis["tool_runtime_api"],
|
mock_apis["tool_runtime_api"],
|
||||||
mock_apis["tool_groups_api"],
|
mock_apis["tool_groups_api"],
|
||||||
|
{},
|
||||||
)
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
yield impl
|
yield impl
|
||||||
|
|
|
@ -12,24 +12,24 @@ import pytest
|
||||||
|
|
||||||
from llama_stack.apis.agents import Turn
|
from llama_stack.apis.agents import Turn
|
||||||
from llama_stack.apis.inference import CompletionMessage, StopReason
|
from llama_stack.apis.inference import CompletionMessage, StopReason
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.datatypes import User
|
||||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def test_setup(sqlite_kvstore):
|
async def test_setup(sqlite_kvstore):
|
||||||
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore)
|
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
|
||||||
yield agent_persistence
|
yield agent_persistence
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||||
async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup):
|
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
|
||||||
agent_persistence = test_setup
|
agent_persistence = test_setup
|
||||||
|
|
||||||
# Set creator's attributes for the session
|
# Set creator's attributes for the session
|
||||||
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
|
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
|
||||||
mock_get_auth_attributes.return_value = creator_attributes
|
mock_get_authenticated_user.return_value = User("test_user", creator_attributes)
|
||||||
|
|
||||||
# Create a session
|
# Create a session
|
||||||
session_id = await agent_persistence.create_session("Test Session")
|
session_id = await agent_persistence.create_session("Test Session")
|
||||||
|
@ -37,14 +37,15 @@ async def test_session_creation_with_access_attributes(mock_get_auth_attributes,
|
||||||
# Get the session and verify access attributes were set
|
# Get the session and verify access attributes were set
|
||||||
session_info = await agent_persistence.get_session_info(session_id)
|
session_info = await agent_persistence.get_session_info(session_id)
|
||||||
assert session_info is not None
|
assert session_info is not None
|
||||||
assert session_info.access_attributes is not None
|
assert session_info.owner is not None
|
||||||
assert session_info.access_attributes.roles == ["researcher"]
|
assert session_info.owner.attributes is not None
|
||||||
assert session_info.access_attributes.teams == ["ai-team"]
|
assert session_info.owner.attributes["roles"] == ["researcher"]
|
||||||
|
assert session_info.owner.attributes["teams"] == ["ai-team"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||||
async def test_session_access_control(mock_get_auth_attributes, test_setup):
|
async def test_session_access_control(mock_get_authenticated_user, test_setup):
|
||||||
agent_persistence = test_setup
|
agent_persistence = test_setup
|
||||||
|
|
||||||
# Create a session with specific access attributes
|
# Create a session with specific access attributes
|
||||||
|
@ -53,8 +54,9 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
session_name="Restricted Session",
|
session_name="Restricted Session",
|
||||||
started_at=datetime.now(),
|
started_at=datetime.now(),
|
||||||
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
|
owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}),
|
||||||
turns=[],
|
turns=[],
|
||||||
|
identifier="Restricted Session",
|
||||||
)
|
)
|
||||||
|
|
||||||
await agent_persistence.kvstore.set(
|
await agent_persistence.kvstore.set(
|
||||||
|
@ -63,20 +65,22 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
|
||||||
)
|
)
|
||||||
|
|
||||||
# User with matching attributes can access
|
# User with matching attributes can access
|
||||||
mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
|
mock_get_authenticated_user.return_value = User(
|
||||||
|
"testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
|
||||||
|
)
|
||||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||||
assert retrieved_session is not None
|
assert retrieved_session is not None
|
||||||
assert retrieved_session.session_id == session_id
|
assert retrieved_session.session_id == session_id
|
||||||
|
|
||||||
# User without matching attributes cannot access
|
# User without matching attributes cannot access
|
||||||
mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]}
|
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]})
|
||||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||||
assert retrieved_session is None
|
assert retrieved_session is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||||
async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
|
||||||
agent_persistence = test_setup
|
agent_persistence = test_setup
|
||||||
|
|
||||||
# Create a session with restricted access
|
# Create a session with restricted access
|
||||||
|
@ -85,8 +89,9 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
session_name="Restricted Session",
|
session_name="Restricted Session",
|
||||||
started_at=datetime.now(),
|
started_at=datetime.now(),
|
||||||
access_attributes=AccessAttributes(roles=["admin"]),
|
owner=User("someone", {"roles": ["admin"]}),
|
||||||
turns=[],
|
turns=[],
|
||||||
|
identifier="Restricted Session",
|
||||||
)
|
)
|
||||||
|
|
||||||
await agent_persistence.kvstore.set(
|
await agent_persistence.kvstore.set(
|
||||||
|
@ -109,7 +114,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Admin can add turn
|
# Admin can add turn
|
||||||
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
|
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
|
||||||
await agent_persistence.add_turn_to_session(session_id, turn)
|
await agent_persistence.add_turn_to_session(session_id, turn)
|
||||||
|
|
||||||
# Admin can get turn
|
# Admin can get turn
|
||||||
|
@ -118,7 +123,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
||||||
assert retrieved_turn.turn_id == turn_id
|
assert retrieved_turn.turn_id == turn_id
|
||||||
|
|
||||||
# Regular user cannot get turn
|
# Regular user cannot get turn
|
||||||
mock_get_auth_attributes.return_value = {"roles": ["user"]}
|
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await agent_persistence.get_session_turn(session_id, turn_id)
|
await agent_persistence.get_session_turn(session_id, turn_id)
|
||||||
|
|
||||||
|
@ -128,8 +133,8 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||||
async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup):
|
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
|
||||||
agent_persistence = test_setup
|
agent_persistence = test_setup
|
||||||
|
|
||||||
# Create a session with restricted access
|
# Create a session with restricted access
|
||||||
|
@ -138,8 +143,9 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
session_name="Restricted Session",
|
session_name="Restricted Session",
|
||||||
started_at=datetime.now(),
|
started_at=datetime.now(),
|
||||||
access_attributes=AccessAttributes(roles=["admin"]),
|
owner=User("someone", {"roles": ["admin"]}),
|
||||||
turns=[],
|
turns=[],
|
||||||
|
identifier="Restricted Session",
|
||||||
)
|
)
|
||||||
|
|
||||||
await agent_persistence.kvstore.set(
|
await agent_persistence.kvstore.set(
|
||||||
|
@ -150,7 +156,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# Admin user can set inference iterations
|
# Admin user can set inference iterations
|
||||||
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
|
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
|
||||||
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
|
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
|
||||||
|
|
||||||
# Admin user can get inference iterations
|
# Admin user can get inference iterations
|
||||||
|
@ -158,7 +164,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
|
||||||
assert infer_iters == 5
|
assert infer_iters == 5
|
||||||
|
|
||||||
# Regular user cannot get inference iterations
|
# Regular user cannot get inference iterations
|
||||||
mock_get_auth_attributes.return_value = {"roles": ["user"]}
|
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
|
||||||
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
|
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
|
||||||
assert infer_iters is None
|
assert infer_iters is None
|
||||||
|
|
||||||
|
|
|
@ -8,19 +8,18 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.distribution.datatypes import ModelWithACL
|
from llama_stack.distribution.datatypes import ModelWithOwner, User
|
||||||
from llama_stack.distribution.server.auth_providers import AccessAttributes
|
|
||||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="model-acl",
|
identifier="model-acl",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
provider_resource_id="model-acl-resource",
|
provider_resource_id="model-acl-resource",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
|
owner=User("testuser", {"roles": ["admin"], "teams": ["ai-team"]}),
|
||||||
)
|
)
|
||||||
|
|
||||||
success = await cached_disk_dist_registry.register(model)
|
success = await cached_disk_dist_registry.register(model)
|
||||||
|
@ -29,22 +28,14 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
||||||
cached_model = cached_disk_dist_registry.get_cached("model", "model-acl")
|
cached_model = cached_disk_dist_registry.get_cached("model", "model-acl")
|
||||||
assert cached_model is not None
|
assert cached_model is not None
|
||||||
assert cached_model.identifier == "model-acl"
|
assert cached_model.identifier == "model-acl"
|
||||||
assert cached_model.access_attributes.roles == ["admin"]
|
assert cached_model.owner.principal == "testuser"
|
||||||
assert cached_model.access_attributes.teams == ["ai-team"]
|
assert cached_model.owner.attributes["roles"] == ["admin"]
|
||||||
|
assert cached_model.owner.attributes["teams"] == ["ai-team"]
|
||||||
|
|
||||||
fetched_model = await cached_disk_dist_registry.get("model", "model-acl")
|
fetched_model = await cached_disk_dist_registry.get("model", "model-acl")
|
||||||
assert fetched_model is not None
|
assert fetched_model is not None
|
||||||
assert fetched_model.identifier == "model-acl"
|
assert fetched_model.identifier == "model-acl"
|
||||||
assert fetched_model.access_attributes.roles == ["admin"]
|
assert fetched_model.owner.attributes["roles"] == ["admin"]
|
||||||
|
|
||||||
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
|
|
||||||
await cached_disk_dist_registry.update(model)
|
|
||||||
|
|
||||||
updated_cached = cached_disk_dist_registry.get_cached("model", "model-acl")
|
|
||||||
assert updated_cached is not None
|
|
||||||
assert updated_cached.access_attributes.roles == ["admin", "user"]
|
|
||||||
assert updated_cached.access_attributes.projects == ["project-x"]
|
|
||||||
assert updated_cached.access_attributes.teams is None
|
|
||||||
|
|
||||||
new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
|
new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
|
||||||
await new_registry.initialize()
|
await new_registry.initialize()
|
||||||
|
@ -52,35 +43,32 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
||||||
new_model = await new_registry.get("model", "model-acl")
|
new_model = await new_registry.get("model", "model-acl")
|
||||||
assert new_model is not None
|
assert new_model is not None
|
||||||
assert new_model.identifier == "model-acl"
|
assert new_model.identifier == "model-acl"
|
||||||
assert new_model.access_attributes.roles == ["admin", "user"]
|
assert new_model.owner.principal == "testuser"
|
||||||
assert new_model.access_attributes.projects == ["project-x"]
|
assert new_model.owner.attributes["roles"] == ["admin"]
|
||||||
assert new_model.access_attributes.teams is None
|
assert new_model.owner.attributes["teams"] == ["ai-team"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_registry_empty_acl(cached_disk_dist_registry):
|
async def test_registry_empty_acl(cached_disk_dist_registry):
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="model-empty-acl",
|
identifier="model-empty-acl",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
provider_resource_id="model-resource",
|
provider_resource_id="model-resource",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(),
|
owner=User("testuser", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
await cached_disk_dist_registry.register(model)
|
await cached_disk_dist_registry.register(model)
|
||||||
|
|
||||||
cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl")
|
cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl")
|
||||||
assert cached_model is not None
|
assert cached_model is not None
|
||||||
assert cached_model.access_attributes is not None
|
assert cached_model.owner is not None
|
||||||
assert cached_model.access_attributes.roles is None
|
assert cached_model.owner.attributes is None
|
||||||
assert cached_model.access_attributes.teams is None
|
|
||||||
assert cached_model.access_attributes.projects is None
|
|
||||||
assert cached_model.access_attributes.namespaces is None
|
|
||||||
|
|
||||||
all_models = await cached_disk_dist_registry.get_all()
|
all_models = await cached_disk_dist_registry.get_all()
|
||||||
assert len(all_models) == 1
|
assert len(all_models) == 1
|
||||||
|
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="model-no-acl",
|
identifier="model-no-acl",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
provider_resource_id="model-resource-2",
|
provider_resource_id="model-resource-2",
|
||||||
|
@ -91,7 +79,7 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
|
||||||
|
|
||||||
cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl")
|
cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl")
|
||||||
assert cached_model is not None
|
assert cached_model is not None
|
||||||
assert cached_model.access_attributes is None
|
assert cached_model.owner is None
|
||||||
|
|
||||||
all_models = await cached_disk_dist_registry.get_all()
|
all_models = await cached_disk_dist_registry.get_all()
|
||||||
assert len(all_models) == 2
|
assert len(all_models) == 2
|
||||||
|
@ -99,19 +87,19 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_registry_serialization(cached_disk_dist_registry):
|
async def test_registry_serialization(cached_disk_dist_registry):
|
||||||
attributes = AccessAttributes(
|
attributes = {
|
||||||
roles=["admin", "researcher"],
|
"roles": ["admin", "researcher"],
|
||||||
teams=["ai-team", "ml-team"],
|
"teams": ["ai-team", "ml-team"],
|
||||||
projects=["project-a", "project-b"],
|
"projects": ["project-a", "project-b"],
|
||||||
namespaces=["prod", "staging"],
|
"namespaces": ["prod", "staging"],
|
||||||
)
|
}
|
||||||
|
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="model-serialize",
|
identifier="model-serialize",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
provider_resource_id="model-resource",
|
provider_resource_id="model-resource",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=attributes,
|
owner=User("bob", attributes),
|
||||||
)
|
)
|
||||||
|
|
||||||
await cached_disk_dist_registry.register(model)
|
await cached_disk_dist_registry.register(model)
|
||||||
|
@ -122,7 +110,7 @@ async def test_registry_serialization(cached_disk_dist_registry):
|
||||||
loaded_model = await new_registry.get("model", "model-serialize")
|
loaded_model = await new_registry.get("model", "model-serialize")
|
||||||
assert loaded_model is not None
|
assert loaded_model is not None
|
||||||
|
|
||||||
assert loaded_model.access_attributes.roles == ["admin", "researcher"]
|
assert loaded_model.owner.attributes["roles"] == ["admin", "researcher"]
|
||||||
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"]
|
assert loaded_model.owner.attributes["teams"] == ["ai-team", "ml-team"]
|
||||||
assert loaded_model.access_attributes.projects == ["project-a", "project-b"]
|
assert loaded_model.owner.attributes["projects"] == ["project-a", "project-b"]
|
||||||
assert loaded_model.access_attributes.namespaces == ["prod", "staging"]
|
assert loaded_model.owner.attributes["namespaces"] == ["prod", "staging"]
|
||||||
|
|
|
@ -7,10 +7,13 @@
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from pydantic import TypeAdapter, ValidationError
|
||||||
|
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||||
|
from llama_stack.distribution.datatypes import AccessRule, ModelWithOwner, User
|
||||||
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,39 +35,40 @@ async def test_setup(cached_disk_dist_registry):
|
||||||
routing_table = ModelsRoutingTable(
|
routing_table = ModelsRoutingTable(
|
||||||
impls_by_provider_id={"test_provider": mock_inference},
|
impls_by_provider_id={"test_provider": mock_inference},
|
||||||
dist_registry=cached_disk_dist_registry,
|
dist_registry=cached_disk_dist_registry,
|
||||||
|
policy={},
|
||||||
)
|
)
|
||||||
yield cached_disk_dist_registry, routing_table
|
yield cached_disk_dist_registry, routing_table
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
model_public = ModelWithACL(
|
model_public = ModelWithOwner(
|
||||||
identifier="model-public",
|
identifier="model-public",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="model-public",
|
provider_resource_id="model-public",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
)
|
)
|
||||||
model_admin_only = ModelWithACL(
|
model_admin_only = ModelWithOwner(
|
||||||
identifier="model-admin",
|
identifier="model-admin",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="model-admin",
|
provider_resource_id="model-admin",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(roles=["admin"]),
|
owner=User("testuser", {"roles": ["admin"]}),
|
||||||
)
|
)
|
||||||
model_data_scientist = ModelWithACL(
|
model_data_scientist = ModelWithOwner(
|
||||||
identifier="model-data-scientist",
|
identifier="model-data-scientist",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="model-data-scientist",
|
provider_resource_id="model-data-scientist",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]),
|
owner=User("testuser", {"roles": ["data-scientist", "researcher"], "teams": ["ml-team"]}),
|
||||||
)
|
)
|
||||||
await registry.register(model_public)
|
await registry.register(model_public)
|
||||||
await registry.register(model_admin_only)
|
await registry.register(model_admin_only)
|
||||||
await registry.register(model_data_scientist)
|
await registry.register(model_data_scientist)
|
||||||
|
|
||||||
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
|
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"], "teams": ["management"]})
|
||||||
all_models = await routing_table.list_models()
|
all_models = await routing_table.list_models()
|
||||||
assert len(all_models.data) == 2
|
assert len(all_models.data) == 2
|
||||||
|
|
||||||
|
@ -75,7 +79,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await routing_table.get_model("model-data-scientist")
|
await routing_table.get_model("model-data-scientist")
|
||||||
|
|
||||||
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
|
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["other-team"]})
|
||||||
all_models = await routing_table.list_models()
|
all_models = await routing_table.list_models()
|
||||||
assert len(all_models.data) == 1
|
assert len(all_models.data) == 1
|
||||||
assert all_models.data[0].identifier == "model-public"
|
assert all_models.data[0].identifier == "model-public"
|
||||||
|
@ -86,7 +90,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await routing_table.get_model("model-data-scientist")
|
await routing_table.get_model("model-data-scientist")
|
||||||
|
|
||||||
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]}
|
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["ml-team"]})
|
||||||
all_models = await routing_table.list_models()
|
all_models = await routing_table.list_models()
|
||||||
assert len(all_models.data) == 2
|
assert len(all_models.data) == 2
|
||||||
model_ids = [m.identifier for m in all_models.data]
|
model_ids = [m.identifier for m in all_models.data]
|
||||||
|
@ -102,50 +106,62 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
|
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
model_public = ModelWithACL(
|
model_public = ModelWithOwner(
|
||||||
identifier="model-updates",
|
identifier="model-updates",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="model-updates",
|
provider_resource_id="model-updates",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
)
|
)
|
||||||
await registry.register(model_public)
|
await registry.register(model_public)
|
||||||
mock_get_auth_attributes.return_value = {
|
mock_get_authenticated_user.return_value = User(
|
||||||
"roles": ["user"],
|
"test-user",
|
||||||
}
|
{
|
||||||
|
"roles": ["user"],
|
||||||
|
},
|
||||||
|
)
|
||||||
model = await routing_table.get_model("model-updates")
|
model = await routing_table.get_model("model-updates")
|
||||||
assert model.identifier == "model-updates"
|
assert model.identifier == "model-updates"
|
||||||
model_public.access_attributes = AccessAttributes(roles=["admin"])
|
model_public.owner = User("testuser", {"roles": ["admin"]})
|
||||||
await registry.update(model_public)
|
await registry.update(model_public)
|
||||||
mock_get_auth_attributes.return_value = {
|
mock_get_authenticated_user.return_value = User(
|
||||||
"roles": ["user"],
|
"test-user",
|
||||||
}
|
{
|
||||||
|
"roles": ["user"],
|
||||||
|
},
|
||||||
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await routing_table.get_model("model-updates")
|
await routing_table.get_model("model-updates")
|
||||||
mock_get_auth_attributes.return_value = {
|
mock_get_authenticated_user.return_value = User(
|
||||||
"roles": ["admin"],
|
"test-user",
|
||||||
}
|
{
|
||||||
|
"roles": ["admin"],
|
||||||
|
},
|
||||||
|
)
|
||||||
model = await routing_table.get_model("model-updates")
|
model = await routing_table.get_model("model-updates")
|
||||||
assert model.identifier == "model-updates"
|
assert model.identifier == "model-updates"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
|
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="model-empty-attrs",
|
identifier="model-empty-attrs",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="model-empty-attrs",
|
provider_resource_id="model-empty-attrs",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(),
|
owner=User("testuser", {}),
|
||||||
)
|
)
|
||||||
await registry.register(model)
|
await registry.register(model)
|
||||||
mock_get_auth_attributes.return_value = {
|
mock_get_authenticated_user.return_value = User(
|
||||||
"roles": [],
|
"test-user",
|
||||||
}
|
{
|
||||||
|
"roles": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
result = await routing_table.get_model("model-empty-attrs")
|
result = await routing_table.get_model("model-empty-attrs")
|
||||||
assert result.identifier == "model-empty-attrs"
|
assert result.identifier == "model-empty-attrs"
|
||||||
all_models = await routing_table.list_models()
|
all_models = await routing_table.list_models()
|
||||||
|
@ -154,25 +170,25 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
model_public = ModelWithACL(
|
model_public = ModelWithOwner(
|
||||||
identifier="model-public-2",
|
identifier="model-public-2",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="model-public-2",
|
provider_resource_id="model-public-2",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
)
|
)
|
||||||
model_restricted = ModelWithACL(
|
model_restricted = ModelWithOwner(
|
||||||
identifier="model-restricted",
|
identifier="model-restricted",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="model-restricted",
|
provider_resource_id="model-restricted",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(roles=["admin"]),
|
owner=User("testuser", {"roles": ["admin"]}),
|
||||||
)
|
)
|
||||||
await registry.register(model_public)
|
await registry.register(model_public)
|
||||||
await registry.register(model_restricted)
|
await registry.register(model_restricted)
|
||||||
mock_get_auth_attributes.return_value = None
|
mock_get_authenticated_user.return_value = User("test-user", None)
|
||||||
model = await routing_table.get_model("model-public-2")
|
model = await routing_table.get_model("model-public-2")
|
||||||
assert model.identifier == "model-public-2"
|
assert model.identifier == "model-public-2"
|
||||||
|
|
||||||
|
@ -185,17 +201,17 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
|
async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup):
|
||||||
"""Test that newly created resources inherit access attributes from their creator."""
|
"""Test that newly created resources inherit access attributes from their creator."""
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
|
|
||||||
# Set creator's attributes
|
# Set creator's attributes
|
||||||
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
|
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
|
||||||
mock_get_auth_attributes.return_value = creator_attributes
|
mock_get_authenticated_user.return_value = User("test-user", creator_attributes)
|
||||||
|
|
||||||
# Create model without explicit access attributes
|
# Create model without explicit access attributes
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="auto-access-model",
|
identifier="auto-access-model",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="auto-access-model",
|
provider_resource_id="auto-access-model",
|
||||||
|
@ -205,21 +221,346 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup)
|
||||||
|
|
||||||
# Verify the model got creator's attributes
|
# Verify the model got creator's attributes
|
||||||
registered_model = await routing_table.get_model("auto-access-model")
|
registered_model = await routing_table.get_model("auto-access-model")
|
||||||
assert registered_model.access_attributes is not None
|
assert registered_model.owner is not None
|
||||||
assert registered_model.access_attributes.roles == ["data-scientist"]
|
assert registered_model.owner.attributes is not None
|
||||||
assert registered_model.access_attributes.teams == ["ml-team"]
|
assert registered_model.owner.attributes["roles"] == ["data-scientist"]
|
||||||
assert registered_model.access_attributes.projects == ["llama-3"]
|
assert registered_model.owner.attributes["teams"] == ["ml-team"]
|
||||||
|
assert registered_model.owner.attributes["projects"] == ["llama-3"]
|
||||||
|
|
||||||
# Verify another user without matching attributes can't access it
|
# Verify another user without matching attributes can't access it
|
||||||
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]}
|
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await routing_table.get_model("auto-access-model")
|
await routing_table.get_model("auto-access-model")
|
||||||
|
|
||||||
# But a user with matching attributes can
|
# But a user with matching attributes can
|
||||||
mock_get_auth_attributes.return_value = {
|
mock_get_authenticated_user.return_value = User(
|
||||||
"roles": ["data-scientist", "engineer"],
|
"test-user",
|
||||||
"teams": ["ml-team", "platform-team"],
|
{
|
||||||
"projects": ["llama-3"],
|
"roles": ["data-scientist", "engineer"],
|
||||||
}
|
"teams": ["ml-team", "platform-team"],
|
||||||
|
"projects": ["llama-3"],
|
||||||
|
},
|
||||||
|
)
|
||||||
model = await routing_table.get_model("auto-access-model")
|
model = await routing_table.get_model("auto-access-model")
|
||||||
assert model.identifier == "auto-access-model"
|
assert model.identifier == "auto-access-model"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_setup_with_access_policy(cached_disk_dist_registry):
|
||||||
|
mock_inference = Mock()
|
||||||
|
mock_inference.__provider_spec__ = MagicMock()
|
||||||
|
mock_inference.__provider_spec__.api = Api.inference
|
||||||
|
mock_inference.register_model = AsyncMock(side_effect=_return_model)
|
||||||
|
mock_inference.unregister_model = AsyncMock(side_effect=_return_model)
|
||||||
|
|
||||||
|
config = """
|
||||||
|
- permit:
|
||||||
|
principal: user-1
|
||||||
|
actions: [create, read, delete]
|
||||||
|
description: user-1 has full access to all models
|
||||||
|
- permit:
|
||||||
|
principal: user-2
|
||||||
|
actions: [read]
|
||||||
|
resource: model::model-1
|
||||||
|
description: user-2 has read access to model-1 only
|
||||||
|
- permit:
|
||||||
|
principal: user-3
|
||||||
|
actions: [read]
|
||||||
|
resource: model::model-2
|
||||||
|
description: user-3 has read access to model-2 only
|
||||||
|
- forbid:
|
||||||
|
actions: [create, read, delete]
|
||||||
|
"""
|
||||||
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
routing_table = ModelsRoutingTable(
|
||||||
|
impls_by_provider_id={"test_provider": mock_inference},
|
||||||
|
dist_registry=cached_disk_dist_registry,
|
||||||
|
policy=policy,
|
||||||
|
)
|
||||||
|
yield routing_table
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
|
async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy):
|
||||||
|
routing_table = test_setup_with_access_policy
|
||||||
|
mock_get_authenticated_user.return_value = User(
|
||||||
|
"user-1",
|
||||||
|
{
|
||||||
|
"roles": ["admin"],
|
||||||
|
"projects": ["foo", "bar"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await routing_table.register_model("model-1", provider_id="test_provider")
|
||||||
|
await routing_table.register_model("model-2", provider_id="test_provider")
|
||||||
|
await routing_table.register_model("model-3", provider_id="test_provider")
|
||||||
|
model = await routing_table.get_model("model-1")
|
||||||
|
assert model.identifier == "model-1"
|
||||||
|
model = await routing_table.get_model("model-2")
|
||||||
|
assert model.identifier == "model-2"
|
||||||
|
model = await routing_table.get_model("model-3")
|
||||||
|
assert model.identifier == "model-3"
|
||||||
|
|
||||||
|
mock_get_authenticated_user.return_value = User(
|
||||||
|
"user-2",
|
||||||
|
{
|
||||||
|
"roles": ["user"],
|
||||||
|
"projects": ["foo"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
model = await routing_table.get_model("model-1")
|
||||||
|
assert model.identifier == "model-1"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await routing_table.get_model("model-2")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await routing_table.get_model("model-3")
|
||||||
|
with pytest.raises(AccessDeniedError):
|
||||||
|
await routing_table.register_model("model-4", provider_id="test_provider")
|
||||||
|
with pytest.raises(AccessDeniedError):
|
||||||
|
await routing_table.unregister_model("model-1")
|
||||||
|
|
||||||
|
mock_get_authenticated_user.return_value = User(
|
||||||
|
"user-3",
|
||||||
|
{
|
||||||
|
"roles": ["user"],
|
||||||
|
"projects": ["bar"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
model = await routing_table.get_model("model-2")
|
||||||
|
assert model.identifier == "model-2"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await routing_table.get_model("model-1")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await routing_table.get_model("model-3")
|
||||||
|
with pytest.raises(AccessDeniedError):
|
||||||
|
await routing_table.register_model("model-5", provider_id="test_provider")
|
||||||
|
with pytest.raises(AccessDeniedError):
|
||||||
|
await routing_table.unregister_model("model-2")
|
||||||
|
|
||||||
|
mock_get_authenticated_user.return_value = User(
|
||||||
|
"user-1",
|
||||||
|
{
|
||||||
|
"roles": ["admin"],
|
||||||
|
"projects": ["foo", "bar"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await routing_table.unregister_model("model-3")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await routing_table.get_model("model-3")
|
||||||
|
|
||||||
|
|
||||||
|
def test_permit_when():
|
||||||
|
config = """
|
||||||
|
- permit:
|
||||||
|
principal: user-1
|
||||||
|
actions: [read]
|
||||||
|
when: user in owners namespaces
|
||||||
|
"""
|
||||||
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
model = ModelWithOwner(
|
||||||
|
identifier="mymodel",
|
||||||
|
provider_id="myprovider",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||||
|
)
|
||||||
|
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||||
|
|
||||||
|
|
||||||
|
def test_permit_unless():
|
||||||
|
config = """
|
||||||
|
- permit:
|
||||||
|
principal: user-1
|
||||||
|
actions: [read]
|
||||||
|
resource: model::*
|
||||||
|
unless:
|
||||||
|
- user not in owners namespaces
|
||||||
|
- user in owners teams
|
||||||
|
"""
|
||||||
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
model = ModelWithOwner(
|
||||||
|
identifier="mymodel",
|
||||||
|
provider_id="myprovider",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||||
|
)
|
||||||
|
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||||
|
|
||||||
|
|
||||||
|
def test_forbid_when():
|
||||||
|
config = """
|
||||||
|
- forbid:
|
||||||
|
principal: user-1
|
||||||
|
actions: [read]
|
||||||
|
when:
|
||||||
|
user in owners namespaces
|
||||||
|
- permit:
|
||||||
|
actions: [read]
|
||||||
|
"""
|
||||||
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
model = ModelWithOwner(
|
||||||
|
identifier="mymodel",
|
||||||
|
provider_id="myprovider",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||||
|
)
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||||
|
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||||
|
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||||
|
|
||||||
|
|
||||||
|
def test_forbid_unless():
|
||||||
|
config = """
|
||||||
|
- forbid:
|
||||||
|
principal: user-1
|
||||||
|
actions: [read]
|
||||||
|
unless:
|
||||||
|
user in owners namespaces
|
||||||
|
- permit:
|
||||||
|
actions: [read]
|
||||||
|
"""
|
||||||
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
model = ModelWithOwner(
|
||||||
|
identifier="mymodel",
|
||||||
|
provider_id="myprovider",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||||
|
)
|
||||||
|
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||||
|
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_has_attribute():
|
||||||
|
config = """
|
||||||
|
- permit:
|
||||||
|
actions: [read]
|
||||||
|
when: user with admin in roles
|
||||||
|
"""
|
||||||
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
model = ModelWithOwner(
|
||||||
|
identifier="mymodel",
|
||||||
|
provider_id="myprovider",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||||
|
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_does_not_have_attribute():
|
||||||
|
config = """
|
||||||
|
- permit:
|
||||||
|
actions: [read]
|
||||||
|
unless: user with admin not in roles
|
||||||
|
"""
|
||||||
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
model = ModelWithOwner(
|
||||||
|
identifier="mymodel",
|
||||||
|
provider_id="myprovider",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||||
|
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_owner():
|
||||||
|
config = """
|
||||||
|
- permit:
|
||||||
|
actions: [read]
|
||||||
|
when: user is owner
|
||||||
|
"""
|
||||||
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
model = ModelWithOwner(
|
||||||
|
identifier="mymodel",
|
||||||
|
provider_id="myprovider",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
owner=User("user-2", {"namespaces": ["foo"]}),
|
||||||
|
)
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||||
|
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_not_owner():
|
||||||
|
config = """
|
||||||
|
- permit:
|
||||||
|
actions: [read]
|
||||||
|
unless: user is not owner
|
||||||
|
"""
|
||||||
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
model = ModelWithOwner(
|
||||||
|
identifier="mymodel",
|
||||||
|
provider_id="myprovider",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
owner=User("user-2", {"namespaces": ["foo"]}),
|
||||||
|
)
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||||
|
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_rule_permit_and_forbid_both_specified():
|
||||||
|
config = """
|
||||||
|
- permit:
|
||||||
|
actions: [read]
|
||||||
|
forbid:
|
||||||
|
actions: [create]
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_rule_neither_permit_or_forbid_specified():
|
||||||
|
config = """
|
||||||
|
- when: user is owner
|
||||||
|
unless: user with admin in roles
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_rule_when_and_unless_both_specified():
|
||||||
|
config = """
|
||||||
|
- permit:
|
||||||
|
actions: [read]
|
||||||
|
when: user is owner
|
||||||
|
unless: user with admin in roles
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_condition():
|
||||||
|
config = """
|
||||||
|
- permit:
|
||||||
|
actions: [read]
|
||||||
|
when: random words that are not valid
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"condition",
|
||||||
|
[
|
||||||
|
"user is owner",
|
||||||
|
"user is not owner",
|
||||||
|
"user with dev in teams",
|
||||||
|
"user with default not in namespaces",
|
||||||
|
"user in owners roles",
|
||||||
|
"user not in owners projects",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_condition_reprs(condition):
|
||||||
|
from llama_stack.distribution.access_control.conditions import parse_condition
|
||||||
|
|
||||||
|
assert condition == str(parse_condition(condition))
|
||||||
|
|
|
@ -139,7 +139,7 @@ async def mock_post_success(*args, **kwargs):
|
||||||
{
|
{
|
||||||
"message": "Authentication successful",
|
"message": "Authentication successful",
|
||||||
"principal": "test-principal",
|
"principal": "test-principal",
|
||||||
"access_attributes": {
|
"attributes": {
|
||||||
"roles": ["admin", "user"],
|
"roles": ["admin", "user"],
|
||||||
"teams": ["ml-team", "nlp-team"],
|
"teams": ["ml-team", "nlp-team"],
|
||||||
"projects": ["llama-3", "project-x"],
|
"projects": ["llama-3", "project-x"],
|
||||||
|
@ -233,7 +233,7 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
||||||
{
|
{
|
||||||
"message": "Authentication successful",
|
"message": "Authentication successful",
|
||||||
"principal": "test-principal",
|
"principal": "test-principal",
|
||||||
"access_attributes": {
|
"attributes": {
|
||||||
"roles": ["admin", "user"],
|
"roles": ["admin", "user"],
|
||||||
"teams": ["ml-team", "nlp-team"],
|
"teams": ["ml-team", "nlp-team"],
|
||||||
"projects": ["llama-3", "project-x"],
|
"projects": ["llama-3", "project-x"],
|
||||||
|
@ -255,33 +255,6 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
||||||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope):
|
|
||||||
"""Test middleware behavior with no access attributes"""
|
|
||||||
middleware, mock_app = mock_http_middleware
|
|
||||||
mock_receive = AsyncMock()
|
|
||||||
mock_send = AsyncMock()
|
|
||||||
|
|
||||||
with patch("httpx.AsyncClient") as mock_client:
|
|
||||||
mock_client_instance = AsyncMock()
|
|
||||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
|
||||||
|
|
||||||
mock_client_instance.post.return_value = MockResponse(
|
|
||||||
200,
|
|
||||||
{
|
|
||||||
"message": "Authentication successful"
|
|
||||||
# No access_attributes
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
await middleware(mock_scope, mock_receive, mock_send)
|
|
||||||
|
|
||||||
assert "user_attributes" in mock_scope
|
|
||||||
attributes = mock_scope["user_attributes"]
|
|
||||||
assert "roles" in attributes
|
|
||||||
assert attributes["roles"] == ["test.jwt.token"]
|
|
||||||
|
|
||||||
|
|
||||||
# oauth2 token provider tests
|
# oauth2 token provider tests
|
||||||
|
|
||||||
|
|
||||||
|
@ -380,16 +353,16 @@ def test_get_attributes_from_claims():
|
||||||
"aud": "llama-stack",
|
"aud": "llama-stack",
|
||||||
}
|
}
|
||||||
attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"})
|
attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"})
|
||||||
assert attributes.roles == ["my-user"]
|
assert attributes["roles"] == ["my-user"]
|
||||||
assert attributes.teams == ["group1", "group2"]
|
assert attributes["teams"] == ["group1", "group2"]
|
||||||
|
|
||||||
claims = {
|
claims = {
|
||||||
"sub": "my-user",
|
"sub": "my-user",
|
||||||
"tenant": "my-tenant",
|
"tenant": "my-tenant",
|
||||||
}
|
}
|
||||||
attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"})
|
attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"})
|
||||||
assert attributes.roles == ["my-user"]
|
assert attributes["roles"] == ["my-user"]
|
||||||
assert attributes.namespaces == ["my-tenant"]
|
assert attributes["namespaces"] == ["my-tenant"]
|
||||||
|
|
||||||
claims = {
|
claims = {
|
||||||
"sub": "my-user",
|
"sub": "my-user",
|
||||||
|
@ -408,9 +381,9 @@ def test_get_attributes_from_claims():
|
||||||
"groups": "teams",
|
"groups": "teams",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert set(attributes.roles) == {"my-user", "my-username"}
|
assert set(attributes["roles"]) == {"my-user", "my-username"}
|
||||||
assert set(attributes.teams) == {"my-team", "group1", "group2"}
|
assert set(attributes["teams"]) == {"my-team", "group1", "group2"}
|
||||||
assert attributes.namespaces == ["my-tenant"]
|
assert attributes["namespaces"] == ["my-tenant"]
|
||||||
|
|
||||||
|
|
||||||
# TODO: add more tests for oauth2 token provider
|
# TODO: add more tests for oauth2 token provider
|
||||||
|
|
|
@ -100,9 +100,10 @@ async def test_resolve_impls_basic():
|
||||||
add_protocol_methods(SampleImpl, Inference)
|
add_protocol_methods(SampleImpl, Inference)
|
||||||
|
|
||||||
mock_module.get_provider_impl = AsyncMock(return_value=impl)
|
mock_module.get_provider_impl = AsyncMock(return_value=impl)
|
||||||
|
mock_module.get_provider_impl.__text_signature__ = "()"
|
||||||
sys.modules["test_module"] = mock_module
|
sys.modules["test_module"] = mock_module
|
||||||
|
|
||||||
impls = await resolve_impls(run_config, provider_registry, dist_registry)
|
impls = await resolve_impls(run_config, provider_registry, dist_registry, policy={})
|
||||||
|
|
||||||
assert Api.inference in impls
|
assert Api.inference in impls
|
||||||
assert isinstance(impls[Api.inference], InferenceRouter)
|
assert isinstance(impls[Api.inference], InferenceRouter)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue