llama-stack-mirror/tests/unit/providers/agent/test_meta_reference_agent.py
Gordon Sim 01ad876012 feat: fine grained access control policy
This allows a set of rules to be defined for determining access to resources.

Signed-off-by: Gordon Sim <gsim@redhat.com>
2025-05-27 21:37:56 +01:00

235 lines
7.8 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from unittest.mock import AsyncMock
import pytest
import pytest_asyncio
from llama_stack.apis.agents import (
Agent,
AgentConfig,
AgentCreateResponse,
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import Inference
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
@pytest.fixture
def mock_apis():
return {
"inference_api": AsyncMock(spec=Inference),
"vector_io_api": AsyncMock(spec=VectorIO),
"safety_api": AsyncMock(spec=Safety),
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
"tool_groups_api": AsyncMock(spec=ToolGroups),
}
@pytest.fixture
def config(tmp_path):
return MetaReferenceAgentsImplConfig(
persistence_store={
"type": "sqlite",
"db_path": str(tmp_path / "test.db"),
},
responses_store={
"type": "sqlite",
"db_path": str(tmp_path / "test.db"),
},
)
@pytest_asyncio.fixture
async def agents_impl(config, mock_apis):
impl = MetaReferenceAgentsImpl(
config,
mock_apis["inference_api"],
mock_apis["vector_io_api"],
mock_apis["safety_api"],
mock_apis["tool_runtime_api"],
mock_apis["tool_groups_api"],
{},
)
await impl.initialize()
yield impl
await impl.shutdown()
@pytest.fixture
def sample_agent_config():
return AgentConfig(
sampling_params={
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 1.0,
},
input_shields=["string"],
output_shields=["string"],
toolgroups=["string"],
client_tools=[
{
"name": "string",
"description": "string",
"parameters": [
{
"name": "string",
"parameter_type": "string",
"description": "string",
"required": True,
"default": None,
}
],
"metadata": {
"property1": None,
"property2": None,
},
}
],
tool_choice="auto",
tool_prompt_format="json",
tool_config={
"tool_choice": "auto",
"tool_prompt_format": "json",
"system_message_behavior": "append",
},
max_infer_iters=10,
model="string",
instructions="string",
enable_session_persistence=False,
response_format={
"type": "json_schema",
"json_schema": {
"property1": None,
"property2": None,
},
},
)
@pytest.mark.asyncio
async def test_create_agent(agents_impl, sample_agent_config):
response = await agents_impl.create_agent(sample_agent_config)
assert isinstance(response, AgentCreateResponse)
assert response.agent_id is not None
stored_agent = await agents_impl.persistence_store.get(f"agent:{response.agent_id}")
assert stored_agent is not None
agent_info = AgentInfo.model_validate_json(stored_agent)
assert agent_info.model == sample_agent_config.model
assert agent_info.created_at is not None
assert isinstance(agent_info.created_at, datetime)
@pytest.mark.asyncio
async def test_get_agent(agents_impl, sample_agent_config):
create_response = await agents_impl.create_agent(sample_agent_config)
agent_id = create_response.agent_id
agent = await agents_impl.get_agent(agent_id)
assert isinstance(agent, Agent)
assert agent.agent_id == agent_id
assert agent.agent_config.model == sample_agent_config.model
assert agent.created_at is not None
assert isinstance(agent.created_at, datetime)
@pytest.mark.asyncio
async def test_list_agents(agents_impl, sample_agent_config):
agent1_response = await agents_impl.create_agent(sample_agent_config)
agent2_response = await agents_impl.create_agent(sample_agent_config)
response = await agents_impl.list_agents()
assert isinstance(response, PaginatedResponse)
assert len(response.data) == 2
agent_ids = {agent["agent_id"] for agent in response.data}
assert agent1_response.agent_id in agent_ids
assert agent2_response.agent_id in agent_ids
@pytest.mark.asyncio
@pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting
config = sample_agent_config.model_copy()
config.enable_session_persistence = enable_session_persistence
response = await agents_impl.create_agent(config)
agent_id = response.agent_id
# Create a session
session_response = await agents_impl.create_agent_session(agent_id, "test_session")
assert session_response.session_id is not None
# Verify the session was stored
session = await agents_impl.get_agents_session(agent_id, session_response.session_id)
assert session.session_name == "test_session"
assert session.session_id == session_response.session_id
assert session.started_at is not None
assert session.turns == []
# Delete the session
await agents_impl.delete_agents_session(agent_id, session_response.session_id)
# Verify the session was deleted
with pytest.raises(ValueError):
await agents_impl.get_agents_session(agent_id, session_response.session_id)
@pytest.mark.asyncio
@pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting
config = sample_agent_config.model_copy()
config.enable_session_persistence = enable_session_persistence
response = await agents_impl.create_agent(config)
agent_id = response.agent_id
# Create multiple sessions
session1 = await agents_impl.create_agent_session(agent_id, "session1")
session2 = await agents_impl.create_agent_session(agent_id, "session2")
# List sessions
sessions = await agents_impl.list_agent_sessions(agent_id)
assert len(sessions.data) == 2
session_ids = {s["session_id"] for s in sessions.data}
assert session1.session_id in session_ids
assert session2.session_id in session_ids
# Delete one session
await agents_impl.delete_agents_session(agent_id, session1.session_id)
# Verify the session was deleted
with pytest.raises(ValueError):
await agents_impl.get_agents_session(agent_id, session1.session_id)
# List sessions again
sessions = await agents_impl.list_agent_sessions(agent_id)
assert len(sessions.data) == 1
assert session2.session_id in {s["session_id"] for s in sessions.data}
@pytest.mark.asyncio
async def test_delete_agent(agents_impl, sample_agent_config):
# Create an agent
response = await agents_impl.create_agent(sample_agent_config)
agent_id = response.agent_id
# Delete the agent
await agents_impl.delete_agent(agent_id)
# Verify the agent was deleted
with pytest.raises(ValueError):
await agents_impl.get_agent(agent_id)