mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
Changes to access rule conditions:
* change from access_attributes to owner on dynamically created resources * define simpler string based conditions for more intuitive restriction
This commit is contained in:
parent
01ad876012
commit
96cd51a0c8
20 changed files with 427 additions and 431 deletions
|
@ -4,16 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Protocol
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.request_headers import User
|
||||
from llama_stack.distribution.datatypes import User
|
||||
|
||||
from .conditions import (
|
||||
Condition,
|
||||
ProtectedResource,
|
||||
parse_conditions,
|
||||
)
|
||||
from .datatypes import (
|
||||
AccessAttributes,
|
||||
AccessRule,
|
||||
Action,
|
||||
AttributeReference,
|
||||
Condition,
|
||||
Scope,
|
||||
)
|
||||
|
||||
|
@ -37,43 +39,6 @@ def matches_scope(
|
|||
return action in scope.actions
|
||||
|
||||
|
||||
def user_in_literal(
|
||||
literal: str,
|
||||
user_attributes: dict[str, list[str]] | None,
|
||||
) -> bool:
|
||||
for qualifier in ["role::", "team::", "project::", "namespace::"]:
|
||||
if literal.startswith(qualifier):
|
||||
if not user_attributes:
|
||||
return False
|
||||
ref = qualifier.replace("::", "s")
|
||||
if ref in user_attributes:
|
||||
value = literal.removeprefix(qualifier)
|
||||
return value in user_attributes[ref]
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def user_in(
|
||||
ref: AttributeReference | str,
|
||||
resource_attributes: AccessAttributes | None,
|
||||
user_attributes: dict[str, list[str]] | None,
|
||||
) -> bool:
|
||||
if not ref.startswith("resource."):
|
||||
return user_in_literal(ref, user_attributes)
|
||||
name = ref.removeprefix("resource.")
|
||||
required = resource_attributes and getattr(resource_attributes, name)
|
||||
if not required:
|
||||
return True
|
||||
if not user_attributes or name not in user_attributes:
|
||||
return False
|
||||
actual = user_attributes[name]
|
||||
for value in required:
|
||||
if value in actual:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def as_list(obj: Any) -> list[Any]:
|
||||
if isinstance(obj, list):
|
||||
return obj
|
||||
|
@ -82,55 +47,27 @@ def as_list(obj: Any) -> list[Any]:
|
|||
|
||||
def matches_conditions(
|
||||
conditions: list[Condition],
|
||||
resource_attributes: AccessAttributes | None,
|
||||
user_attributes: dict[str, list[str]] | None,
|
||||
resource: ProtectedResource,
|
||||
user: User,
|
||||
) -> bool:
|
||||
for condition in conditions:
|
||||
# must match all conditions
|
||||
if not matches_condition(condition, resource_attributes, user_attributes):
|
||||
if not condition.matches(resource, user):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def matches_condition(
|
||||
condition: Condition | list[Condition],
|
||||
resource_attributes: AccessAttributes | None,
|
||||
user_attributes: dict[str, list[str]] | None,
|
||||
) -> bool:
|
||||
if isinstance(condition, list):
|
||||
return matches_conditions(as_list(condition), resource_attributes, user_attributes)
|
||||
if condition.user_in:
|
||||
for ref in as_list(condition.user_in):
|
||||
# if multiple references are specified, all must match
|
||||
if not user_in(ref, resource_attributes, user_attributes):
|
||||
return False
|
||||
return True
|
||||
if condition.user_not_in:
|
||||
for ref in as_list(condition.user_not_in):
|
||||
# if multiple references are specified, none must match
|
||||
if user_in(ref, resource_attributes, user_attributes):
|
||||
return False
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
def default_policy() -> list[AccessRule]:
|
||||
# for backwards compatibility, if no rules are provided , assume
|
||||
# full access to all subject to attribute matching rules
|
||||
# for backwards compatibility, if no rules are provided, assume
|
||||
# full access subject to previous attribute matching rules
|
||||
return [
|
||||
AccessRule(
|
||||
permit=Scope(actions=list(Action)),
|
||||
when=Condition(user_in=list(AttributeReference)),
|
||||
)
|
||||
when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class ProtectedResource(Protocol):
|
||||
type: str
|
||||
identifier: str
|
||||
access_attributes: AccessAttributes
|
||||
|
||||
|
||||
def is_action_allowed(
|
||||
policy: list[AccessRule],
|
||||
action: Action,
|
||||
|
@ -144,26 +81,23 @@ def is_action_allowed(
|
|||
if not len(policy):
|
||||
policy = default_policy()
|
||||
|
||||
resource_attributes = AccessAttributes()
|
||||
if hasattr(resource, "access_attributes"):
|
||||
resource_attributes = resource.access_attributes
|
||||
qualified_resource_id = resource.type + "::" + resource.identifier
|
||||
for rule in policy:
|
||||
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
|
||||
if rule.when:
|
||||
if matches_condition(rule.when, resource_attributes, user.attributes):
|
||||
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
|
||||
return False
|
||||
elif rule.unless:
|
||||
if not matches_condition(rule.unless, resource_attributes, user.attributes):
|
||||
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
|
||||
if rule.when:
|
||||
if matches_condition(rule.when, resource_attributes, user.attributes):
|
||||
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
|
||||
return True
|
||||
elif rule.unless:
|
||||
if not matches_condition(rule.unless, resource_attributes, user.attributes):
|
||||
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
|
||||
return True
|
||||
else:
|
||||
return True
|
||||
|
|
129
llama_stack/distribution/access_control/conditions.py
Normal file
129
llama_stack/distribution/access_control/conditions.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class User(Protocol):
|
||||
principal: str
|
||||
attributes: dict[str, list[str]] | None
|
||||
|
||||
|
||||
class ProtectedResource(Protocol):
|
||||
type: str
|
||||
identifier: str
|
||||
owner: User
|
||||
|
||||
|
||||
class Condition(Protocol):
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool: ...
|
||||
|
||||
|
||||
class UserInOwnersList:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
def owners_values(self, resource: ProtectedResource) -> list[str] | None:
|
||||
if (
|
||||
hasattr(resource, "owner")
|
||||
and resource.owner
|
||||
and resource.owner.attributes
|
||||
and self.name in resource.owner.attributes
|
||||
):
|
||||
return resource.owner.attributes[self.name]
|
||||
else:
|
||||
return None
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
required = self.owners_values(resource)
|
||||
if not required:
|
||||
return True
|
||||
if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]:
|
||||
return False
|
||||
user_values = user.attributes[self.name]
|
||||
for value in required:
|
||||
if value in user_values:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f"user in owners {self.name}"
|
||||
|
||||
|
||||
class UserNotInOwnersList(UserInOwnersList):
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return not super().matches(resource, user)
|
||||
|
||||
def __repr__(self):
|
||||
return f"user not in owners {self.name}"
|
||||
|
||||
|
||||
class UserWithValueInList:
|
||||
def __init__(self, name: str, value: str):
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
if user.attributes and self.name in user.attributes:
|
||||
return self.value in user.attributes[self.name]
|
||||
print(f"User does not have {self.value} in {self.name}")
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f"user with {self.value} in {self.name}"
|
||||
|
||||
|
||||
class UserWithValueNotInList(UserWithValueInList):
|
||||
def __init__(self, name: str, value: str):
|
||||
super().__init__(name, value)
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return not super().matches(resource, user)
|
||||
|
||||
def __repr__(self):
|
||||
return f"user with {self.value} not in {self.name}"
|
||||
|
||||
|
||||
class UserIsOwner:
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return resource.owner.principal == user.principal if resource.owner else False
|
||||
|
||||
def __repr__(self):
|
||||
return "user is owner"
|
||||
|
||||
|
||||
class UserIsNotOwner:
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return not resource.owner or resource.owner.principal != user.principal
|
||||
|
||||
def __repr__(self):
|
||||
return "user is not owner"
|
||||
|
||||
|
||||
def parse_condition(condition: str) -> Condition:
|
||||
words = condition.split()
|
||||
match words:
|
||||
case ["user", "is", "owner"]:
|
||||
return UserIsOwner()
|
||||
case ["user", "is", "not", "owner"]:
|
||||
return UserIsNotOwner()
|
||||
case ["user", "with", value, "in", name]:
|
||||
return UserWithValueInList(name, value)
|
||||
case ["user", "with", value, "not", "in", name]:
|
||||
return UserWithValueNotInList(name, value)
|
||||
case ["user", "in", "owners", name]:
|
||||
return UserInOwnersList(name)
|
||||
case ["user", "not", "in", "owners", name]:
|
||||
return UserNotInOwnersList(name)
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def parse_conditions(conditions: list[str]) -> list[Condition]:
|
||||
return [parse_condition(c) for c in conditions]
|
|
@ -6,37 +6,10 @@
|
|||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class AccessAttributes(BaseModel):
|
||||
"""Structured representation of user attributes for access control.
|
||||
|
||||
This model defines a structured approach to representing user attributes
|
||||
with common standard categories for access control.
|
||||
|
||||
Standard attribute categories include:
|
||||
- roles: Role-based attributes (e.g., admin, data-scientist)
|
||||
- teams: Team-based attributes (e.g., ml-team, infra-team)
|
||||
- projects: Project access attributes (e.g., llama-3, customer-insights)
|
||||
- namespaces: Namespace-based access control for resource isolation
|
||||
"""
|
||||
|
||||
# Standard attribute categories - the minimal set we need now
|
||||
roles: list[str] | None = Field(
|
||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
||||
)
|
||||
|
||||
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||
|
||||
projects: list[str] | None = Field(
|
||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
||||
)
|
||||
|
||||
namespaces: list[str] | None = Field(
|
||||
default=None, description="Namespace-based access control for resource isolation"
|
||||
)
|
||||
from .conditions import parse_conditions
|
||||
|
||||
|
||||
class Action(str, Enum):
|
||||
|
@ -62,18 +35,6 @@ def _require_one_of(obj, a: str, b: str):
|
|||
raise ValueError(f"on of {a} or {b} is required")
|
||||
|
||||
|
||||
class AttributeReference(str, Enum):
|
||||
RESOURCE_ROLES = "resource.roles"
|
||||
RESOURCE_TEAMS = "resource.teams"
|
||||
RESOURCE_PROJECTS = "resource.projects"
|
||||
RESOURCE_NAMESPACES = "resource.namespaces"
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
user_in: AttributeReference | list[AttributeReference] | str | None = None
|
||||
user_not_in: AttributeReference | list[AttributeReference] | str | None = None
|
||||
|
||||
|
||||
class AccessRule(BaseModel):
|
||||
"""Access rule based loosely on cedar policy language
|
||||
|
||||
|
@ -85,10 +46,14 @@ class AccessRule(BaseModel):
|
|||
requests.
|
||||
|
||||
A rule may also specify a condition, either a 'when' or an 'unless', with additional
|
||||
constraints as to where the rule applies. The constraints at present are whether the
|
||||
user requesting access is in or not in some set. This set can either be a particular
|
||||
set of attributes on the resource e.g. resource.roles or a literal value of some
|
||||
notion of group, e.g. role::admin or namespace::foo.
|
||||
constraints as to where the rule applies. The constraints supported at present are:
|
||||
|
||||
- 'user with <attr-value> in <attr-name>'
|
||||
- 'user with <attr-value> not in <attr-name>'
|
||||
- 'user is owner'
|
||||
- 'user is not owner'
|
||||
- 'user in owners <attr-name>'
|
||||
- 'user not in owners <attr-name>'
|
||||
|
||||
Rules are tested in order to find a match. If a match is found, the request is
|
||||
permitted or forbidden depending on the type of rule. If no match is found, the
|
||||
|
@ -110,22 +75,20 @@ class AccessRule(BaseModel):
|
|||
description: user-2 has read access to model-1 only
|
||||
- permit:
|
||||
actions: [read]
|
||||
when:
|
||||
user_in: resource.namespaces
|
||||
description: any user has read access to any resource with matching attributes
|
||||
when: user in owner teams
|
||||
description: any user has read access to any resource created by a member of their team
|
||||
- forbid:
|
||||
actions: [create, read, delete]
|
||||
resource: vector_db::*
|
||||
unless:
|
||||
user_in: role::admin
|
||||
unless: user with admin in roles
|
||||
description: only user with admin role can use vector_db resources
|
||||
|
||||
"""
|
||||
|
||||
permit: Scope | None = None
|
||||
forbid: Scope | None = None
|
||||
when: Condition | list[Condition] | None = None
|
||||
unless: Condition | list[Condition] | None = None
|
||||
when: str | list[str] | None = None
|
||||
unless: str | list[str] | None = None
|
||||
description: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
@ -133,4 +96,12 @@ class AccessRule(BaseModel):
|
|||
_require_one_of(self, "permit", "forbid")
|
||||
_mutually_exclusive(self, "permit", "forbid")
|
||||
_mutually_exclusive(self, "when", "unless")
|
||||
if isinstance(self.when, list):
|
||||
parse_conditions(self.when)
|
||||
elif self.when:
|
||||
parse_conditions([self.when])
|
||||
if isinstance(self.unless, list):
|
||||
parse_conditions(self.unless)
|
||||
elif self.unless:
|
||||
parse_conditions([self.unless])
|
||||
return self
|
||||
|
|
|
@ -24,7 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput
|
|||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule
|
||||
from llama_stack.distribution.access_control.datatypes import AccessRule
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
||||
|
@ -36,97 +36,66 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
|||
RoutingKey = str | list[str]
|
||||
|
||||
|
||||
class ResourceWithACL(Resource):
|
||||
"""Extension of Resource that adds attribute-based access control capabilities.
|
||||
class User(BaseModel):
|
||||
principal: str
|
||||
# further attributes that may be used for access control decisions
|
||||
attributes: dict[str, list[str]] | None = None
|
||||
|
||||
This class adds an optional access_attributes field that allows fine-grained control
|
||||
over which users can access each resource. When attributes are defined, a user must have
|
||||
matching attributes to access the resource.
|
||||
def __init__(self, principal: str, attributes: dict[str, list[str]] | None):
|
||||
super().__init__(principal=principal, attributes=attributes)
|
||||
|
||||
Attribute Matching Algorithm:
|
||||
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
|
||||
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
|
||||
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
|
||||
4. Within each category, ANY value match is sufficient (OR relationship within a category)
|
||||
|
||||
Examples:
|
||||
# Resource visible to everyone (no access control)
|
||||
model = Model(identifier="llama-2", ...)
|
||||
class ResourceWithOwner(Resource):
|
||||
"""Extension of Resource that adds an optional owner, i.e. the user that created the
|
||||
resource. This can be used to constrain access to the resource."""
|
||||
|
||||
# Resource visible only to admins
|
||||
model = Model(
|
||||
identifier="gpt-4",
|
||||
access_attributes=AccessAttributes(roles=["admin"])
|
||||
)
|
||||
|
||||
# Resource visible to data scientists on the ML team
|
||||
model = Model(
|
||||
identifier="private-model",
|
||||
access_attributes=AccessAttributes(
|
||||
roles=["data-scientist", "researcher"],
|
||||
teams=["ml-team"]
|
||||
)
|
||||
)
|
||||
# ^ User must have at least one of the roles AND be on the ml-team
|
||||
|
||||
# Resource visible to users with specific project access
|
||||
vector_db = VectorDB(
|
||||
identifier="customer-embeddings",
|
||||
access_attributes=AccessAttributes(
|
||||
projects=["customer-insights"],
|
||||
namespaces=["confidential"]
|
||||
)
|
||||
)
|
||||
# ^ User must have access to the customer-insights project AND have confidential namespace
|
||||
"""
|
||||
|
||||
access_attributes: AccessAttributes | None = None
|
||||
owner: User | None = None
|
||||
|
||||
|
||||
# Use the extended Resource for all routable objects
|
||||
class ModelWithACL(Model, ResourceWithACL):
|
||||
class ModelWithOwner(Model, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ShieldWithACL(Shield, ResourceWithACL):
|
||||
class ShieldWithOwner(Shield, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class VectorDBWithACL(VectorDB, ResourceWithACL):
|
||||
class VectorDBWithOwner(VectorDB, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class DatasetWithACL(Dataset, ResourceWithACL):
|
||||
class DatasetWithOwner(Dataset, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
|
||||
class ScoringFnWithOwner(ScoringFn, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class BenchmarkWithACL(Benchmark, ResourceWithACL):
|
||||
class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ToolWithACL(Tool, ResourceWithACL):
|
||||
class ToolWithOwner(Tool, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
||||
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
||||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
ModelWithACL
|
||||
| ShieldWithACL
|
||||
| VectorDBWithACL
|
||||
| DatasetWithACL
|
||||
| ScoringFnWithACL
|
||||
| BenchmarkWithACL
|
||||
| ToolWithACL
|
||||
| ToolGroupWithACL,
|
||||
ModelWithOwner
|
||||
| ShieldWithOwner
|
||||
| VectorDBWithOwner
|
||||
| DatasetWithOwner
|
||||
| ScoringFnWithOwner
|
||||
| BenchmarkWithOwner
|
||||
| ToolWithOwner
|
||||
| ToolGroupWithOwner,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
|
|
@ -10,6 +10,8 @@ import logging
|
|||
from contextlib import AbstractContextManager
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import User
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -18,16 +20,6 @@ log = logging.getLogger(__name__)
|
|||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class User:
|
||||
principal: str
|
||||
# further attributes that may be used for access control decisions
|
||||
attributes: dict[str, list[str]]
|
||||
|
||||
def __init__(self, principal: str, attributes: dict[str, list[str]]):
|
||||
self.principal = principal
|
||||
self.attributes = attributes
|
||||
|
||||
|
||||
class RequestProviderDataContext(AbstractContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BenchmarkWithACL,
|
||||
BenchmarkWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -47,7 +47,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
)
|
||||
if provider_benchmark_id is None:
|
||||
provider_benchmark_id = benchmark_id
|
||||
benchmark = BenchmarkWithACL(
|
||||
benchmark = BenchmarkWithOwner(
|
||||
identifier=benchmark_id,
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=scoring_functions,
|
||||
|
|
|
@ -10,7 +10,6 @@ from llama_stack.apis.resource import ResourceType
|
|||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessAttributes,
|
||||
AccessRule,
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
|
@ -195,9 +194,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
creator = get_authenticated_user()
|
||||
if not is_action_allowed(self.policy, "create", obj, creator):
|
||||
raise AccessDeniedError()
|
||||
if creator and creator.attributes:
|
||||
obj.access_attributes = AccessAttributes(**creator.attributes)
|
||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||
if creator:
|
||||
obj.owner = creator
|
||||
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
|
|
|
@ -19,7 +19,7 @@ from llama_stack.apis.datasets import (
|
|||
)
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
DatasetWithACL,
|
||||
DatasetWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -74,7 +74,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
dataset = DatasetWithACL(
|
||||
dataset = DatasetWithOwner(
|
||||
identifier=dataset_id,
|
||||
provider_resource_id=provider_dataset_id,
|
||||
provider_id=provider_id,
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelWithACL,
|
||||
ModelWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -65,7 +65,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
model_type = ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.scoring_functions import (
|
|||
ScoringFunctions,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ScoringFnWithACL,
|
||||
ScoringFnWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -50,7 +50,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
scoring_fn = ScoringFnWithACL(
|
||||
scoring_fn = ScoringFnWithOwner(
|
||||
identifier=scoring_fn_id,
|
||||
description=description,
|
||||
return_type=return_type,
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any
|
|||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ShieldWithACL,
|
||||
ShieldWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -47,7 +47,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
)
|
||||
if params is None:
|
||||
params = {}
|
||||
shield = ShieldWithACL(
|
||||
shield = ShieldWithOwner(
|
||||
identifier=shield_id,
|
||||
provider_resource_id=provider_shield_id,
|
||||
provider_id=provider_id,
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||
from llama_stack.distribution.datatypes import ToolGroupWithACL
|
||||
from llama_stack.distribution.datatypes import ToolGroupWithOwner
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
@ -88,7 +88,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
mcp_endpoint: URL | None = None,
|
||||
args: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
toolgroup = ToolGroupWithACL(
|
||||
toolgroup = ToolGroupWithOwner(
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
|
|
|
@ -10,7 +10,7 @@ from llama_stack.apis.models import ModelType
|
|||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.distribution.datatypes import (
|
||||
VectorDBWithACL,
|
||||
VectorDBWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -63,7 +63,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
||||
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
return vector_db
|
||||
|
||||
|
|
|
@ -105,24 +105,16 @@ class AuthenticationMiddleware:
|
|||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if validation_result.access_attributes:
|
||||
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to token by default")
|
||||
user_attributes = {
|
||||
"roles": [token],
|
||||
}
|
||||
|
||||
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
||||
# can identify the requester and enforce per-client rate limits.
|
||||
scope["authenticated_client_id"] = token
|
||||
|
||||
# Store attributes in request scope
|
||||
scope["user_attributes"] = user_attributes
|
||||
scope["principal"] = validation_result.principal
|
||||
if validation_result.attributes:
|
||||
scope["user_attributes"] = validation_result.attributes
|
||||
logger.debug(
|
||||
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
|
||||
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
||||
)
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
|
|
@ -16,43 +16,18 @@ from jose import jwt
|
|||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class TokenValidationResult(BaseModel):
|
||||
principal: str | None = Field(
|
||||
default=None,
|
||||
description="The principal (username or persistent identifier) of the authenticated user",
|
||||
)
|
||||
access_attributes: AccessAttributes | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Structured user attributes for attribute-based access control.
|
||||
|
||||
These attributes determine which resources the user can access.
|
||||
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
||||
Each attribute category contains a list of values that the user has for that category.
|
||||
During access control checks, these values are compared against resource requirements.
|
||||
|
||||
Example with standard categories:
|
||||
```json
|
||||
{
|
||||
"roles": ["admin", "data-scientist"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["llama-3"],
|
||||
"namespaces": ["research"]
|
||||
}
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class AuthResponse(TokenValidationResult):
|
||||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
principal: str
|
||||
# further attributes that may be used for access control decisions
|
||||
attributes: dict[str, list[str]] | None = None
|
||||
message: str | None = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
@ -78,7 +53,7 @@ class AuthProvider(ABC):
|
|||
"""Abstract base class for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token and return access attributes."""
|
||||
pass
|
||||
|
||||
|
@ -88,10 +63,10 @@ class AuthProvider(ABC):
|
|||
pass
|
||||
|
||||
|
||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
||||
attributes = AccessAttributes()
|
||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
||||
attributes: dict[str, list[str]] = {}
|
||||
for claim_key, attribute_key in mapping.items():
|
||||
if claim_key not in claims or not hasattr(attributes, attribute_key):
|
||||
if claim_key not in claims:
|
||||
continue
|
||||
claim = claims[claim_key]
|
||||
if isinstance(claim, list):
|
||||
|
@ -99,11 +74,10 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
|||
else:
|
||||
values = claim.split()
|
||||
|
||||
current = getattr(attributes, attribute_key)
|
||||
if current:
|
||||
current.extend(values)
|
||||
if attribute_key in attributes:
|
||||
attributes[attribute_key].extend(values)
|
||||
else:
|
||||
setattr(attributes, attribute_key, values)
|
||||
attributes[attribute_key] = values
|
||||
return attributes
|
||||
|
||||
|
||||
|
@ -145,8 +119,6 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
|
|||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
if value not in AccessAttributes.model_fields:
|
||||
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
@ -171,14 +143,14 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
self._jwks: dict[str, str] = {}
|
||||
self._jwks_lock = Lock()
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
if self.config.jwks:
|
||||
return await self.validate_jwt_token(token, scope)
|
||||
if self.config.introspection:
|
||||
return await self.introspect_token(token, scope)
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
|
||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using the JWT token."""
|
||||
await self._refresh_jwks()
|
||||
|
||||
|
@ -203,12 +175,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
# We should incorporate these into the access attributes.
|
||||
principal = claims["sub"]
|
||||
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||
return TokenValidationResult(
|
||||
return User(
|
||||
principal=principal,
|
||||
access_attributes=access_attributes,
|
||||
attributes=access_attributes,
|
||||
)
|
||||
|
||||
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
async def introspect_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using token introspection as defined by RFC 7662."""
|
||||
form = {
|
||||
"token": token,
|
||||
|
@ -242,9 +214,9 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
raise ValueError("Token not active")
|
||||
principal = fields["sub"] or fields["username"]
|
||||
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
||||
return TokenValidationResult(
|
||||
return User(
|
||||
principal=principal,
|
||||
access_attributes=access_attributes,
|
||||
attributes=access_attributes,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Token introspection request timed out")
|
||||
|
@ -299,7 +271,7 @@ class CustomAuthProvider(AuthProvider):
|
|||
self.config = config
|
||||
self._client = None
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using the custom authentication endpoint."""
|
||||
if scope is None:
|
||||
scope = {}
|
||||
|
@ -333,6 +305,7 @@ class CustomAuthProvider(AuthProvider):
|
|||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
print("MADE CALL")
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Authentication failed: {response.status_code}")
|
||||
|
@ -341,7 +314,7 @@ class CustomAuthProvider(AuthProvider):
|
|||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
return auth_response
|
||||
return User(auth_response.principal, auth_response.attributes)
|
||||
except Exception as e:
|
||||
logger.exception("Error parsing authentication response")
|
||||
raise ValueError("Invalid authentication response format") from e
|
||||
|
|
|
@ -11,7 +11,8 @@ from datetime import datetime, timezone
|
|||
|
||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule
|
||||
from llama_stack.distribution.access_control.datatypes import AccessRule
|
||||
from llama_stack.distribution.datatypes import User
|
||||
from llama_stack.distribution.request_headers import get_authenticated_user
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
|
@ -22,7 +23,7 @@ class AgentSessionInfo(Session):
|
|||
# TODO: is this used anywhere?
|
||||
vector_db_id: str | None = None
|
||||
started_at: datetime
|
||||
access_attributes: AccessAttributes | None = None
|
||||
owner: User | None = None
|
||||
identifier: str | None = None
|
||||
type: str = "session"
|
||||
|
||||
|
@ -42,14 +43,12 @@ class AgentPersistence:
|
|||
|
||||
# Get current user's auth attributes for new sessions
|
||||
user = get_authenticated_user()
|
||||
auth_attributes = user and user.attributes
|
||||
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
|
||||
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
access_attributes=access_attributes,
|
||||
owner=user,
|
||||
turns=[],
|
||||
identifier=name, # should this be qualified in any way?
|
||||
)
|
||||
|
@ -80,7 +79,7 @@ class AgentPersistence:
|
|||
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
||||
"""Check if current user has access to the session."""
|
||||
# Handle backward compatibility for old sessions without access control
|
||||
if not hasattr(session_info, "access_attributes"):
|
||||
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
||||
return True
|
||||
|
||||
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
|
||||
|
|
|
@ -12,8 +12,7 @@ import pytest
|
|||
|
||||
from llama_stack.apis.agents import Turn
|
||||
from llama_stack.apis.inference import CompletionMessage, StopReason
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.request_headers import User
|
||||
from llama_stack.distribution.datatypes import User
|
||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
||||
|
||||
|
||||
|
@ -38,9 +37,10 @@ async def test_session_creation_with_access_attributes(mock_get_authenticated_us
|
|||
# Get the session and verify access attributes were set
|
||||
session_info = await agent_persistence.get_session_info(session_id)
|
||||
assert session_info is not None
|
||||
assert session_info.access_attributes is not None
|
||||
assert session_info.access_attributes.roles == ["researcher"]
|
||||
assert session_info.access_attributes.teams == ["ai-team"]
|
||||
assert session_info.owner is not None
|
||||
assert session_info.owner.attributes is not None
|
||||
assert session_info.owner.attributes["roles"] == ["researcher"]
|
||||
assert session_info.owner.attributes["teams"] == ["ai-team"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -54,7 +54,7 @@ async def test_session_access_control(mock_get_authenticated_user, test_setup):
|
|||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
|
||||
owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
@ -89,7 +89,7 @@ async def test_turn_access_control(mock_get_authenticated_user, test_setup):
|
|||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
owner=User("someone", {"roles": ["admin"]}),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
@ -143,7 +143,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_u
|
|||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
owner=User("someone", {"roles": ["admin"]}),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
|
|
@ -8,19 +8,18 @@
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import ModelWithACL
|
||||
from llama_stack.distribution.server.auth_providers import AccessAttributes
|
||||
from llama_stack.distribution.datatypes import ModelWithOwner, User
|
||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="model-acl",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-acl-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
|
||||
owner=User("testuser", {"roles": ["admin"], "teams": ["ai-team"]}),
|
||||
)
|
||||
|
||||
success = await cached_disk_dist_registry.register(model)
|
||||
|
@ -29,22 +28,14 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
|||
cached_model = cached_disk_dist_registry.get_cached("model", "model-acl")
|
||||
assert cached_model is not None
|
||||
assert cached_model.identifier == "model-acl"
|
||||
assert cached_model.access_attributes.roles == ["admin"]
|
||||
assert cached_model.access_attributes.teams == ["ai-team"]
|
||||
assert cached_model.owner.principal == "testuser"
|
||||
assert cached_model.owner.attributes["roles"] == ["admin"]
|
||||
assert cached_model.owner.attributes["teams"] == ["ai-team"]
|
||||
|
||||
fetched_model = await cached_disk_dist_registry.get("model", "model-acl")
|
||||
assert fetched_model is not None
|
||||
assert fetched_model.identifier == "model-acl"
|
||||
assert fetched_model.access_attributes.roles == ["admin"]
|
||||
|
||||
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
|
||||
await cached_disk_dist_registry.update(model)
|
||||
|
||||
updated_cached = cached_disk_dist_registry.get_cached("model", "model-acl")
|
||||
assert updated_cached is not None
|
||||
assert updated_cached.access_attributes.roles == ["admin", "user"]
|
||||
assert updated_cached.access_attributes.projects == ["project-x"]
|
||||
assert updated_cached.access_attributes.teams is None
|
||||
assert fetched_model.owner.attributes["roles"] == ["admin"]
|
||||
|
||||
new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
|
||||
await new_registry.initialize()
|
||||
|
@ -52,35 +43,32 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
|||
new_model = await new_registry.get("model", "model-acl")
|
||||
assert new_model is not None
|
||||
assert new_model.identifier == "model-acl"
|
||||
assert new_model.access_attributes.roles == ["admin", "user"]
|
||||
assert new_model.access_attributes.projects == ["project-x"]
|
||||
assert new_model.access_attributes.teams is None
|
||||
assert new_model.owner.principal == "testuser"
|
||||
assert new_model.owner.attributes["roles"] == ["admin"]
|
||||
assert new_model.owner.attributes["teams"] == ["ai-team"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_empty_acl(cached_disk_dist_registry):
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="model-empty-acl",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(),
|
||||
owner=User("testuser", None),
|
||||
)
|
||||
|
||||
await cached_disk_dist_registry.register(model)
|
||||
|
||||
cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl")
|
||||
assert cached_model is not None
|
||||
assert cached_model.access_attributes is not None
|
||||
assert cached_model.access_attributes.roles is None
|
||||
assert cached_model.access_attributes.teams is None
|
||||
assert cached_model.access_attributes.projects is None
|
||||
assert cached_model.access_attributes.namespaces is None
|
||||
assert cached_model.owner is not None
|
||||
assert cached_model.owner.attributes is None
|
||||
|
||||
all_models = await cached_disk_dist_registry.get_all()
|
||||
assert len(all_models) == 1
|
||||
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="model-no-acl",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource-2",
|
||||
|
@ -91,7 +79,7 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
|
|||
|
||||
cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl")
|
||||
assert cached_model is not None
|
||||
assert cached_model.access_attributes is None
|
||||
assert cached_model.owner is None
|
||||
|
||||
all_models = await cached_disk_dist_registry.get_all()
|
||||
assert len(all_models) == 2
|
||||
|
@ -99,19 +87,19 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_serialization(cached_disk_dist_registry):
|
||||
attributes = AccessAttributes(
|
||||
roles=["admin", "researcher"],
|
||||
teams=["ai-team", "ml-team"],
|
||||
projects=["project-a", "project-b"],
|
||||
namespaces=["prod", "staging"],
|
||||
)
|
||||
attributes = {
|
||||
"roles": ["admin", "researcher"],
|
||||
"teams": ["ai-team", "ml-team"],
|
||||
"projects": ["project-a", "project-b"],
|
||||
"namespaces": ["prod", "staging"],
|
||||
}
|
||||
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="model-serialize",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=attributes,
|
||||
owner=User("bob", attributes),
|
||||
)
|
||||
|
||||
await cached_disk_dist_registry.register(model)
|
||||
|
@ -122,7 +110,7 @@ async def test_registry_serialization(cached_disk_dist_registry):
|
|||
loaded_model = await new_registry.get("model", "model-serialize")
|
||||
assert loaded_model is not None
|
||||
|
||||
assert loaded_model.access_attributes.roles == ["admin", "researcher"]
|
||||
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"]
|
||||
assert loaded_model.access_attributes.projects == ["project-a", "project-b"]
|
||||
assert loaded_model.access_attributes.namespaces == ["prod", "staging"]
|
||||
assert loaded_model.owner.attributes["roles"] == ["admin", "researcher"]
|
||||
assert loaded_model.owner.attributes["teams"] == ["ai-team", "ml-team"]
|
||||
assert loaded_model.owner.attributes["projects"] == ["project-a", "project-b"]
|
||||
assert loaded_model.owner.attributes["namespaces"] == ["prod", "staging"]
|
||||
|
|
|
@ -8,13 +8,12 @@ from unittest.mock import MagicMock, Mock, patch
|
|||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, AccessRule, ModelWithACL
|
||||
from llama_stack.distribution.request_headers import User
|
||||
from llama_stack.distribution.datatypes import AccessRule, ModelWithOwner, User
|
||||
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
||||
|
||||
|
||||
|
@ -45,25 +44,25 @@ async def test_setup(cached_disk_dist_registry):
|
|||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
model_public = ModelWithOwner(
|
||||
identifier="model-public",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-public",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
model_admin_only = ModelWithACL(
|
||||
model_admin_only = ModelWithOwner(
|
||||
identifier="model-admin",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-admin",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
owner=User("testuser", {"roles": ["admin"]}),
|
||||
)
|
||||
model_data_scientist = ModelWithACL(
|
||||
model_data_scientist = ModelWithOwner(
|
||||
identifier="model-data-scientist",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-data-scientist",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]),
|
||||
owner=User("testuser", {"roles": ["data-scientist", "researcher"], "teams": ["ml-team"]}),
|
||||
)
|
||||
await registry.register(model_public)
|
||||
await registry.register(model_admin_only)
|
||||
|
@ -110,7 +109,7 @@ async def test_access_control_with_cache(mock_get_authenticated_user, test_setup
|
|||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
model_public = ModelWithOwner(
|
||||
identifier="model-updates",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-updates",
|
||||
|
@ -125,7 +124,7 @@ async def test_access_control_and_updates(mock_get_authenticated_user, test_setu
|
|||
)
|
||||
model = await routing_table.get_model("model-updates")
|
||||
assert model.identifier == "model-updates"
|
||||
model_public.access_attributes = AccessAttributes(roles=["admin"])
|
||||
model_public.owner = User("testuser", {"roles": ["admin"]})
|
||||
await registry.update(model_public)
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
|
@ -149,12 +148,12 @@ async def test_access_control_and_updates(mock_get_authenticated_user, test_setu
|
|||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="model-empty-attrs",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-empty-attrs",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(),
|
||||
owner=User("testuser", {}),
|
||||
)
|
||||
await registry.register(model)
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
|
@ -174,18 +173,18 @@ async def test_access_control_empty_attributes(mock_get_authenticated_user, test
|
|||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
model_public = ModelWithOwner(
|
||||
identifier="model-public-2",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-public-2",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
model_restricted = ModelWithACL(
|
||||
model_restricted = ModelWithOwner(
|
||||
identifier="model-restricted",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-restricted",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
owner=User("testuser", {"roles": ["admin"]}),
|
||||
)
|
||||
await registry.register(model_public)
|
||||
await registry.register(model_restricted)
|
||||
|
@ -212,7 +211,7 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set
|
|||
mock_get_authenticated_user.return_value = User("test-user", creator_attributes)
|
||||
|
||||
# Create model without explicit access attributes
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="auto-access-model",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="auto-access-model",
|
||||
|
@ -222,10 +221,11 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set
|
|||
|
||||
# Verify the model got creator's attributes
|
||||
registered_model = await routing_table.get_model("auto-access-model")
|
||||
assert registered_model.access_attributes is not None
|
||||
assert registered_model.access_attributes.roles == ["data-scientist"]
|
||||
assert registered_model.access_attributes.teams == ["ml-team"]
|
||||
assert registered_model.access_attributes.projects == ["llama-3"]
|
||||
assert registered_model.owner is not None
|
||||
assert registered_model.owner.attributes is not None
|
||||
assert registered_model.owner.attributes["roles"] == ["data-scientist"]
|
||||
assert registered_model.owner.attributes["teams"] == ["ml-team"]
|
||||
assert registered_model.owner.attributes["projects"] == ["llama-3"]
|
||||
|
||||
# Verify another user without matching attributes can't access it
|
||||
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
|
||||
|
@ -354,15 +354,14 @@ def test_permit_when():
|
|||
- permit:
|
||||
principal: user-1
|
||||
actions: [read]
|
||||
when:
|
||||
user_in: resource.namespaces
|
||||
when: user in owners namespaces
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||
)
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
|
@ -376,15 +375,15 @@ def test_permit_unless():
|
|||
actions: [read]
|
||||
resource: model::*
|
||||
unless:
|
||||
- user_not_in: resource.namespaces
|
||||
- user_in: resource.teams
|
||||
- user not in owners namespaces
|
||||
- user in owners teams
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||
)
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
|
@ -397,16 +396,16 @@ def test_forbid_when():
|
|||
principal: user-1
|
||||
actions: [read]
|
||||
when:
|
||||
user_in: resource.namespaces
|
||||
user in owners namespaces
|
||||
- permit:
|
||||
actions: [read]
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||
)
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
|
@ -419,35 +418,33 @@ def test_forbid_unless():
|
|||
principal: user-1
|
||||
actions: [read]
|
||||
unless:
|
||||
user_in: resource.namespaces
|
||||
user in owners namespaces
|
||||
- permit:
|
||||
actions: [read]
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||
)
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||
|
||||
|
||||
def test_condition_with_literal():
|
||||
def test_user_has_attribute():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
when:
|
||||
user_in: role::admin
|
||||
when: user with admin in roles
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
)
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||
|
@ -455,35 +452,115 @@ def test_condition_with_literal():
|
|||
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||
|
||||
|
||||
def test_condition_with_unrecognised_literal():
|
||||
def test_user_does_not_have_attribute():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
when:
|
||||
user_in: whatever
|
||||
unless: user with admin not in roles
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
||||
)
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-2", None))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||
|
||||
|
||||
def test_empty_condition():
|
||||
def test_is_owner():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
when: {}
|
||||
when: user is owner
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
owner=User("user-2", {"namespaces": ["foo"]}),
|
||||
)
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", None))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||
|
||||
|
||||
def test_is_not_owner():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
unless: user is not owner
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
owner=User("user-2", {"namespaces": ["foo"]}),
|
||||
)
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||
|
||||
|
||||
def test_invalid_rule_permit_and_forbid_both_specified():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
forbid:
|
||||
actions: [create]
|
||||
"""
|
||||
with pytest.raises(ValidationError):
|
||||
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
|
||||
|
||||
def test_invalid_rule_neither_permit_or_forbid_specified():
|
||||
config = """
|
||||
- when: user is owner
|
||||
unless: user with admin in roles
|
||||
"""
|
||||
with pytest.raises(ValidationError):
|
||||
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
|
||||
|
||||
def test_invalid_rule_when_and_unless_both_specified():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
when: user is owner
|
||||
unless: user with admin in roles
|
||||
"""
|
||||
with pytest.raises(ValidationError):
|
||||
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
|
||||
|
||||
def test_invalid_condition():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
when: random words that are not valid
|
||||
"""
|
||||
with pytest.raises(ValidationError):
|
||||
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"condition",
|
||||
[
|
||||
"user is owner",
|
||||
"user is not owner",
|
||||
"user with dev in teams",
|
||||
"user with default not in namespaces",
|
||||
"user in owners roles",
|
||||
"user not in owners projects",
|
||||
],
|
||||
)
|
||||
def test_condition_reprs(condition):
|
||||
from llama_stack.distribution.access_control.conditions import parse_condition
|
||||
|
||||
assert condition == str(parse_condition(condition))
|
||||
|
|
|
@ -139,7 +139,7 @@ async def mock_post_success(*args, **kwargs):
|
|||
{
|
||||
"message": "Authentication successful",
|
||||
"principal": "test-principal",
|
||||
"access_attributes": {
|
||||
"attributes": {
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
|
@ -233,7 +233,7 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
|||
{
|
||||
"message": "Authentication successful",
|
||||
"principal": "test-principal",
|
||||
"access_attributes": {
|
||||
"attributes": {
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
|
@ -255,33 +255,6 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
|||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope):
|
||||
"""Test middleware behavior with no access attributes"""
|
||||
middleware, mock_app = mock_http_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
mock_client_instance.post.return_value = MockResponse(
|
||||
200,
|
||||
{
|
||||
"message": "Authentication successful"
|
||||
# No access_attributes
|
||||
},
|
||||
)
|
||||
|
||||
await middleware(mock_scope, mock_receive, mock_send)
|
||||
|
||||
assert "user_attributes" in mock_scope
|
||||
attributes = mock_scope["user_attributes"]
|
||||
assert "roles" in attributes
|
||||
assert attributes["roles"] == ["test.jwt.token"]
|
||||
|
||||
|
||||
# oauth2 token provider tests
|
||||
|
||||
|
||||
|
@ -380,16 +353,16 @@ def test_get_attributes_from_claims():
|
|||
"aud": "llama-stack",
|
||||
}
|
||||
attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"})
|
||||
assert attributes.roles == ["my-user"]
|
||||
assert attributes.teams == ["group1", "group2"]
|
||||
assert attributes["roles"] == ["my-user"]
|
||||
assert attributes["teams"] == ["group1", "group2"]
|
||||
|
||||
claims = {
|
||||
"sub": "my-user",
|
||||
"tenant": "my-tenant",
|
||||
}
|
||||
attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"})
|
||||
assert attributes.roles == ["my-user"]
|
||||
assert attributes.namespaces == ["my-tenant"]
|
||||
assert attributes["roles"] == ["my-user"]
|
||||
assert attributes["namespaces"] == ["my-tenant"]
|
||||
|
||||
claims = {
|
||||
"sub": "my-user",
|
||||
|
@ -408,9 +381,9 @@ def test_get_attributes_from_claims():
|
|||
"groups": "teams",
|
||||
},
|
||||
)
|
||||
assert set(attributes.roles) == {"my-user", "my-username"}
|
||||
assert set(attributes.teams) == {"my-team", "group1", "group2"}
|
||||
assert attributes.namespaces == ["my-tenant"]
|
||||
assert set(attributes["roles"]) == {"my-user", "my-username"}
|
||||
assert set(attributes["teams"]) == {"my-team", "group1", "group2"}
|
||||
assert attributes["namespaces"] == ["my-tenant"]
|
||||
|
||||
|
||||
# TODO: add more tests for oauth2 token provider
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue