forked from phoenix-oss/llama-stack-mirror
merge
This commit is contained in:
commit
a54d757ade
197 changed files with 9392 additions and 3089 deletions
|
@ -8,9 +8,7 @@ from typing import Any, Dict
|
|||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agents.turn_create_params import Document
|
||||
from llama_stack_client import Agent, AgentEventLogger, Document
|
||||
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
||||
|
||||
from llama_stack.apis.agents.agents import (
|
||||
|
@ -92,7 +90,7 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(simple_hello) if log is not None]
|
||||
logs = [str(log) for log in AgentEventLogger().log(simple_hello) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
|
||||
assert "hello" in logs_str.lower()
|
||||
|
@ -111,7 +109,7 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(bomb_response) if log is not None]
|
||||
logs = [str(log) for log in AgentEventLogger().log(bomb_response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
assert "I can't" in logs_str
|
||||
|
||||
|
@ -192,7 +190,7 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
|
||||
assert "tool_execution>" in logs_str
|
||||
|
@ -221,7 +219,7 @@ def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, a
|
|||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
|
||||
assert "541" in logs_str
|
||||
|
@ -262,7 +260,7 @@ def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inferen
|
|||
session_id=session_id,
|
||||
documents=input.get("documents", None),
|
||||
)
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
assert "Tool:code_interpreter" in logs_str
|
||||
|
||||
|
@ -287,7 +285,7 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
assert "-100" in logs_str
|
||||
assert "get_boiling_point" in logs_str
|
||||
|
|
|
@ -96,7 +96,7 @@ def test_evaluate_benchmark(llama_stack_client, text_model_id, scoring_fn_id):
|
|||
)
|
||||
assert response.job_id == "0"
|
||||
job_status = llama_stack_client.eval.jobs.status(job_id=response.job_id, benchmark_id=benchmark_id)
|
||||
assert job_status and job_status == "completed"
|
||||
assert job_status and job_status.status == "completed"
|
||||
|
||||
eval_response = llama_stack_client.eval.jobs.retrieve(job_id=response.job_id, benchmark_id=benchmark_id)
|
||||
assert eval_response is not None
|
||||
|
|
|
@ -12,6 +12,12 @@ from llama_stack import LlamaStackAsLibraryClient
|
|||
|
||||
class TestProviders:
|
||||
@pytest.mark.asyncio
|
||||
def test_list(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
def test_providers(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
provider_list = llama_stack_client.providers.list()
|
||||
assert provider_list is not None
|
||||
assert len(provider_list) > 0
|
||||
|
||||
for provider in provider_list:
|
||||
pid = provider.provider_id
|
||||
provider = llama_stack_client.providers.retrieve(pid)
|
||||
assert provider is not None
|
||||
|
|
12
tests/integration/tools/test_tools.py
Normal file
12
tests/integration/tools/test_tools.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
def test_toolsgroups_unregister(llama_stack_client):
|
||||
client = llama_stack_client
|
||||
client.toolgroups.unregister(
|
||||
toolgroup_id="builtin::websearch",
|
||||
)
|
|
@ -165,7 +165,10 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
request.model = MODEL
|
||||
request.tool_config.tool_prompt_format = ToolPromptFormat.json
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt)
|
||||
self.assertIn(
|
||||
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
|
||||
prompt,
|
||||
)
|
||||
|
||||
async def test_user_provided_system_message(self):
|
||||
content = "Hello !"
|
||||
|
|
|
@ -25,19 +25,21 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
|||
|
||||
|
||||
class PromptTemplateTests(unittest.TestCase):
|
||||
def check_generator_output(self, generator, expected_text):
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example)
|
||||
text = pt.render()
|
||||
# print(text) # debugging
|
||||
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
||||
def check_generator_output(self, generator):
|
||||
for example in generator.data_examples():
|
||||
pt = generator.gen(example)
|
||||
text = pt.render()
|
||||
# print(text) # debugging
|
||||
if not example:
|
||||
continue
|
||||
for tool in example:
|
||||
assert tool.tool_name in text
|
||||
|
||||
def test_system_default(self):
|
||||
generator = SystemDefaultGenerator()
|
||||
today = datetime.now().strftime("%d %B %Y")
|
||||
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
||||
self.check_generator_output(generator, expected_text)
|
||||
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
|
||||
|
||||
def test_system_builtin_only(self):
|
||||
generator = BuiltinToolGenerator()
|
||||
|
@ -47,143 +49,24 @@ class PromptTemplateTests(unittest.TestCase):
|
|||
Tools: brave_search, wolfram_alpha
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
|
||||
|
||||
def test_system_custom_only(self):
|
||||
self.maxDiff = None
|
||||
generator = JsonCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Answer the user's question by making use of the following functions if needed.
|
||||
If none of the function can be used, please say so.
|
||||
Here is a list of functions in JSON format:
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "trending_songs",
|
||||
"description": "Returns the trending songs on a Music site",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": [
|
||||
{
|
||||
"n": {
|
||||
"type": "object",
|
||||
"description": "The number of songs to return"
|
||||
}
|
||||
},
|
||||
{
|
||||
"genre": {
|
||||
"type": "object",
|
||||
"description": "The genre of the songs to return"
|
||||
}
|
||||
}
|
||||
],
|
||||
"required": ["n"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Return function calls in JSON format.
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
self.check_generator_output(generator)
|
||||
|
||||
def test_system_custom_function_tag(self):
|
||||
self.maxDiff = None
|
||||
generator = FunctionTagCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
You have access to the following functions:
|
||||
|
||||
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
|
||||
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
|
||||
|
||||
Think very carefully before calling functions.
|
||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||
|
||||
Reminder:
|
||||
- If looking for real time information use relevant functions before falling back to brave_search
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
self.check_generator_output(generator)
|
||||
|
||||
def test_llama_3_2_system_zero_shot(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant. You have access to functions, but you should only use them if they are required.
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": ["city"],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
self.check_generator_output(generator)
|
||||
|
||||
def test_llama_3_2_provided_system_prompt(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": ["city"],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]"""
|
||||
)
|
||||
user_system_prompt = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
@ -195,4 +78,5 @@ class PromptTemplateTests(unittest.TestCase):
|
|||
|
||||
pt = generator.gen(example, user_system_prompt)
|
||||
text = pt.render()
|
||||
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
||||
assert "Overriding message." in text
|
||||
assert '"name": "get_weather"' in text
|
||||
|
|
175
tests/unit/providers/agents/test_persistence_access_control.py
Normal file
175
tests/unit/providers/agents/test_persistence_access_control.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import Turn
|
||||
from llama_stack.apis.inference import CompletionMessage, StopReason
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_setup():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_path = os.path.join(temp_dir, "test_persistence_access_control.db")
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
|
||||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||
await kvstore.initialize()
|
||||
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=kvstore)
|
||||
yield agent_persistence
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Set creator's attributes for the session
|
||||
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
|
||||
mock_get_auth_attributes.return_value = creator_attributes
|
||||
|
||||
# Create a session
|
||||
session_id = await agent_persistence.create_session("Test Session")
|
||||
|
||||
# Get the session and verify access attributes were set
|
||||
session_info = await agent_persistence.get_session_info(session_id)
|
||||
assert session_info is not None
|
||||
assert session_info.access_attributes is not None
|
||||
assert session_info.access_attributes.roles == ["researcher"]
|
||||
assert session_info.access_attributes.teams == ["ai-team"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_session_access_control(mock_get_auth_attributes, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with specific access attributes
|
||||
session_id = str(uuid.uuid4())
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
key=f"session:{agent_persistence.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
# User with matching attributes can access
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
|
||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||
assert retrieved_session is not None
|
||||
assert retrieved_session.session_id == session_id
|
||||
|
||||
# User without matching attributes cannot access
|
||||
mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]}
|
||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||
assert retrieved_session is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with restricted access
|
||||
session_id = str(uuid.uuid4())
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
key=f"session:{agent_persistence.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
# Create a turn for this session
|
||||
turn_id = str(uuid.uuid4())
|
||||
turn = Turn(
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
steps=[],
|
||||
started_at=datetime.now(),
|
||||
input_messages=[],
|
||||
output_message=CompletionMessage(
|
||||
content="Hello",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
),
|
||||
)
|
||||
|
||||
# Admin can add turn
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
|
||||
await agent_persistence.add_turn_to_session(session_id, turn)
|
||||
|
||||
# Admin can get turn
|
||||
retrieved_turn = await agent_persistence.get_session_turn(session_id, turn_id)
|
||||
assert retrieved_turn is not None
|
||||
assert retrieved_turn.turn_id == turn_id
|
||||
|
||||
# Regular user cannot get turn
|
||||
mock_get_auth_attributes.return_value = {"roles": ["user"]}
|
||||
with pytest.raises(ValueError):
|
||||
await agent_persistence.get_session_turn(session_id, turn_id)
|
||||
|
||||
# Regular user cannot get turns for session
|
||||
with pytest.raises(ValueError):
|
||||
await agent_persistence.get_session_turns(session_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with restricted access
|
||||
session_id = str(uuid.uuid4())
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
key=f"session:{agent_persistence.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
|
||||
# Admin user can set inference iterations
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
|
||||
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
|
||||
|
||||
# Admin user can get inference iterations
|
||||
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
|
||||
assert infer_iters == 5
|
||||
|
||||
# Regular user cannot get inference iterations
|
||||
mock_get_auth_attributes.return_value = {"roles": ["user"]}
|
||||
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
|
||||
assert infer_iters is None
|
||||
|
||||
# Regular user cannot set inference iterations (should raise ValueError)
|
||||
with pytest.raises(ValueError):
|
||||
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 10)
|
|
@ -188,7 +188,7 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
|
|||
caplog.set_level(logging.WARNING)
|
||||
|
||||
# Log when event loop is blocked for more than 200ms
|
||||
loop.slow_callback_duration = 0.2
|
||||
loop.slow_callback_duration = 0.5
|
||||
# Sleep for 500ms in our delayed http response
|
||||
sleep_time = 0.5
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ import pytest_asyncio
|
|||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.distribution.store.registry import (
|
||||
KEY_FORMAT,
|
||||
CachedDiskDistributionRegistry,
|
||||
DiskDistributionRegistry,
|
||||
)
|
||||
|
@ -197,3 +198,72 @@ async def test_get_all_objects(config):
|
|||
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
|
||||
assert stored_vector_db.provider_id == original_vector_db.provider_id
|
||||
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_registry_values_error_handling(config):
|
||||
kvstore = await kvstore_impl(config)
|
||||
|
||||
valid_db = VectorDB(
|
||||
identifier="valid_vector_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="valid_vector_db",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json())
|
||||
|
||||
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
|
||||
|
||||
await kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
||||
'{"type": "vector_db", "identifier": "missing_fields"}',
|
||||
)
|
||||
|
||||
test_registry = DiskDistributionRegistry(kvstore)
|
||||
await test_registry.initialize()
|
||||
|
||||
# Get all objects, which should only return the valid one
|
||||
all_objects = await test_registry.get_all()
|
||||
|
||||
# Should have filtered out the invalid entries
|
||||
assert len(all_objects) == 1
|
||||
assert all_objects[0].identifier == "valid_vector_db"
|
||||
|
||||
# Check that the get method also handles errors correctly
|
||||
invalid_obj = await test_registry.get("vector_db", "corrupted_json")
|
||||
assert invalid_obj is None
|
||||
|
||||
invalid_obj = await test_registry.get("vector_db", "missing_fields")
|
||||
assert invalid_obj is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_registry_error_handling(config):
|
||||
kvstore = await kvstore_impl(config)
|
||||
|
||||
valid_db = VectorDB(
|
||||
identifier="valid_cached_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="valid_cached_db",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json())
|
||||
|
||||
await kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
|
||||
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
|
||||
)
|
||||
|
||||
cached_registry = CachedDiskDistributionRegistry(kvstore)
|
||||
await cached_registry.initialize()
|
||||
|
||||
all_objects = await cached_registry.get_all()
|
||||
assert len(all_objects) == 1
|
||||
assert all_objects[0].identifier == "valid_cached_db"
|
||||
|
||||
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
|
||||
assert invalid_obj is None
|
||||
|
|
151
tests/unit/registry/test_registry_acl.py
Normal file
151
tests/unit/registry/test_registry_acl.py
Normal file
|
@ -0,0 +1,151 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import ModelWithACL
|
||||
from llama_stack.distribution.server.auth import AccessAttributes
|
||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def kvstore():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_path = os.path.join(temp_dir, "test_registry_acl.db")
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
|
||||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||
await kvstore.initialize()
|
||||
yield kvstore
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def registry(kvstore):
|
||||
registry = CachedDiskDistributionRegistry(kvstore)
|
||||
await registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_cache_with_acl(registry):
|
||||
model = ModelWithACL(
|
||||
identifier="model-acl",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-acl-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
|
||||
)
|
||||
|
||||
success = await registry.register(model)
|
||||
assert success
|
||||
|
||||
cached_model = registry.get_cached("model", "model-acl")
|
||||
assert cached_model is not None
|
||||
assert cached_model.identifier == "model-acl"
|
||||
assert cached_model.access_attributes.roles == ["admin"]
|
||||
assert cached_model.access_attributes.teams == ["ai-team"]
|
||||
|
||||
fetched_model = await registry.get("model", "model-acl")
|
||||
assert fetched_model is not None
|
||||
assert fetched_model.identifier == "model-acl"
|
||||
assert fetched_model.access_attributes.roles == ["admin"]
|
||||
|
||||
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
|
||||
await registry.update(model)
|
||||
|
||||
updated_cached = registry.get_cached("model", "model-acl")
|
||||
assert updated_cached is not None
|
||||
assert updated_cached.access_attributes.roles == ["admin", "user"]
|
||||
assert updated_cached.access_attributes.projects == ["project-x"]
|
||||
assert updated_cached.access_attributes.teams is None
|
||||
|
||||
new_registry = CachedDiskDistributionRegistry(registry.kvstore)
|
||||
await new_registry.initialize()
|
||||
|
||||
new_model = await new_registry.get("model", "model-acl")
|
||||
assert new_model is not None
|
||||
assert new_model.identifier == "model-acl"
|
||||
assert new_model.access_attributes.roles == ["admin", "user"]
|
||||
assert new_model.access_attributes.projects == ["project-x"]
|
||||
assert new_model.access_attributes.teams is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_empty_acl(registry):
|
||||
model = ModelWithACL(
|
||||
identifier="model-empty-acl",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(),
|
||||
)
|
||||
|
||||
await registry.register(model)
|
||||
|
||||
cached_model = registry.get_cached("model", "model-empty-acl")
|
||||
assert cached_model is not None
|
||||
assert cached_model.access_attributes is not None
|
||||
assert cached_model.access_attributes.roles is None
|
||||
assert cached_model.access_attributes.teams is None
|
||||
assert cached_model.access_attributes.projects is None
|
||||
assert cached_model.access_attributes.namespaces is None
|
||||
|
||||
all_models = await registry.get_all()
|
||||
assert len(all_models) == 1
|
||||
|
||||
model = ModelWithACL(
|
||||
identifier="model-no-acl",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource-2",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
await registry.register(model)
|
||||
|
||||
cached_model = registry.get_cached("model", "model-no-acl")
|
||||
assert cached_model is not None
|
||||
assert cached_model.access_attributes is None
|
||||
|
||||
all_models = await registry.get_all()
|
||||
assert len(all_models) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_serialization(registry):
|
||||
attributes = AccessAttributes(
|
||||
roles=["admin", "researcher"],
|
||||
teams=["ai-team", "ml-team"],
|
||||
projects=["project-a", "project-b"],
|
||||
namespaces=["prod", "staging"],
|
||||
)
|
||||
|
||||
model = ModelWithACL(
|
||||
identifier="model-serialize",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=attributes,
|
||||
)
|
||||
|
||||
await registry.register(model)
|
||||
|
||||
new_registry = CachedDiskDistributionRegistry(registry.kvstore)
|
||||
await new_registry.initialize()
|
||||
|
||||
loaded_model = await new_registry.get("model", "model-serialize")
|
||||
assert loaded_model is not None
|
||||
|
||||
assert loaded_model.access_attributes.roles == ["admin", "researcher"]
|
||||
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"]
|
||||
assert loaded_model.access_attributes.projects == ["project-a", "project-b"]
|
||||
assert loaded_model.access_attributes.namespaces == ["prod", "staging"]
|
240
tests/unit/server/test_access_control.py
Normal file
240
tests/unit/server/test_access_control.py
Normal file
|
@ -0,0 +1,240 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
||||
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
|
||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||
|
||||
|
||||
class AsyncMock(MagicMock):
|
||||
async def __call__(self, *args, **kwargs):
|
||||
return super(AsyncMock, self).__call__(*args, **kwargs)
|
||||
|
||||
|
||||
def _return_model(model):
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_setup():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_path = os.path.join(temp_dir, "test_access_control.db")
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
|
||||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||
await kvstore.initialize()
|
||||
registry = CachedDiskDistributionRegistry(kvstore)
|
||||
await registry.initialize()
|
||||
|
||||
mock_inference = Mock()
|
||||
mock_inference.__provider_spec__ = MagicMock()
|
||||
mock_inference.__provider_spec__.api = Api.inference
|
||||
mock_inference.register_model = AsyncMock(side_effect=_return_model)
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"test_provider": mock_inference},
|
||||
dist_registry=registry,
|
||||
)
|
||||
yield registry, routing_table
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
identifier="model-public",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-public",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
model_admin_only = ModelWithACL(
|
||||
identifier="model-admin",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-admin",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
)
|
||||
model_data_scientist = ModelWithACL(
|
||||
identifier="model-data-scientist",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-data-scientist",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]),
|
||||
)
|
||||
await registry.register(model_public)
|
||||
await registry.register(model_admin_only)
|
||||
await registry.register(model_data_scientist)
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 2
|
||||
|
||||
model = await routing_table.get_model("model-public")
|
||||
assert model.identifier == "model-public"
|
||||
model = await routing_table.get_model("model-admin")
|
||||
assert model.identifier == "model-admin"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-data-scientist")
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 1
|
||||
assert all_models.data[0].identifier == "model-public"
|
||||
model = await routing_table.get_model("model-public")
|
||||
assert model.identifier == "model-public"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-admin")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-data-scientist")
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]}
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 2
|
||||
model_ids = [m.identifier for m in all_models.data]
|
||||
assert "model-public" in model_ids
|
||||
assert "model-data-scientist" in model_ids
|
||||
assert "model-admin" not in model_ids
|
||||
model = await routing_table.get_model("model-public")
|
||||
assert model.identifier == "model-public"
|
||||
model = await routing_table.get_model("model-data-scientist")
|
||||
assert model.identifier == "model-data-scientist"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-admin")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
identifier="model-updates",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-updates",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
await registry.register(model_public)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["user"],
|
||||
}
|
||||
model = await routing_table.get_model("model-updates")
|
||||
assert model.identifier == "model-updates"
|
||||
model_public.access_attributes = AccessAttributes(roles=["admin"])
|
||||
await registry.update(model_public)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["user"],
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-updates")
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["admin"],
|
||||
}
|
||||
model = await routing_table.get_model("model-updates")
|
||||
assert model.identifier == "model-updates"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model = ModelWithACL(
|
||||
identifier="model-empty-attrs",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-empty-attrs",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(),
|
||||
)
|
||||
await registry.register(model)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": [],
|
||||
}
|
||||
result = await routing_table.get_model("model-empty-attrs")
|
||||
assert result.identifier == "model-empty-attrs"
|
||||
all_models = await routing_table.list_models()
|
||||
model_ids = [m.identifier for m in all_models.data]
|
||||
assert "model-empty-attrs" in model_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
identifier="model-public-2",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-public-2",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
model_restricted = ModelWithACL(
|
||||
identifier="model-restricted",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-restricted",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
)
|
||||
await registry.register(model_public)
|
||||
await registry.register(model_restricted)
|
||||
mock_get_auth_attributes.return_value = None
|
||||
model = await routing_table.get_model("model-public-2")
|
||||
assert model.identifier == "model-public-2"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-restricted")
|
||||
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 1
|
||||
assert all_models.data[0].identifier == "model-public-2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
||||
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
|
||||
"""Test that newly created resources inherit access attributes from their creator."""
|
||||
registry, routing_table = test_setup
|
||||
|
||||
# Set creator's attributes
|
||||
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
|
||||
mock_get_auth_attributes.return_value = creator_attributes
|
||||
|
||||
# Create model without explicit access attributes
|
||||
model = ModelWithACL(
|
||||
identifier="auto-access-model",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="auto-access-model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
await routing_table.register_object(model)
|
||||
|
||||
# Verify the model got creator's attributes
|
||||
registered_model = await routing_table.get_model("auto-access-model")
|
||||
assert registered_model.access_attributes is not None
|
||||
assert registered_model.access_attributes.roles == ["data-scientist"]
|
||||
assert registered_model.access_attributes.teams == ["ml-team"]
|
||||
assert registered_model.access_attributes.projects == ["llama-3"]
|
||||
|
||||
# Verify another user without matching attributes can't access it
|
||||
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]}
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("auto-access-model")
|
||||
|
||||
# But a user with matching attributes can
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "platform-team"],
|
||||
"projects": ["llama-3"],
|
||||
}
|
||||
model = await routing_table.get_model("auto-access-model")
|
||||
assert model.identifier == "auto-access-model"
|
|
@ -13,6 +13,15 @@ from fastapi.testclient import TestClient
|
|||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data):
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_endpoint():
|
||||
return "http://mock-auth-service/validate"
|
||||
|
@ -45,16 +54,32 @@ def client(app):
|
|||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scope():
|
||||
return {
|
||||
"type": "http",
|
||||
"path": "/models/list",
|
||||
"headers": [
|
||||
(b"content-type", b"application/json"),
|
||||
(b"authorization", b"Bearer test-api-key"),
|
||||
(b"user-agent", b"test-user-agent"),
|
||||
],
|
||||
"query_string": b"limit=100&offset=0",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_middleware(mock_auth_endpoint):
|
||||
mock_app = AsyncMock()
|
||||
return AuthenticationMiddleware(mock_app, mock_auth_endpoint), mock_app
|
||||
|
||||
|
||||
async def mock_post_success(*args, **kwargs):
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
return mock_response
|
||||
return MockResponse(200, {"message": "Authentication successful"})
|
||||
|
||||
|
||||
async def mock_post_failure(*args, **kwargs):
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 401
|
||||
return mock_response
|
||||
return MockResponse(401, {"message": "Authentication failed"})
|
||||
|
||||
|
||||
async def mock_post_exception(*args, **kwargs):
|
||||
|
@ -96,8 +121,7 @@ def test_auth_service_error(client, valid_api_key):
|
|||
|
||||
def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
|
||||
with patch("httpx.AsyncClient.post") as mock_post:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response = MockResponse(200, {"message": "Authentication successful"})
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client.get(
|
||||
|
@ -119,6 +143,64 @@ def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
|
|||
payload = kwargs["json"]
|
||||
assert payload["api_key"] == valid_api_key
|
||||
assert payload["request"]["path"] == "/test"
|
||||
assert "authorization" in payload["request"]["headers"]
|
||||
assert "authorization" not in payload["request"]["headers"]
|
||||
assert "param1" in payload["request"]["params"]
|
||||
assert "param2" in payload["request"]["params"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_middleware_with_access_attributes(mock_middleware, mock_scope):
|
||||
middleware, mock_app = mock_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
mock_client_instance.post.return_value = MockResponse(
|
||||
200,
|
||||
{
|
||||
"access_attributes": {
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["project-x", "project-y"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
await middleware(mock_scope, mock_receive, mock_send)
|
||||
|
||||
assert "user_attributes" in mock_scope
|
||||
assert mock_scope["user_attributes"]["roles"] == ["admin", "user"]
|
||||
assert mock_scope["user_attributes"]["teams"] == ["ml-team"]
|
||||
assert mock_scope["user_attributes"]["projects"] == ["project-x", "project-y"]
|
||||
|
||||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_middleware_no_attributes(mock_middleware, mock_scope):
|
||||
"""Test middleware behavior with no access attributes"""
|
||||
middleware, mock_app = mock_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
mock_client_instance.post.return_value = MockResponse(
|
||||
200,
|
||||
{
|
||||
"message": "Authentication successful"
|
||||
# No access_attributes
|
||||
},
|
||||
)
|
||||
|
||||
await middleware(mock_scope, mock_receive, mock_send)
|
||||
|
||||
assert "user_attributes" in mock_scope
|
||||
attributes = mock_scope["user_attributes"]
|
||||
assert "namespaces" in attributes
|
||||
assert attributes["namespaces"] == ["test-api-key"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue