mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Changes to access rule conditions:
* change from access_attributes to owner on dynamically created resources * define simpler string based conditions for more intuitive restriction
This commit is contained in:
parent
01ad876012
commit
96cd51a0c8
20 changed files with 427 additions and 431 deletions
|
@ -4,16 +4,18 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Protocol
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.request_headers import User
|
from llama_stack.distribution.datatypes import User
|
||||||
|
|
||||||
|
from .conditions import (
|
||||||
|
Condition,
|
||||||
|
ProtectedResource,
|
||||||
|
parse_conditions,
|
||||||
|
)
|
||||||
from .datatypes import (
|
from .datatypes import (
|
||||||
AccessAttributes,
|
|
||||||
AccessRule,
|
AccessRule,
|
||||||
Action,
|
Action,
|
||||||
AttributeReference,
|
|
||||||
Condition,
|
|
||||||
Scope,
|
Scope,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -37,43 +39,6 @@ def matches_scope(
|
||||||
return action in scope.actions
|
return action in scope.actions
|
||||||
|
|
||||||
|
|
||||||
def user_in_literal(
|
|
||||||
literal: str,
|
|
||||||
user_attributes: dict[str, list[str]] | None,
|
|
||||||
) -> bool:
|
|
||||||
for qualifier in ["role::", "team::", "project::", "namespace::"]:
|
|
||||||
if literal.startswith(qualifier):
|
|
||||||
if not user_attributes:
|
|
||||||
return False
|
|
||||||
ref = qualifier.replace("::", "s")
|
|
||||||
if ref in user_attributes:
|
|
||||||
value = literal.removeprefix(qualifier)
|
|
||||||
return value in user_attributes[ref]
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def user_in(
|
|
||||||
ref: AttributeReference | str,
|
|
||||||
resource_attributes: AccessAttributes | None,
|
|
||||||
user_attributes: dict[str, list[str]] | None,
|
|
||||||
) -> bool:
|
|
||||||
if not ref.startswith("resource."):
|
|
||||||
return user_in_literal(ref, user_attributes)
|
|
||||||
name = ref.removeprefix("resource.")
|
|
||||||
required = resource_attributes and getattr(resource_attributes, name)
|
|
||||||
if not required:
|
|
||||||
return True
|
|
||||||
if not user_attributes or name not in user_attributes:
|
|
||||||
return False
|
|
||||||
actual = user_attributes[name]
|
|
||||||
for value in required:
|
|
||||||
if value in actual:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def as_list(obj: Any) -> list[Any]:
|
def as_list(obj: Any) -> list[Any]:
|
||||||
if isinstance(obj, list):
|
if isinstance(obj, list):
|
||||||
return obj
|
return obj
|
||||||
|
@ -82,55 +47,27 @@ def as_list(obj: Any) -> list[Any]:
|
||||||
|
|
||||||
def matches_conditions(
|
def matches_conditions(
|
||||||
conditions: list[Condition],
|
conditions: list[Condition],
|
||||||
resource_attributes: AccessAttributes | None,
|
resource: ProtectedResource,
|
||||||
user_attributes: dict[str, list[str]] | None,
|
user: User,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
for condition in conditions:
|
for condition in conditions:
|
||||||
# must match all conditions
|
# must match all conditions
|
||||||
if not matches_condition(condition, resource_attributes, user_attributes):
|
if not condition.matches(resource, user):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def matches_condition(
|
|
||||||
condition: Condition | list[Condition],
|
|
||||||
resource_attributes: AccessAttributes | None,
|
|
||||||
user_attributes: dict[str, list[str]] | None,
|
|
||||||
) -> bool:
|
|
||||||
if isinstance(condition, list):
|
|
||||||
return matches_conditions(as_list(condition), resource_attributes, user_attributes)
|
|
||||||
if condition.user_in:
|
|
||||||
for ref in as_list(condition.user_in):
|
|
||||||
# if multiple references are specified, all must match
|
|
||||||
if not user_in(ref, resource_attributes, user_attributes):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
if condition.user_not_in:
|
|
||||||
for ref in as_list(condition.user_not_in):
|
|
||||||
# if multiple references are specified, none must match
|
|
||||||
if user_in(ref, resource_attributes, user_attributes):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def default_policy() -> list[AccessRule]:
|
def default_policy() -> list[AccessRule]:
|
||||||
# for backwards compatibility, if no rules are provided , assume
|
# for backwards compatibility, if no rules are provided, assume
|
||||||
# full access to all subject to attribute matching rules
|
# full access subject to previous attribute matching rules
|
||||||
return [
|
return [
|
||||||
AccessRule(
|
AccessRule(
|
||||||
permit=Scope(actions=list(Action)),
|
permit=Scope(actions=list(Action)),
|
||||||
when=Condition(user_in=list(AttributeReference)),
|
when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]],
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ProtectedResource(Protocol):
|
|
||||||
type: str
|
|
||||||
identifier: str
|
|
||||||
access_attributes: AccessAttributes
|
|
||||||
|
|
||||||
|
|
||||||
def is_action_allowed(
|
def is_action_allowed(
|
||||||
policy: list[AccessRule],
|
policy: list[AccessRule],
|
||||||
action: Action,
|
action: Action,
|
||||||
|
@ -144,26 +81,23 @@ def is_action_allowed(
|
||||||
if not len(policy):
|
if not len(policy):
|
||||||
policy = default_policy()
|
policy = default_policy()
|
||||||
|
|
||||||
resource_attributes = AccessAttributes()
|
|
||||||
if hasattr(resource, "access_attributes"):
|
|
||||||
resource_attributes = resource.access_attributes
|
|
||||||
qualified_resource_id = resource.type + "::" + resource.identifier
|
qualified_resource_id = resource.type + "::" + resource.identifier
|
||||||
for rule in policy:
|
for rule in policy:
|
||||||
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
|
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
|
||||||
if rule.when:
|
if rule.when:
|
||||||
if matches_condition(rule.when, resource_attributes, user.attributes):
|
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
|
||||||
return False
|
return False
|
||||||
elif rule.unless:
|
elif rule.unless:
|
||||||
if not matches_condition(rule.unless, resource_attributes, user.attributes):
|
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
|
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
|
||||||
if rule.when:
|
if rule.when:
|
||||||
if matches_condition(rule.when, resource_attributes, user.attributes):
|
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
|
||||||
return True
|
return True
|
||||||
elif rule.unless:
|
elif rule.unless:
|
||||||
if not matches_condition(rule.unless, resource_attributes, user.attributes):
|
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
129
llama_stack/distribution/access_control/conditions.py
Normal file
129
llama_stack/distribution/access_control/conditions.py
Normal file
|
@ -0,0 +1,129 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
|
||||||
|
class User(Protocol):
|
||||||
|
principal: str
|
||||||
|
attributes: dict[str, list[str]] | None
|
||||||
|
|
||||||
|
|
||||||
|
class ProtectedResource(Protocol):
|
||||||
|
type: str
|
||||||
|
identifier: str
|
||||||
|
owner: User
|
||||||
|
|
||||||
|
|
||||||
|
class Condition(Protocol):
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
class UserInOwnersList:
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def owners_values(self, resource: ProtectedResource) -> list[str] | None:
|
||||||
|
if (
|
||||||
|
hasattr(resource, "owner")
|
||||||
|
and resource.owner
|
||||||
|
and resource.owner.attributes
|
||||||
|
and self.name in resource.owner.attributes
|
||||||
|
):
|
||||||
|
return resource.owner.attributes[self.name]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
required = self.owners_values(resource)
|
||||||
|
if not required:
|
||||||
|
return True
|
||||||
|
if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]:
|
||||||
|
return False
|
||||||
|
user_values = user.attributes[self.name]
|
||||||
|
for value in required:
|
||||||
|
if value in user_values:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"user in owners {self.name}"
|
||||||
|
|
||||||
|
|
||||||
|
class UserNotInOwnersList(UserInOwnersList):
|
||||||
|
def __init__(self, name: str):
|
||||||
|
super().__init__(name)
|
||||||
|
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
return not super().matches(resource, user)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"user not in owners {self.name}"
|
||||||
|
|
||||||
|
|
||||||
|
class UserWithValueInList:
|
||||||
|
def __init__(self, name: str, value: str):
|
||||||
|
self.name = name
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
if user.attributes and self.name in user.attributes:
|
||||||
|
return self.value in user.attributes[self.name]
|
||||||
|
print(f"User does not have {self.value} in {self.name}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"user with {self.value} in {self.name}"
|
||||||
|
|
||||||
|
|
||||||
|
class UserWithValueNotInList(UserWithValueInList):
|
||||||
|
def __init__(self, name: str, value: str):
|
||||||
|
super().__init__(name, value)
|
||||||
|
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
return not super().matches(resource, user)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"user with {self.value} not in {self.name}"
|
||||||
|
|
||||||
|
|
||||||
|
class UserIsOwner:
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
return resource.owner.principal == user.principal if resource.owner else False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "user is owner"
|
||||||
|
|
||||||
|
|
||||||
|
class UserIsNotOwner:
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
return not resource.owner or resource.owner.principal != user.principal
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "user is not owner"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_condition(condition: str) -> Condition:
|
||||||
|
words = condition.split()
|
||||||
|
match words:
|
||||||
|
case ["user", "is", "owner"]:
|
||||||
|
return UserIsOwner()
|
||||||
|
case ["user", "is", "not", "owner"]:
|
||||||
|
return UserIsNotOwner()
|
||||||
|
case ["user", "with", value, "in", name]:
|
||||||
|
return UserWithValueInList(name, value)
|
||||||
|
case ["user", "with", value, "not", "in", name]:
|
||||||
|
return UserWithValueNotInList(name, value)
|
||||||
|
case ["user", "in", "owners", name]:
|
||||||
|
return UserInOwnersList(name)
|
||||||
|
case ["user", "not", "in", "owners", name]:
|
||||||
|
return UserNotInOwnersList(name)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid condition: {condition}")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_conditions(conditions: list[str]) -> list[Condition]:
|
||||||
|
return [parse_condition(c) for c in conditions]
|
|
@ -6,37 +6,10 @@
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from .conditions import parse_conditions
|
||||||
class AccessAttributes(BaseModel):
|
|
||||||
"""Structured representation of user attributes for access control.
|
|
||||||
|
|
||||||
This model defines a structured approach to representing user attributes
|
|
||||||
with common standard categories for access control.
|
|
||||||
|
|
||||||
Standard attribute categories include:
|
|
||||||
- roles: Role-based attributes (e.g., admin, data-scientist)
|
|
||||||
- teams: Team-based attributes (e.g., ml-team, infra-team)
|
|
||||||
- projects: Project access attributes (e.g., llama-3, customer-insights)
|
|
||||||
- namespaces: Namespace-based access control for resource isolation
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Standard attribute categories - the minimal set we need now
|
|
||||||
roles: list[str] | None = Field(
|
|
||||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
|
||||||
)
|
|
||||||
|
|
||||||
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
|
||||||
|
|
||||||
projects: list[str] | None = Field(
|
|
||||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
|
||||||
)
|
|
||||||
|
|
||||||
namespaces: list[str] | None = Field(
|
|
||||||
default=None, description="Namespace-based access control for resource isolation"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Action(str, Enum):
|
class Action(str, Enum):
|
||||||
|
@ -62,18 +35,6 @@ def _require_one_of(obj, a: str, b: str):
|
||||||
raise ValueError(f"on of {a} or {b} is required")
|
raise ValueError(f"on of {a} or {b} is required")
|
||||||
|
|
||||||
|
|
||||||
class AttributeReference(str, Enum):
|
|
||||||
RESOURCE_ROLES = "resource.roles"
|
|
||||||
RESOURCE_TEAMS = "resource.teams"
|
|
||||||
RESOURCE_PROJECTS = "resource.projects"
|
|
||||||
RESOURCE_NAMESPACES = "resource.namespaces"
|
|
||||||
|
|
||||||
|
|
||||||
class Condition(BaseModel):
|
|
||||||
user_in: AttributeReference | list[AttributeReference] | str | None = None
|
|
||||||
user_not_in: AttributeReference | list[AttributeReference] | str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class AccessRule(BaseModel):
|
class AccessRule(BaseModel):
|
||||||
"""Access rule based loosely on cedar policy language
|
"""Access rule based loosely on cedar policy language
|
||||||
|
|
||||||
|
@ -85,10 +46,14 @@ class AccessRule(BaseModel):
|
||||||
requests.
|
requests.
|
||||||
|
|
||||||
A rule may also specify a condition, either a 'when' or an 'unless', with additional
|
A rule may also specify a condition, either a 'when' or an 'unless', with additional
|
||||||
constraints as to where the rule applies. The constraints at present are whether the
|
constraints as to where the rule applies. The constraints supported at present are:
|
||||||
user requesting access is in or not in some set. This set can either be a particular
|
|
||||||
set of attributes on the resource e.g. resource.roles or a literal value of some
|
- 'user with <attr-value> in <attr-name>'
|
||||||
notion of group, e.g. role::admin or namespace::foo.
|
- 'user with <attr-value> not in <attr-name>'
|
||||||
|
- 'user is owner'
|
||||||
|
- 'user is not owner'
|
||||||
|
- 'user in owners <attr-name>'
|
||||||
|
- 'user not in owners <attr-name>'
|
||||||
|
|
||||||
Rules are tested in order to find a match. If a match is found, the request is
|
Rules are tested in order to find a match. If a match is found, the request is
|
||||||
permitted or forbidden depending on the type of rule. If no match is found, the
|
permitted or forbidden depending on the type of rule. If no match is found, the
|
||||||
|
@ -99,33 +64,31 @@ class AccessRule(BaseModel):
|
||||||
Some examples in yaml:
|
Some examples in yaml:
|
||||||
|
|
||||||
- permit:
|
- permit:
|
||||||
principal: user-1
|
principal: user-1
|
||||||
actions: [create, read, delete]
|
actions: [create, read, delete]
|
||||||
resource: model::*
|
resource: model::*
|
||||||
description: user-1 has full access to all models
|
description: user-1 has full access to all models
|
||||||
- permit:
|
- permit:
|
||||||
principal: user-2
|
principal: user-2
|
||||||
actions: [read]
|
actions: [read]
|
||||||
resource: model::model-1
|
resource: model::model-1
|
||||||
description: user-2 has read access to model-1 only
|
description: user-2 has read access to model-1 only
|
||||||
- permit:
|
- permit:
|
||||||
actions: [read]
|
actions: [read]
|
||||||
when:
|
when: user in owner teams
|
||||||
user_in: resource.namespaces
|
description: any user has read access to any resource created by a member of their team
|
||||||
description: any user has read access to any resource with matching attributes
|
|
||||||
- forbid:
|
- forbid:
|
||||||
actions: [create, read, delete]
|
actions: [create, read, delete]
|
||||||
resource: vector_db::*
|
resource: vector_db::*
|
||||||
unless:
|
unless: user with admin in roles
|
||||||
user_in: role::admin
|
|
||||||
description: only user with admin role can use vector_db resources
|
description: only user with admin role can use vector_db resources
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
permit: Scope | None = None
|
permit: Scope | None = None
|
||||||
forbid: Scope | None = None
|
forbid: Scope | None = None
|
||||||
when: Condition | list[Condition] | None = None
|
when: str | list[str] | None = None
|
||||||
unless: Condition | list[Condition] | None = None
|
unless: str | list[str] | None = None
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
@ -133,4 +96,12 @@ class AccessRule(BaseModel):
|
||||||
_require_one_of(self, "permit", "forbid")
|
_require_one_of(self, "permit", "forbid")
|
||||||
_mutually_exclusive(self, "permit", "forbid")
|
_mutually_exclusive(self, "permit", "forbid")
|
||||||
_mutually_exclusive(self, "when", "unless")
|
_mutually_exclusive(self, "when", "unless")
|
||||||
|
if isinstance(self.when, list):
|
||||||
|
parse_conditions(self.when)
|
||||||
|
elif self.when:
|
||||||
|
parse_conditions([self.when])
|
||||||
|
if isinstance(self.unless, list):
|
||||||
|
parse_conditions(self.unless)
|
||||||
|
elif self.unless:
|
||||||
|
parse_conditions([self.unless])
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -24,7 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput
|
||||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule
|
from llama_stack.distribution.access_control.datatypes import AccessRule
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
||||||
|
@ -36,97 +36,66 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||||
RoutingKey = str | list[str]
|
RoutingKey = str | list[str]
|
||||||
|
|
||||||
|
|
||||||
class ResourceWithACL(Resource):
|
class User(BaseModel):
|
||||||
"""Extension of Resource that adds attribute-based access control capabilities.
|
principal: str
|
||||||
|
# further attributes that may be used for access control decisions
|
||||||
|
attributes: dict[str, list[str]] | None = None
|
||||||
|
|
||||||
This class adds an optional access_attributes field that allows fine-grained control
|
def __init__(self, principal: str, attributes: dict[str, list[str]] | None):
|
||||||
over which users can access each resource. When attributes are defined, a user must have
|
super().__init__(principal=principal, attributes=attributes)
|
||||||
matching attributes to access the resource.
|
|
||||||
|
|
||||||
Attribute Matching Algorithm:
|
|
||||||
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
|
|
||||||
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
|
|
||||||
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
|
|
||||||
4. Within each category, ANY value match is sufficient (OR relationship within a category)
|
|
||||||
|
|
||||||
Examples:
|
class ResourceWithOwner(Resource):
|
||||||
# Resource visible to everyone (no access control)
|
"""Extension of Resource that adds an optional owner, i.e. the user that created the
|
||||||
model = Model(identifier="llama-2", ...)
|
resource. This can be used to constrain access to the resource."""
|
||||||
|
|
||||||
# Resource visible only to admins
|
owner: User | None = None
|
||||||
model = Model(
|
|
||||||
identifier="gpt-4",
|
|
||||||
access_attributes=AccessAttributes(roles=["admin"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Resource visible to data scientists on the ML team
|
|
||||||
model = Model(
|
|
||||||
identifier="private-model",
|
|
||||||
access_attributes=AccessAttributes(
|
|
||||||
roles=["data-scientist", "researcher"],
|
|
||||||
teams=["ml-team"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# ^ User must have at least one of the roles AND be on the ml-team
|
|
||||||
|
|
||||||
# Resource visible to users with specific project access
|
|
||||||
vector_db = VectorDB(
|
|
||||||
identifier="customer-embeddings",
|
|
||||||
access_attributes=AccessAttributes(
|
|
||||||
projects=["customer-insights"],
|
|
||||||
namespaces=["confidential"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# ^ User must have access to the customer-insights project AND have confidential namespace
|
|
||||||
"""
|
|
||||||
|
|
||||||
access_attributes: AccessAttributes | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# Use the extended Resource for all routable objects
|
# Use the extended Resource for all routable objects
|
||||||
class ModelWithACL(Model, ResourceWithACL):
|
class ModelWithOwner(Model, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ShieldWithACL(Shield, ResourceWithACL):
|
class ShieldWithOwner(Shield, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class VectorDBWithACL(VectorDB, ResourceWithACL):
|
class VectorDBWithOwner(VectorDB, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DatasetWithACL(Dataset, ResourceWithACL):
|
class DatasetWithOwner(Dataset, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
|
class ScoringFnWithOwner(ScoringFn, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkWithACL(Benchmark, ResourceWithACL):
|
class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolWithACL(Tool, ResourceWithACL):
|
class ToolWithOwner(Tool, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
||||||
|
|
||||||
RoutableObjectWithProvider = Annotated[
|
RoutableObjectWithProvider = Annotated[
|
||||||
ModelWithACL
|
ModelWithOwner
|
||||||
| ShieldWithACL
|
| ShieldWithOwner
|
||||||
| VectorDBWithACL
|
| VectorDBWithOwner
|
||||||
| DatasetWithACL
|
| DatasetWithOwner
|
||||||
| ScoringFnWithACL
|
| ScoringFnWithOwner
|
||||||
| BenchmarkWithACL
|
| BenchmarkWithOwner
|
||||||
| ToolWithACL
|
| ToolWithOwner
|
||||||
| ToolGroupWithACL,
|
| ToolGroupWithOwner,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,8 @@ import logging
|
||||||
from contextlib import AbstractContextManager
|
from contextlib import AbstractContextManager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import User
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -18,16 +20,6 @@ log = logging.getLogger(__name__)
|
||||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||||
|
|
||||||
|
|
||||||
class User:
|
|
||||||
principal: str
|
|
||||||
# further attributes that may be used for access control decisions
|
|
||||||
attributes: dict[str, list[str]]
|
|
||||||
|
|
||||||
def __init__(self, principal: str, attributes: dict[str, list[str]]):
|
|
||||||
self.principal = principal
|
|
||||||
self.attributes = attributes
|
|
||||||
|
|
||||||
|
|
||||||
class RequestProviderDataContext(AbstractContextManager):
|
class RequestProviderDataContext(AbstractContextManager):
|
||||||
"""Context manager for request provider data"""
|
"""Context manager for request provider data"""
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
BenchmarkWithACL,
|
BenchmarkWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
)
|
)
|
||||||
if provider_benchmark_id is None:
|
if provider_benchmark_id is None:
|
||||||
provider_benchmark_id = benchmark_id
|
provider_benchmark_id = benchmark_id
|
||||||
benchmark = BenchmarkWithACL(
|
benchmark = BenchmarkWithOwner(
|
||||||
identifier=benchmark_id,
|
identifier=benchmark_id,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
|
|
|
@ -10,7 +10,6 @@ from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
AccessAttributes,
|
|
||||||
AccessRule,
|
AccessRule,
|
||||||
RoutableObject,
|
RoutableObject,
|
||||||
RoutableObjectWithProvider,
|
RoutableObjectWithProvider,
|
||||||
|
@ -195,9 +194,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
creator = get_authenticated_user()
|
creator = get_authenticated_user()
|
||||||
if not is_action_allowed(self.policy, "create", obj, creator):
|
if not is_action_allowed(self.policy, "create", obj, creator):
|
||||||
raise AccessDeniedError()
|
raise AccessDeniedError()
|
||||||
if creator and creator.attributes:
|
if creator:
|
||||||
obj.access_attributes = AccessAttributes(**creator.attributes)
|
obj.owner = creator
|
||||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
|
||||||
|
|
||||||
registered_obj = await register_object_with_provider(obj, p)
|
registered_obj = await register_object_with_provider(obj, p)
|
||||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||||
|
|
|
@ -19,7 +19,7 @@ from llama_stack.apis.datasets import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
DatasetWithACL,
|
DatasetWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
dataset = DatasetWithACL(
|
dataset = DatasetWithOwner(
|
||||||
identifier=dataset_id,
|
identifier=dataset_id,
|
||||||
provider_resource_id=provider_dataset_id,
|
provider_resource_id=provider_dataset_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
ModelWithACL,
|
ModelWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
model_type = ModelType.llm
|
model_type = ModelType.llm
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.scoring_functions import (
|
||||||
ScoringFunctions,
|
ScoringFunctions,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
ScoringFnWithACL,
|
ScoringFnWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
)
|
)
|
||||||
scoring_fn = ScoringFnWithACL(
|
scoring_fn = ScoringFnWithOwner(
|
||||||
identifier=scoring_fn_id,
|
identifier=scoring_fn_id,
|
||||||
description=description,
|
description=description,
|
||||||
return_type=return_type,
|
return_type=return_type,
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
ShieldWithACL,
|
ShieldWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
)
|
)
|
||||||
if params is None:
|
if params is None:
|
||||||
params = {}
|
params = {}
|
||||||
shield = ShieldWithACL(
|
shield = ShieldWithOwner(
|
||||||
identifier=shield_id,
|
identifier=shield_id,
|
||||||
provider_resource_id=provider_shield_id,
|
provider_resource_id=provider_shield_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||||
from llama_stack.distribution.datatypes import ToolGroupWithACL
|
from llama_stack.distribution.datatypes import ToolGroupWithOwner
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
@ -88,7 +88,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
mcp_endpoint: URL | None = None,
|
mcp_endpoint: URL | None = None,
|
||||||
args: dict[str, Any] | None = None,
|
args: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
toolgroup = ToolGroupWithACL(
|
toolgroup = ToolGroupWithOwner(
|
||||||
identifier=toolgroup_id,
|
identifier=toolgroup_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_resource_id=toolgroup_id,
|
provider_resource_id=toolgroup_id,
|
||||||
|
|
|
@ -10,7 +10,7 @@ from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
VectorDBWithACL,
|
VectorDBWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
"embedding_model": embedding_model,
|
"embedding_model": embedding_model,
|
||||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||||
}
|
}
|
||||||
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||||
await self.register_object(vector_db)
|
await self.register_object(vector_db)
|
||||||
return vector_db
|
return vector_db
|
||||||
|
|
||||||
|
|
|
@ -105,24 +105,16 @@ class AuthenticationMiddleware:
|
||||||
logger.exception("Error during authentication")
|
logger.exception("Error during authentication")
|
||||||
return await self._send_auth_error(send, "Authentication service error")
|
return await self._send_auth_error(send, "Authentication service error")
|
||||||
|
|
||||||
# Store attributes in request scope for access control
|
|
||||||
if validation_result.access_attributes:
|
|
||||||
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
|
|
||||||
else:
|
|
||||||
logger.warning("No access attributes, setting namespace to token by default")
|
|
||||||
user_attributes = {
|
|
||||||
"roles": [token],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
||||||
# can identify the requester and enforce per-client rate limits.
|
# can identify the requester and enforce per-client rate limits.
|
||||||
scope["authenticated_client_id"] = token
|
scope["authenticated_client_id"] = token
|
||||||
|
|
||||||
# Store attributes in request scope
|
# Store attributes in request scope
|
||||||
scope["user_attributes"] = user_attributes
|
|
||||||
scope["principal"] = validation_result.principal
|
scope["principal"] = validation_result.principal
|
||||||
|
if validation_result.attributes:
|
||||||
|
scope["user_attributes"] = validation_result.attributes
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
|
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
|
@ -16,43 +16,18 @@ from jose import jwt
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType
|
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
|
||||||
|
|
||||||
class TokenValidationResult(BaseModel):
|
class AuthResponse(BaseModel):
|
||||||
principal: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="The principal (username or persistent identifier) of the authenticated user",
|
|
||||||
)
|
|
||||||
access_attributes: AccessAttributes | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="""
|
|
||||||
Structured user attributes for attribute-based access control.
|
|
||||||
|
|
||||||
These attributes determine which resources the user can access.
|
|
||||||
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
|
||||||
Each attribute category contains a list of values that the user has for that category.
|
|
||||||
During access control checks, these values are compared against resource requirements.
|
|
||||||
|
|
||||||
Example with standard categories:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"roles": ["admin", "data-scientist"],
|
|
||||||
"teams": ["ml-team"],
|
|
||||||
"projects": ["llama-3"],
|
|
||||||
"namespaces": ["research"]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthResponse(TokenValidationResult):
|
|
||||||
"""The format of the authentication response from the auth endpoint."""
|
"""The format of the authentication response from the auth endpoint."""
|
||||||
|
|
||||||
|
principal: str
|
||||||
|
# further attributes that may be used for access control decisions
|
||||||
|
attributes: dict[str, list[str]] | None = None
|
||||||
message: str | None = Field(
|
message: str | None = Field(
|
||||||
default=None, description="Optional message providing additional context about the authentication result."
|
default=None, description="Optional message providing additional context about the authentication result."
|
||||||
)
|
)
|
||||||
|
@ -78,7 +53,7 @@ class AuthProvider(ABC):
|
||||||
"""Abstract base class for authentication providers."""
|
"""Abstract base class for authentication providers."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
"""Validate a token and return access attributes."""
|
"""Validate a token and return access attributes."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -88,10 +63,10 @@ class AuthProvider(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
||||||
attributes = AccessAttributes()
|
attributes: dict[str, list[str]] = {}
|
||||||
for claim_key, attribute_key in mapping.items():
|
for claim_key, attribute_key in mapping.items():
|
||||||
if claim_key not in claims or not hasattr(attributes, attribute_key):
|
if claim_key not in claims:
|
||||||
continue
|
continue
|
||||||
claim = claims[claim_key]
|
claim = claims[claim_key]
|
||||||
if isinstance(claim, list):
|
if isinstance(claim, list):
|
||||||
|
@ -99,11 +74,10 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
||||||
else:
|
else:
|
||||||
values = claim.split()
|
values = claim.split()
|
||||||
|
|
||||||
current = getattr(attributes, attribute_key)
|
if attribute_key in attributes:
|
||||||
if current:
|
attributes[attribute_key].extend(values)
|
||||||
current.extend(values)
|
|
||||||
else:
|
else:
|
||||||
setattr(attributes, attribute_key, values)
|
attributes[attribute_key] = values
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
|
|
||||||
|
@ -145,8 +119,6 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||||
for key, value in v.items():
|
for key, value in v.items():
|
||||||
if not value:
|
if not value:
|
||||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||||
if value not in AccessAttributes.model_fields:
|
|
||||||
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
@ -171,14 +143,14 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
self._jwks: dict[str, str] = {}
|
self._jwks: dict[str, str] = {}
|
||||||
self._jwks_lock = Lock()
|
self._jwks_lock = Lock()
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
if self.config.jwks:
|
if self.config.jwks:
|
||||||
return await self.validate_jwt_token(token, scope)
|
return await self.validate_jwt_token(token, scope)
|
||||||
if self.config.introspection:
|
if self.config.introspection:
|
||||||
return await self.introspect_token(token, scope)
|
return await self.introspect_token(token, scope)
|
||||||
raise ValueError("One of jwks or introspection must be configured")
|
raise ValueError("One of jwks or introspection must be configured")
|
||||||
|
|
||||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
"""Validate a token using the JWT token."""
|
"""Validate a token using the JWT token."""
|
||||||
await self._refresh_jwks()
|
await self._refresh_jwks()
|
||||||
|
|
||||||
|
@ -203,12 +175,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
# We should incorporate these into the access attributes.
|
# We should incorporate these into the access attributes.
|
||||||
principal = claims["sub"]
|
principal = claims["sub"]
|
||||||
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||||
return TokenValidationResult(
|
return User(
|
||||||
principal=principal,
|
principal=principal,
|
||||||
access_attributes=access_attributes,
|
attributes=access_attributes,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def introspect_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
"""Validate a token using token introspection as defined by RFC 7662."""
|
"""Validate a token using token introspection as defined by RFC 7662."""
|
||||||
form = {
|
form = {
|
||||||
"token": token,
|
"token": token,
|
||||||
|
@ -242,9 +214,9 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
raise ValueError("Token not active")
|
raise ValueError("Token not active")
|
||||||
principal = fields["sub"] or fields["username"]
|
principal = fields["sub"] or fields["username"]
|
||||||
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
||||||
return TokenValidationResult(
|
return User(
|
||||||
principal=principal,
|
principal=principal,
|
||||||
access_attributes=access_attributes,
|
attributes=access_attributes,
|
||||||
)
|
)
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
logger.exception("Token introspection request timed out")
|
logger.exception("Token introspection request timed out")
|
||||||
|
@ -299,7 +271,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
"""Validate a token using the custom authentication endpoint."""
|
"""Validate a token using the custom authentication endpoint."""
|
||||||
if scope is None:
|
if scope is None:
|
||||||
scope = {}
|
scope = {}
|
||||||
|
@ -333,6 +305,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
json=auth_request.model_dump(),
|
json=auth_request.model_dump(),
|
||||||
timeout=10.0, # Add a reasonable timeout
|
timeout=10.0, # Add a reasonable timeout
|
||||||
)
|
)
|
||||||
|
print("MADE CALL")
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
||||||
raise ValueError(f"Authentication failed: {response.status_code}")
|
raise ValueError(f"Authentication failed: {response.status_code}")
|
||||||
|
@ -341,7 +314,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
try:
|
try:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
auth_response = AuthResponse(**response_data)
|
auth_response = AuthResponse(**response_data)
|
||||||
return auth_response
|
return User(auth_response.principal, auth_response.attributes)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error parsing authentication response")
|
logger.exception("Error parsing authentication response")
|
||||||
raise ValueError("Invalid authentication response format") from e
|
raise ValueError("Invalid authentication response format") from e
|
||||||
|
|
|
@ -11,7 +11,8 @@ from datetime import datetime, timezone
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||||
from llama_stack.distribution.access_control.datatypes import AccessAttributes, AccessRule
|
from llama_stack.distribution.access_control.datatypes import AccessRule
|
||||||
|
from llama_stack.distribution.datatypes import User
|
||||||
from llama_stack.distribution.request_headers import get_authenticated_user
|
from llama_stack.distribution.request_headers import get_authenticated_user
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
|
@ -22,7 +23,7 @@ class AgentSessionInfo(Session):
|
||||||
# TODO: is this used anywhere?
|
# TODO: is this used anywhere?
|
||||||
vector_db_id: str | None = None
|
vector_db_id: str | None = None
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
access_attributes: AccessAttributes | None = None
|
owner: User | None = None
|
||||||
identifier: str | None = None
|
identifier: str | None = None
|
||||||
type: str = "session"
|
type: str = "session"
|
||||||
|
|
||||||
|
@ -42,14 +43,12 @@ class AgentPersistence:
|
||||||
|
|
||||||
# Get current user's auth attributes for new sessions
|
# Get current user's auth attributes for new sessions
|
||||||
user = get_authenticated_user()
|
user = get_authenticated_user()
|
||||||
auth_attributes = user and user.attributes
|
|
||||||
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
|
|
||||||
|
|
||||||
session_info = AgentSessionInfo(
|
session_info = AgentSessionInfo(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
session_name=name,
|
session_name=name,
|
||||||
started_at=datetime.now(timezone.utc),
|
started_at=datetime.now(timezone.utc),
|
||||||
access_attributes=access_attributes,
|
owner=user,
|
||||||
turns=[],
|
turns=[],
|
||||||
identifier=name, # should this be qualified in any way?
|
identifier=name, # should this be qualified in any way?
|
||||||
)
|
)
|
||||||
|
@ -80,7 +79,7 @@ class AgentPersistence:
|
||||||
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
||||||
"""Check if current user has access to the session."""
|
"""Check if current user has access to the session."""
|
||||||
# Handle backward compatibility for old sessions without access control
|
# Handle backward compatibility for old sessions without access control
|
||||||
if not hasattr(session_info, "access_attributes"):
|
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
|
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
|
||||||
|
|
|
@ -12,8 +12,7 @@ import pytest
|
||||||
|
|
||||||
from llama_stack.apis.agents import Turn
|
from llama_stack.apis.agents import Turn
|
||||||
from llama_stack.apis.inference import CompletionMessage, StopReason
|
from llama_stack.apis.inference import CompletionMessage, StopReason
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.datatypes import User
|
||||||
from llama_stack.distribution.request_headers import User
|
|
||||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,9 +37,10 @@ async def test_session_creation_with_access_attributes(mock_get_authenticated_us
|
||||||
# Get the session and verify access attributes were set
|
# Get the session and verify access attributes were set
|
||||||
session_info = await agent_persistence.get_session_info(session_id)
|
session_info = await agent_persistence.get_session_info(session_id)
|
||||||
assert session_info is not None
|
assert session_info is not None
|
||||||
assert session_info.access_attributes is not None
|
assert session_info.owner is not None
|
||||||
assert session_info.access_attributes.roles == ["researcher"]
|
assert session_info.owner.attributes is not None
|
||||||
assert session_info.access_attributes.teams == ["ai-team"]
|
assert session_info.owner.attributes["roles"] == ["researcher"]
|
||||||
|
assert session_info.owner.attributes["teams"] == ["ai-team"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -54,7 +54,7 @@ async def test_session_access_control(mock_get_authenticated_user, test_setup):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
session_name="Restricted Session",
|
session_name="Restricted Session",
|
||||||
started_at=datetime.now(),
|
started_at=datetime.now(),
|
||||||
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
|
owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}),
|
||||||
turns=[],
|
turns=[],
|
||||||
identifier="Restricted Session",
|
identifier="Restricted Session",
|
||||||
)
|
)
|
||||||
|
@ -89,7 +89,7 @@ async def test_turn_access_control(mock_get_authenticated_user, test_setup):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
session_name="Restricted Session",
|
session_name="Restricted Session",
|
||||||
started_at=datetime.now(),
|
started_at=datetime.now(),
|
||||||
access_attributes=AccessAttributes(roles=["admin"]),
|
owner=User("someone", {"roles": ["admin"]}),
|
||||||
turns=[],
|
turns=[],
|
||||||
identifier="Restricted Session",
|
identifier="Restricted Session",
|
||||||
)
|
)
|
||||||
|
@ -143,7 +143,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_u
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
session_name="Restricted Session",
|
session_name="Restricted Session",
|
||||||
started_at=datetime.now(),
|
started_at=datetime.now(),
|
||||||
access_attributes=AccessAttributes(roles=["admin"]),
|
owner=User("someone", {"roles": ["admin"]}),
|
||||||
turns=[],
|
turns=[],
|
||||||
identifier="Restricted Session",
|
identifier="Restricted Session",
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,19 +8,18 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.distribution.datatypes import ModelWithACL
|
from llama_stack.distribution.datatypes import ModelWithOwner, User
|
||||||
from llama_stack.distribution.server.auth_providers import AccessAttributes
|
|
||||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="model-acl",
|
identifier="model-acl",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
provider_resource_id="model-acl-resource",
|
provider_resource_id="model-acl-resource",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
|
owner=User("testuser", {"roles": ["admin"], "teams": ["ai-team"]}),
|
||||||
)
|
)
|
||||||
|
|
||||||
success = await cached_disk_dist_registry.register(model)
|
success = await cached_disk_dist_registry.register(model)
|
||||||
|
@ -29,22 +28,14 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
||||||
cached_model = cached_disk_dist_registry.get_cached("model", "model-acl")
|
cached_model = cached_disk_dist_registry.get_cached("model", "model-acl")
|
||||||
assert cached_model is not None
|
assert cached_model is not None
|
||||||
assert cached_model.identifier == "model-acl"
|
assert cached_model.identifier == "model-acl"
|
||||||
assert cached_model.access_attributes.roles == ["admin"]
|
assert cached_model.owner.principal == "testuser"
|
||||||
assert cached_model.access_attributes.teams == ["ai-team"]
|
assert cached_model.owner.attributes["roles"] == ["admin"]
|
||||||
|
assert cached_model.owner.attributes["teams"] == ["ai-team"]
|
||||||
|
|
||||||
fetched_model = await cached_disk_dist_registry.get("model", "model-acl")
|
fetched_model = await cached_disk_dist_registry.get("model", "model-acl")
|
||||||
assert fetched_model is not None
|
assert fetched_model is not None
|
||||||
assert fetched_model.identifier == "model-acl"
|
assert fetched_model.identifier == "model-acl"
|
||||||
assert fetched_model.access_attributes.roles == ["admin"]
|
assert fetched_model.owner.attributes["roles"] == ["admin"]
|
||||||
|
|
||||||
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
|
|
||||||
await cached_disk_dist_registry.update(model)
|
|
||||||
|
|
||||||
updated_cached = cached_disk_dist_registry.get_cached("model", "model-acl")
|
|
||||||
assert updated_cached is not None
|
|
||||||
assert updated_cached.access_attributes.roles == ["admin", "user"]
|
|
||||||
assert updated_cached.access_attributes.projects == ["project-x"]
|
|
||||||
assert updated_cached.access_attributes.teams is None
|
|
||||||
|
|
||||||
new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
|
new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
|
||||||
await new_registry.initialize()
|
await new_registry.initialize()
|
||||||
|
@ -52,35 +43,32 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
||||||
new_model = await new_registry.get("model", "model-acl")
|
new_model = await new_registry.get("model", "model-acl")
|
||||||
assert new_model is not None
|
assert new_model is not None
|
||||||
assert new_model.identifier == "model-acl"
|
assert new_model.identifier == "model-acl"
|
||||||
assert new_model.access_attributes.roles == ["admin", "user"]
|
assert new_model.owner.principal == "testuser"
|
||||||
assert new_model.access_attributes.projects == ["project-x"]
|
assert new_model.owner.attributes["roles"] == ["admin"]
|
||||||
assert new_model.access_attributes.teams is None
|
assert new_model.owner.attributes["teams"] == ["ai-team"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_registry_empty_acl(cached_disk_dist_registry):
|
async def test_registry_empty_acl(cached_disk_dist_registry):
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="model-empty-acl",
|
identifier="model-empty-acl",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
provider_resource_id="model-resource",
|
provider_resource_id="model-resource",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(),
|
owner=User("testuser", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
await cached_disk_dist_registry.register(model)
|
await cached_disk_dist_registry.register(model)
|
||||||
|
|
||||||
cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl")
|
cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl")
|
||||||
assert cached_model is not None
|
assert cached_model is not None
|
||||||
assert cached_model.access_attributes is not None
|
assert cached_model.owner is not None
|
||||||
assert cached_model.access_attributes.roles is None
|
assert cached_model.owner.attributes is None
|
||||||
assert cached_model.access_attributes.teams is None
|
|
||||||
assert cached_model.access_attributes.projects is None
|
|
||||||
assert cached_model.access_attributes.namespaces is None
|
|
||||||
|
|
||||||
all_models = await cached_disk_dist_registry.get_all()
|
all_models = await cached_disk_dist_registry.get_all()
|
||||||
assert len(all_models) == 1
|
assert len(all_models) == 1
|
||||||
|
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="model-no-acl",
|
identifier="model-no-acl",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
provider_resource_id="model-resource-2",
|
provider_resource_id="model-resource-2",
|
||||||
|
@ -91,7 +79,7 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
|
||||||
|
|
||||||
cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl")
|
cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl")
|
||||||
assert cached_model is not None
|
assert cached_model is not None
|
||||||
assert cached_model.access_attributes is None
|
assert cached_model.owner is None
|
||||||
|
|
||||||
all_models = await cached_disk_dist_registry.get_all()
|
all_models = await cached_disk_dist_registry.get_all()
|
||||||
assert len(all_models) == 2
|
assert len(all_models) == 2
|
||||||
|
@ -99,19 +87,19 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_registry_serialization(cached_disk_dist_registry):
|
async def test_registry_serialization(cached_disk_dist_registry):
|
||||||
attributes = AccessAttributes(
|
attributes = {
|
||||||
roles=["admin", "researcher"],
|
"roles": ["admin", "researcher"],
|
||||||
teams=["ai-team", "ml-team"],
|
"teams": ["ai-team", "ml-team"],
|
||||||
projects=["project-a", "project-b"],
|
"projects": ["project-a", "project-b"],
|
||||||
namespaces=["prod", "staging"],
|
"namespaces": ["prod", "staging"],
|
||||||
)
|
}
|
||||||
|
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="model-serialize",
|
identifier="model-serialize",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
provider_resource_id="model-resource",
|
provider_resource_id="model-resource",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=attributes,
|
owner=User("bob", attributes),
|
||||||
)
|
)
|
||||||
|
|
||||||
await cached_disk_dist_registry.register(model)
|
await cached_disk_dist_registry.register(model)
|
||||||
|
@ -122,7 +110,7 @@ async def test_registry_serialization(cached_disk_dist_registry):
|
||||||
loaded_model = await new_registry.get("model", "model-serialize")
|
loaded_model = await new_registry.get("model", "model-serialize")
|
||||||
assert loaded_model is not None
|
assert loaded_model is not None
|
||||||
|
|
||||||
assert loaded_model.access_attributes.roles == ["admin", "researcher"]
|
assert loaded_model.owner.attributes["roles"] == ["admin", "researcher"]
|
||||||
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"]
|
assert loaded_model.owner.attributes["teams"] == ["ai-team", "ml-team"]
|
||||||
assert loaded_model.access_attributes.projects == ["project-a", "project-b"]
|
assert loaded_model.owner.attributes["projects"] == ["project-a", "project-b"]
|
||||||
assert loaded_model.access_attributes.namespaces == ["prod", "staging"]
|
assert loaded_model.owner.attributes["namespaces"] == ["prod", "staging"]
|
||||||
|
|
|
@ -8,13 +8,12 @@ from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter, ValidationError
|
||||||
|
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes, AccessRule, ModelWithACL
|
from llama_stack.distribution.datatypes import AccessRule, ModelWithOwner, User
|
||||||
from llama_stack.distribution.request_headers import User
|
|
||||||
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,25 +44,25 @@ async def test_setup(cached_disk_dist_registry):
|
||||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
|
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
model_public = ModelWithACL(
|
model_public = ModelWithOwner(
|
||||||
identifier="model-public",
|
identifier="model-public",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="model-public",
|
provider_resource_id="model-public",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
)
|
)
|
||||||
model_admin_only = ModelWithACL(
|
model_admin_only = ModelWithOwner(
|
||||||
identifier="model-admin",
|
identifier="model-admin",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="model-admin",
|
provider_resource_id="model-admin",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(roles=["admin"]),
|
owner=User("testuser", {"roles": ["admin"]}),
|
||||||
)
|
)
|
||||||
model_data_scientist = ModelWithACL(
|
model_data_scientist = ModelWithOwner(
|
||||||
identifier="model-data-scientist",
|
identifier="model-data-scientist",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="model-data-scientist",
|
provider_resource_id="model-data-scientist",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]),
|
owner=User("testuser", {"roles": ["data-scientist", "researcher"], "teams": ["ml-team"]}),
|
||||||
)
|
)
|
||||||
await registry.register(model_public)
|
await registry.register(model_public)
|
||||||
await registry.register(model_admin_only)
|
await registry.register(model_admin_only)
|
||||||
|
@ -110,7 +109,7 @@ async def test_access_control_with_cache(mock_get_authenticated_user, test_setup
|
||||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
|
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
model_public = ModelWithACL(
|
model_public = ModelWithOwner(
|
||||||
identifier="model-updates",
|
identifier="model-updates",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="model-updates",
|
provider_resource_id="model-updates",
|
||||||
|
@ -125,7 +124,7 @@ async def test_access_control_and_updates(mock_get_authenticated_user, test_setu
|
||||||
)
|
)
|
||||||
model = await routing_table.get_model("model-updates")
|
model = await routing_table.get_model("model-updates")
|
||||||
assert model.identifier == "model-updates"
|
assert model.identifier == "model-updates"
|
||||||
model_public.access_attributes = AccessAttributes(roles=["admin"])
|
model_public.owner = User("testuser", {"roles": ["admin"]})
|
||||||
await registry.update(model_public)
|
await registry.update(model_public)
|
||||||
mock_get_authenticated_user.return_value = User(
|
mock_get_authenticated_user.return_value = User(
|
||||||
"test-user",
|
"test-user",
|
||||||
|
@ -149,12 +148,12 @@ async def test_access_control_and_updates(mock_get_authenticated_user, test_setu
|
||||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
|
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="model-empty-attrs",
|
identifier="model-empty-attrs",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="model-empty-attrs",
|
provider_resource_id="model-empty-attrs",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(),
|
owner=User("testuser", {}),
|
||||||
)
|
)
|
||||||
await registry.register(model)
|
await registry.register(model)
|
||||||
mock_get_authenticated_user.return_value = User(
|
mock_get_authenticated_user.return_value = User(
|
||||||
|
@ -174,18 +173,18 @@ async def test_access_control_empty_attributes(mock_get_authenticated_user, test
|
||||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
|
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
model_public = ModelWithACL(
|
model_public = ModelWithOwner(
|
||||||
identifier="model-public-2",
|
identifier="model-public-2",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="model-public-2",
|
provider_resource_id="model-public-2",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
)
|
)
|
||||||
model_restricted = ModelWithACL(
|
model_restricted = ModelWithOwner(
|
||||||
identifier="model-restricted",
|
identifier="model-restricted",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="model-restricted",
|
provider_resource_id="model-restricted",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(roles=["admin"]),
|
owner=User("testuser", {"roles": ["admin"]}),
|
||||||
)
|
)
|
||||||
await registry.register(model_public)
|
await registry.register(model_public)
|
||||||
await registry.register(model_restricted)
|
await registry.register(model_restricted)
|
||||||
|
@ -212,7 +211,7 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set
|
||||||
mock_get_authenticated_user.return_value = User("test-user", creator_attributes)
|
mock_get_authenticated_user.return_value = User("test-user", creator_attributes)
|
||||||
|
|
||||||
# Create model without explicit access attributes
|
# Create model without explicit access attributes
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="auto-access-model",
|
identifier="auto-access-model",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
provider_resource_id="auto-access-model",
|
provider_resource_id="auto-access-model",
|
||||||
|
@ -222,10 +221,11 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set
|
||||||
|
|
||||||
# Verify the model got creator's attributes
|
# Verify the model got creator's attributes
|
||||||
registered_model = await routing_table.get_model("auto-access-model")
|
registered_model = await routing_table.get_model("auto-access-model")
|
||||||
assert registered_model.access_attributes is not None
|
assert registered_model.owner is not None
|
||||||
assert registered_model.access_attributes.roles == ["data-scientist"]
|
assert registered_model.owner.attributes is not None
|
||||||
assert registered_model.access_attributes.teams == ["ml-team"]
|
assert registered_model.owner.attributes["roles"] == ["data-scientist"]
|
||||||
assert registered_model.access_attributes.projects == ["llama-3"]
|
assert registered_model.owner.attributes["teams"] == ["ml-team"]
|
||||||
|
assert registered_model.owner.attributes["projects"] == ["llama-3"]
|
||||||
|
|
||||||
# Verify another user without matching attributes can't access it
|
# Verify another user without matching attributes can't access it
|
||||||
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
|
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
|
||||||
|
@ -354,15 +354,14 @@ def test_permit_when():
|
||||||
- permit:
|
- permit:
|
||||||
principal: user-1
|
principal: user-1
|
||||||
actions: [read]
|
actions: [read]
|
||||||
when:
|
when: user in owners namespaces
|
||||||
user_in: resource.namespaces
|
|
||||||
"""
|
"""
|
||||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="mymodel",
|
identifier="mymodel",
|
||||||
provider_id="myprovider",
|
provider_id="myprovider",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||||
)
|
)
|
||||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||||
|
@ -376,15 +375,15 @@ def test_permit_unless():
|
||||||
actions: [read]
|
actions: [read]
|
||||||
resource: model::*
|
resource: model::*
|
||||||
unless:
|
unless:
|
||||||
- user_not_in: resource.namespaces
|
- user not in owners namespaces
|
||||||
- user_in: resource.teams
|
- user in owners teams
|
||||||
"""
|
"""
|
||||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="mymodel",
|
identifier="mymodel",
|
||||||
provider_id="myprovider",
|
provider_id="myprovider",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||||
)
|
)
|
||||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||||
|
@ -397,16 +396,16 @@ def test_forbid_when():
|
||||||
principal: user-1
|
principal: user-1
|
||||||
actions: [read]
|
actions: [read]
|
||||||
when:
|
when:
|
||||||
user_in: resource.namespaces
|
user in owners namespaces
|
||||||
- permit:
|
- permit:
|
||||||
actions: [read]
|
actions: [read]
|
||||||
"""
|
"""
|
||||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="mymodel",
|
identifier="mymodel",
|
||||||
provider_id="myprovider",
|
provider_id="myprovider",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||||
)
|
)
|
||||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||||
|
@ -419,35 +418,33 @@ def test_forbid_unless():
|
||||||
principal: user-1
|
principal: user-1
|
||||||
actions: [read]
|
actions: [read]
|
||||||
unless:
|
unless:
|
||||||
user_in: resource.namespaces
|
user in owners namespaces
|
||||||
- permit:
|
- permit:
|
||||||
actions: [read]
|
actions: [read]
|
||||||
"""
|
"""
|
||||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="mymodel",
|
identifier="mymodel",
|
||||||
provider_id="myprovider",
|
provider_id="myprovider",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||||
)
|
)
|
||||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||||
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||||
|
|
||||||
|
|
||||||
def test_condition_with_literal():
|
def test_user_has_attribute():
|
||||||
config = """
|
config = """
|
||||||
- permit:
|
- permit:
|
||||||
actions: [read]
|
actions: [read]
|
||||||
when:
|
when: user with admin in roles
|
||||||
user_in: role::admin
|
|
||||||
"""
|
"""
|
||||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="mymodel",
|
identifier="mymodel",
|
||||||
provider_id="myprovider",
|
provider_id="myprovider",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
|
||||||
)
|
)
|
||||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||||
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||||
|
@ -455,35 +452,115 @@ def test_condition_with_literal():
|
||||||
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||||
|
|
||||||
|
|
||||||
def test_condition_with_unrecognised_literal():
|
def test_user_does_not_have_attribute():
|
||||||
config = """
|
config = """
|
||||||
- permit:
|
- permit:
|
||||||
actions: [read]
|
actions: [read]
|
||||||
when:
|
unless: user with admin not in roles
|
||||||
user_in: whatever
|
|
||||||
"""
|
"""
|
||||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="mymodel",
|
identifier="mymodel",
|
||||||
provider_id="myprovider",
|
provider_id="myprovider",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
access_attributes=AccessAttributes(namespaces=["foo"]),
|
|
||||||
)
|
)
|
||||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||||
assert not is_action_allowed(policy, "read", model, User("user-2", None))
|
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||||
|
|
||||||
|
|
||||||
def test_empty_condition():
|
def test_is_owner():
|
||||||
config = """
|
config = """
|
||||||
- permit:
|
- permit:
|
||||||
actions: [read]
|
actions: [read]
|
||||||
when: {}
|
when: user is owner
|
||||||
"""
|
"""
|
||||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier="mymodel",
|
identifier="mymodel",
|
||||||
provider_id="myprovider",
|
provider_id="myprovider",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
|
owner=User("user-2", {"namespaces": ["foo"]}),
|
||||||
)
|
)
|
||||||
assert is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||||
assert is_action_allowed(policy, "read", model, User("user-2", None))
|
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_not_owner():
|
||||||
|
config = """
|
||||||
|
- permit:
|
||||||
|
actions: [read]
|
||||||
|
unless: user is not owner
|
||||||
|
"""
|
||||||
|
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
model = ModelWithOwner(
|
||||||
|
identifier="mymodel",
|
||||||
|
provider_id="myprovider",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
owner=User("user-2", {"namespaces": ["foo"]}),
|
||||||
|
)
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||||
|
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||||
|
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_rule_permit_and_forbid_both_specified():
|
||||||
|
config = """
|
||||||
|
- permit:
|
||||||
|
actions: [read]
|
||||||
|
forbid:
|
||||||
|
actions: [create]
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_rule_neither_permit_or_forbid_specified():
|
||||||
|
config = """
|
||||||
|
- when: user is owner
|
||||||
|
unless: user with admin in roles
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_rule_when_and_unless_both_specified():
|
||||||
|
config = """
|
||||||
|
- permit:
|
||||||
|
actions: [read]
|
||||||
|
when: user is owner
|
||||||
|
unless: user with admin in roles
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_condition():
|
||||||
|
config = """
|
||||||
|
- permit:
|
||||||
|
actions: [read]
|
||||||
|
when: random words that are not valid
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"condition",
|
||||||
|
[
|
||||||
|
"user is owner",
|
||||||
|
"user is not owner",
|
||||||
|
"user with dev in teams",
|
||||||
|
"user with default not in namespaces",
|
||||||
|
"user in owners roles",
|
||||||
|
"user not in owners projects",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_condition_reprs(condition):
|
||||||
|
from llama_stack.distribution.access_control.conditions import parse_condition
|
||||||
|
|
||||||
|
assert condition == str(parse_condition(condition))
|
||||||
|
|
|
@ -139,7 +139,7 @@ async def mock_post_success(*args, **kwargs):
|
||||||
{
|
{
|
||||||
"message": "Authentication successful",
|
"message": "Authentication successful",
|
||||||
"principal": "test-principal",
|
"principal": "test-principal",
|
||||||
"access_attributes": {
|
"attributes": {
|
||||||
"roles": ["admin", "user"],
|
"roles": ["admin", "user"],
|
||||||
"teams": ["ml-team", "nlp-team"],
|
"teams": ["ml-team", "nlp-team"],
|
||||||
"projects": ["llama-3", "project-x"],
|
"projects": ["llama-3", "project-x"],
|
||||||
|
@ -233,7 +233,7 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
||||||
{
|
{
|
||||||
"message": "Authentication successful",
|
"message": "Authentication successful",
|
||||||
"principal": "test-principal",
|
"principal": "test-principal",
|
||||||
"access_attributes": {
|
"attributes": {
|
||||||
"roles": ["admin", "user"],
|
"roles": ["admin", "user"],
|
||||||
"teams": ["ml-team", "nlp-team"],
|
"teams": ["ml-team", "nlp-team"],
|
||||||
"projects": ["llama-3", "project-x"],
|
"projects": ["llama-3", "project-x"],
|
||||||
|
@ -255,33 +255,6 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
||||||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope):
|
|
||||||
"""Test middleware behavior with no access attributes"""
|
|
||||||
middleware, mock_app = mock_http_middleware
|
|
||||||
mock_receive = AsyncMock()
|
|
||||||
mock_send = AsyncMock()
|
|
||||||
|
|
||||||
with patch("httpx.AsyncClient") as mock_client:
|
|
||||||
mock_client_instance = AsyncMock()
|
|
||||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
|
||||||
|
|
||||||
mock_client_instance.post.return_value = MockResponse(
|
|
||||||
200,
|
|
||||||
{
|
|
||||||
"message": "Authentication successful"
|
|
||||||
# No access_attributes
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
await middleware(mock_scope, mock_receive, mock_send)
|
|
||||||
|
|
||||||
assert "user_attributes" in mock_scope
|
|
||||||
attributes = mock_scope["user_attributes"]
|
|
||||||
assert "roles" in attributes
|
|
||||||
assert attributes["roles"] == ["test.jwt.token"]
|
|
||||||
|
|
||||||
|
|
||||||
# oauth2 token provider tests
|
# oauth2 token provider tests
|
||||||
|
|
||||||
|
|
||||||
|
@ -380,16 +353,16 @@ def test_get_attributes_from_claims():
|
||||||
"aud": "llama-stack",
|
"aud": "llama-stack",
|
||||||
}
|
}
|
||||||
attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"})
|
attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"})
|
||||||
assert attributes.roles == ["my-user"]
|
assert attributes["roles"] == ["my-user"]
|
||||||
assert attributes.teams == ["group1", "group2"]
|
assert attributes["teams"] == ["group1", "group2"]
|
||||||
|
|
||||||
claims = {
|
claims = {
|
||||||
"sub": "my-user",
|
"sub": "my-user",
|
||||||
"tenant": "my-tenant",
|
"tenant": "my-tenant",
|
||||||
}
|
}
|
||||||
attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"})
|
attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"})
|
||||||
assert attributes.roles == ["my-user"]
|
assert attributes["roles"] == ["my-user"]
|
||||||
assert attributes.namespaces == ["my-tenant"]
|
assert attributes["namespaces"] == ["my-tenant"]
|
||||||
|
|
||||||
claims = {
|
claims = {
|
||||||
"sub": "my-user",
|
"sub": "my-user",
|
||||||
|
@ -408,9 +381,9 @@ def test_get_attributes_from_claims():
|
||||||
"groups": "teams",
|
"groups": "teams",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert set(attributes.roles) == {"my-user", "my-username"}
|
assert set(attributes["roles"]) == {"my-user", "my-username"}
|
||||||
assert set(attributes.teams) == {"my-team", "group1", "group2"}
|
assert set(attributes["teams"]) == {"my-team", "group1", "group2"}
|
||||||
assert attributes.namespaces == ["my-tenant"]
|
assert attributes["namespaces"] == ["my-tenant"]
|
||||||
|
|
||||||
|
|
||||||
# TODO: add more tests for oauth2 token provider
|
# TODO: add more tests for oauth2 token provider
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue