llama-stack-mirror/tests/unit/server/test_resolver.py
grs 7c1998db25
feat: fine grained access control policy (#2264)
This allows a set of rules to be defined for determining access to
resources. The rules are (loosely) based on the cedar policy format.

A rule defines a list of action either to permit or to forbid. It may
specify a principal or a resource that must match for the rule to take
effect. It may also specify a condition, either a 'when' or an 'unless',
with additional constraints as to where the rule applies.

A list of rules is held for each type to be protected and tried in order
to find a match. If a match is found, the request is permitted or
forbidden depening on the type of rule. If no match is found, the
request is denied. If no rules are specified for a given type, a rule
that allows any action as long as the resource attributes match the user
attributes is added (i.e. the previous behaviour is the default.

Some examples in yaml:

```
    model:
    - permit:
      principal: user-1
      actions: [create, read, delete]
      comment: user-1 has full access to all models
    - permit:
      principal: user-2
      actions: [read]
      resource: model-1
      comment: user-2 has read access to model-1 only
    - permit:
      actions: [read]
      when:
        user_in: resource.namespaces
      comment: any user has read access to models with matching attributes
    vector_db:
    - forbid:
      actions: [create, read, delete]
      unless:
        user_in: role::admin
      comment: only user with admin role can use vector_db resources
```

---------

Signed-off-by: Gordon Sim <gsim@redhat.com>
2025-06-03 14:51:12 -07:00

118 lines
3.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import sys
from typing import Any, Protocol
from unittest.mock import AsyncMock, MagicMock
import pytest
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Inference
from llama_stack.distribution.datatypes import (
Api,
Provider,
StackRunConfig,
)
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.routers.inference import InferenceRouter
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None:
"""Dynamically add protocol methods to a class by inspecting the protocol."""
for name, value in inspect.getmembers(protocol):
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
# Get the signature
sig = inspect.signature(value)
# Create an async function with the same signature that returns a MagicMock
async def mock_impl(*args, **kwargs):
return MagicMock()
# Set the signature on our mock implementation
mock_impl.__signature__ = sig
# Add it to the class
setattr(cls, name, mock_impl)
class SampleConfig(BaseModel):
foo: str = Field(
default="bar",
description="foo",
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"foo": "baz",
}
class SampleImpl:
def __init__(self, config: SampleConfig, deps: dict[Api, Any], provider_spec: ProviderSpec = None):
self.__provider_id__ = "test_provider"
self.__provider_spec__ = provider_spec
self.__provider_config__ = config
self.__deps__ = deps
self.foo = config.foo
async def initialize(self):
pass
@pytest.mark.asyncio
async def test_resolve_impls_basic():
# Create a real provider spec
provider_spec = InlineProviderSpec(
api=Api.inference,
provider_type="sample",
module="test_module",
config_class="test_resolver.SampleConfig",
api_dependencies=[],
)
# Create provider registry with our provider
provider_registry = {Api.inference: {provider_spec.provider_type: provider_spec}}
run_config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="sample_provider",
provider_type="sample",
config=SampleConfig.sample_run_config(),
)
]
},
)
dist_registry = MagicMock()
mock_module = MagicMock()
impl = SampleImpl(SampleConfig(foo="baz"), {}, provider_spec)
add_protocol_methods(SampleImpl, Inference)
mock_module.get_provider_impl = AsyncMock(return_value=impl)
mock_module.get_provider_impl.__text_signature__ = "()"
sys.modules["test_module"] = mock_module
impls = await resolve_impls(run_config, provider_registry, dist_registry, policy={})
assert Api.inference in impls
assert isinstance(impls[Api.inference], InferenceRouter)
table = impls[Api.inference].routing_table
assert isinstance(table, ModelsRoutingTable)
impl = table.impls_by_provider_id["sample_provider"]
assert isinstance(impl, SampleImpl)
assert impl.foo == "baz"
assert impl.__provider_id__ == "sample_provider"
assert impl.__provider_spec__ == provider_spec