chore!: remove the agents (sessions and turns) API (#4055)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Pre-commit / pre-commit (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
Python Package Build Test / build (3.13) (push) Failing after 2s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Test External API and Providers / test-external (venv) (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 9s
Unit Tests / unit-tests (3.13) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
API Conformance Tests / check-schema-compatibility (push) Successful in 13s
UI Tests / ui-tests (22) (push) Successful in 1m10s

- Removes the deprecated agents (sessions and turns) API that was marked
alpha in 0.3.0
- Cleans up unused imports and orphaned types after the API removal
- Removes `SessionNotFoundError` and `AgentTurnInputType` which are no
longer needed

The agents API is completely superseded by the Responses + Conversations
APIs, and the client SDK Agent class already uses those implementations.

Corresponding client-side PR:
https://github.com/llamastack/llama-stack-client-python/pull/295
This commit is contained in:
Ashwin Bharambe 2025-11-04 09:38:39 -08:00 committed by GitHub
parent a6ddbae0ed
commit a8a8aa56c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
1037 changed files with 393 additions and 309806 deletions

View file

@ -1,347 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import UTC, datetime
from unittest.mock import AsyncMock, patch
import pytest
from llama_stack.apis.agents import Session
from llama_stack.core.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import (
AgentPersistence,
AgentSessionInfo,
)
from llama_stack.providers.utils.kvstore import KVStore
@pytest.fixture
def mock_kvstore():
return AsyncMock(spec=KVStore)
@pytest.fixture
def mock_policy():
return []
@pytest.fixture
def agent_persistence(mock_kvstore, mock_policy):
return AgentPersistence(agent_id="test-agent-123", kvstore=mock_kvstore, policy=mock_policy)
@pytest.fixture
def sample_session():
return AgentSessionInfo(
session_id="session-123",
session_name="Test Session",
started_at=datetime.now(UTC),
owner=User(principal="user-123", attributes=None),
turns=[],
identifier="test-session",
type="session",
)
@pytest.fixture
def sample_session_json(sample_session):
return sample_session.model_dump_json()
class TestAgentPersistenceListSessions:
def setup_mock_kvstore(self, mock_kvstore, session_keys=None, turn_keys=None, invalid_keys=None, custom_data=None):
"""Helper to setup mock kvstore with sessions, turns, and custom/invalid data
Args:
mock_kvstore: The mock KVStore object
session_keys: List of session keys or dict mapping keys to custom session data
turn_keys: List of turn keys or dict mapping keys to custom turn data
invalid_keys: Dict mapping keys to invalid/corrupt data
custom_data: Additional custom data to add to the mock responses
"""
all_keys = []
mock_data = {}
# session keys
if session_keys:
if isinstance(session_keys, dict):
all_keys.extend(session_keys.keys())
mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in session_keys.items()})
else:
all_keys.extend(session_keys)
for key in session_keys:
session_id = key.split(":")[-1]
mock_data[key] = json.dumps(
{
"session_id": session_id,
"session_name": f"Session {session_id}",
"started_at": datetime.now(UTC).isoformat(),
"turns": [],
}
)
# turn keys
if turn_keys:
if isinstance(turn_keys, dict):
all_keys.extend(turn_keys.keys())
mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in turn_keys.items()})
else:
all_keys.extend(turn_keys)
for key in turn_keys:
parts = key.split(":")
session_id = parts[-2]
turn_id = parts[-1]
mock_data[key] = json.dumps(
{
"turn_id": turn_id,
"session_id": session_id,
"input_messages": [],
"started_at": datetime.now(UTC).isoformat(),
}
)
if invalid_keys:
all_keys.extend(invalid_keys.keys())
mock_data.update(invalid_keys)
if custom_data:
mock_data.update(custom_data)
values_list = list(mock_data.values())
mock_kvstore.values_in_range.return_value = values_list
async def mock_get(key):
return mock_data.get(key)
mock_kvstore.get.side_effect = mock_get
return mock_data
@pytest.mark.parametrize(
"scenario",
[
{
# from this issue: https://github.com/meta-llama/llama-stack/issues/3048
"name": "reported_bug",
"session_keys": ["session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"],
"turn_keys": [
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad"
],
"expected_sessions": ["1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"],
},
{
"name": "basic_filtering",
"session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"],
"turn_keys": ["session:test-agent-123:session-1:turn-1", "session:test-agent-123:session-1:turn-2"],
"expected_sessions": ["session-1", "session-2"],
},
{
"name": "multiple_turns_per_session",
"session_keys": ["session:test-agent-123:session-456"],
"turn_keys": [
"session:test-agent-123:session-456:turn-789",
"session:test-agent-123:session-456:turn-790",
],
"expected_sessions": ["session-456"],
},
{
"name": "multiple_sessions_with_turns",
"session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"],
"turn_keys": [
"session:test-agent-123:session-1:turn-1",
"session:test-agent-123:session-1:turn-2",
"session:test-agent-123:session-2:turn-3",
],
"expected_sessions": ["session-1", "session-2"],
},
],
)
async def test_list_sessions_key_filtering(self, agent_persistence, mock_kvstore, scenario):
self.setup_mock_kvstore(mock_kvstore, session_keys=scenario["session_keys"], turn_keys=scenario["turn_keys"])
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
result = await agent_persistence.list_sessions()
assert len(result) == len(scenario["expected_sessions"])
session_ids = {s.session_id for s in result}
for expected_id in scenario["expected_sessions"]:
assert expected_id in session_ids
# no errors should be logged
mock_log.error.assert_not_called()
@pytest.mark.parametrize(
"error_scenario",
[
{
"name": "invalid_json",
"valid_keys": ["session:test-agent-123:valid-session"],
"invalid_data": {"session:test-agent-123:invalid-json": "corrupted-json-data{"},
"expected_valid_sessions": ["valid-session"],
"expected_error_count": 1,
},
{
"name": "missing_fields",
"valid_keys": ["session:test-agent-123:valid-session"],
"invalid_data": {
"session:test-agent-123:invalid-schema": json.dumps(
{
"session_id": "invalid-schema",
"session_name": "Missing Fields",
# missing `started_at` and `turns`
}
)
},
"expected_valid_sessions": ["valid-session"],
"expected_error_count": 1,
},
{
"name": "multiple_invalid",
"valid_keys": ["session:test-agent-123:valid-session-1", "session:test-agent-123:valid-session-2"],
"invalid_data": {
"session:test-agent-123:corrupted-json": "not-valid-json{",
"session:test-agent-123:incomplete-data": json.dumps({"incomplete": "data"}),
},
"expected_valid_sessions": ["valid-session-1", "valid-session-2"],
"expected_error_count": 2,
},
],
)
async def test_list_sessions_error_handling(self, agent_persistence, mock_kvstore, error_scenario):
session_keys = {}
for key in error_scenario["valid_keys"]:
session_id = key.split(":")[-1]
session_keys[key] = {
"session_id": session_id,
"session_name": f"Valid {session_id}",
"started_at": datetime.now(UTC).isoformat(),
"turns": [],
}
self.setup_mock_kvstore(mock_kvstore, session_keys=session_keys, invalid_keys=error_scenario["invalid_data"])
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
result = await agent_persistence.list_sessions()
# only valid sessions should be returned
assert len(result) == len(error_scenario["expected_valid_sessions"])
session_ids = {s.session_id for s in result}
for expected_id in error_scenario["expected_valid_sessions"]:
assert expected_id in session_ids
# error should be logged
assert mock_log.error.call_count > 0
assert mock_log.error.call_count == error_scenario["expected_error_count"]
async def test_list_sessions_empty(self, agent_persistence, mock_kvstore):
mock_kvstore.values_in_range.return_value = []
result = await agent_persistence.list_sessions()
assert result == []
mock_kvstore.values_in_range.assert_called_once_with(
start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff"
)
async def test_list_sessions_properties(self, agent_persistence, mock_kvstore):
session_data = {
"session_id": "session-123",
"session_name": "Test Session",
"started_at": datetime.now(UTC).isoformat(),
"owner": {"principal": "user-123", "attributes": None},
"turns": [],
}
self.setup_mock_kvstore(mock_kvstore, session_keys={"session:test-agent-123:session-123": session_data})
result = await agent_persistence.list_sessions()
assert len(result) == 1
assert isinstance(result[0], Session)
assert result[0].session_id == "session-123"
assert result[0].session_name == "Test Session"
assert result[0].turns == []
assert hasattr(result[0], "started_at")
async def test_list_sessions_kvstore_exception(self, agent_persistence, mock_kvstore):
mock_kvstore.values_in_range.side_effect = Exception("KVStore error")
with pytest.raises(Exception, match="KVStore error"):
await agent_persistence.list_sessions()
async def test_bug_data_loss_with_real_data(self, agent_persistence, mock_kvstore):
# tests the handling of the issue reported in: https://github.com/meta-llama/llama-stack/issues/3048
session_data = {
"session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d",
"session_name": "Test Session",
"started_at": datetime.now(UTC).isoformat(),
"turns": [],
}
turn_data = {
"turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad",
"session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d",
"input_messages": [
{"role": "user", "content": "if i had a cluster i would want to call it persistence01", "context": None}
],
"steps": [
{
"turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad",
"step_id": "c0f797dd-3d34-4bc5-a8f4-db6af9455132",
"started_at": "2025-08-05T14:31:50.000484Z",
"completed_at": "2025-08-05T14:31:51.303691Z",
"step_type": "inference",
"model_response": {
"role": "assistant",
"content": "OK, I can create a cluster named 'persistence01' for you.",
"stop_reason": "end_of_turn",
"tool_calls": [],
},
}
],
"output_message": {
"role": "assistant",
"content": "OK, I can create a cluster named 'persistence01' for you.",
"stop_reason": "end_of_turn",
"tool_calls": [],
},
"output_attachments": [],
"started_at": "2025-08-05T14:31:49.999950Z",
"completed_at": "2025-08-05T14:31:51.305384Z",
}
mock_data = {
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d": json.dumps(session_data),
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad": json.dumps(
turn_data
),
}
mock_kvstore.values_in_range.return_value = list(mock_data.values())
async def mock_get(key):
return mock_data.get(key)
mock_kvstore.get.side_effect = mock_get
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
result = await agent_persistence.list_sessions()
assert len(result) == 1
assert result[0].session_id == "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"
# confirm no errors logged
mock_log.error.assert_not_called()
async def test_list_sessions_key_range_construction(self, agent_persistence, mock_kvstore):
mock_kvstore.values_in_range.return_value = []
await agent_persistence.list_sessions()
mock_kvstore.values_in_range.assert_called_once_with(
start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff"
)

View file

@ -1,196 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.apis.agents import Document
from llama_stack.apis.common.content_types import URL, TextContentItem
from llama_stack.providers.inline.agents.meta_reference.agent_instance import get_raw_document_text
async def test_get_raw_document_text_supports_text_mime_types():
"""Test that the function accepts text/* mime types."""
document = Document(content="Sample text content", mime_type="text/plain")
result = await get_raw_document_text(document)
assert result == "Sample text content"
async def test_get_raw_document_text_supports_yaml_mime_type():
"""Test that the function accepts application/yaml mime type."""
yaml_content = """
name: test
version: 1.0
items:
- item1
- item2
"""
document = Document(content=yaml_content, mime_type="application/yaml")
result = await get_raw_document_text(document)
assert result == yaml_content
async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning():
"""Test that the function accepts text/yaml but emits a deprecation warning."""
yaml_content = """
name: test
version: 1.0
items:
- item1
- item2
"""
document = Document(content=yaml_content, mime_type="text/yaml")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = await get_raw_document_text(document)
# Check that result is correct
assert result == yaml_content
# Check that exactly one warning was issued
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "text/yaml" in str(w[0].message)
assert "application/yaml" in str(w[0].message)
assert "deprecated" in str(w[0].message).lower()
async def test_get_raw_document_text_deprecated_text_yaml_with_url():
"""Test that text/yaml works with URL content and emits warning."""
yaml_content = "name: test\nversion: 1.0"
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
mock_load.return_value = yaml_content
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="text/yaml")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = await get_raw_document_text(document)
# Check that result is correct
assert result == yaml_content
mock_load.assert_called_once_with("https://example.com/config.yaml")
# Check that deprecation warning was issued
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "text/yaml" in str(w[0].message)
async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item():
"""Test that text/yaml works with TextContentItem and emits warning."""
yaml_content = "key: value\nlist:\n - item1\n - item2"
document = Document(content=TextContentItem(text=yaml_content), mime_type="text/yaml")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = await get_raw_document_text(document)
# Check that result is correct
assert result == yaml_content
# Check that deprecation warning was issued
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "text/yaml" in str(w[0].message)
async def test_get_raw_document_text_supports_json_mime_type():
"""Test that the function accepts application/json mime type."""
json_content = '{"name": "test", "version": "1.0", "items": ["item1", "item2"]}'
document = Document(content=json_content, mime_type="application/json")
result = await get_raw_document_text(document)
assert result == json_content
async def test_get_raw_document_text_with_json_text_content_item():
"""Test that the function handles JSON TextContentItem correctly."""
json_content = '{"key": "value", "nested": {"array": [1, 2, 3]}}'
document = Document(content=TextContentItem(text=json_content), mime_type="application/json")
result = await get_raw_document_text(document)
assert result == json_content
async def test_get_raw_document_text_rejects_unsupported_mime_types():
"""Test that the function rejects unsupported mime types."""
document = Document(
content="Some content",
mime_type="application/pdf", # Not supported
)
with pytest.raises(ValueError, match="Unexpected document mime type: application/pdf"):
await get_raw_document_text(document)
async def test_get_raw_document_text_with_url_content():
"""Test that the function handles URL content correctly."""
mock_response = AsyncMock()
mock_response.text = "Content from URL"
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
mock_load.return_value = "Content from URL"
document = Document(content=URL(uri="https://example.com/test.txt"), mime_type="text/plain")
result = await get_raw_document_text(document)
assert result == "Content from URL"
mock_load.assert_called_once_with("https://example.com/test.txt")
async def test_get_raw_document_text_with_yaml_url():
"""Test that the function handles YAML URLs correctly."""
yaml_content = "name: test\nversion: 1.0"
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
mock_load.return_value = yaml_content
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="application/yaml")
result = await get_raw_document_text(document)
assert result == yaml_content
mock_load.assert_called_once_with("https://example.com/config.yaml")
async def test_get_raw_document_text_with_text_content_item():
"""Test that the function handles TextContentItem correctly."""
document = Document(content=TextContentItem(text="Text content item"), mime_type="text/plain")
result = await get_raw_document_text(document)
assert result == "Text content item"
async def test_get_raw_document_text_with_yaml_text_content_item():
"""Test that the function handles YAML TextContentItem correctly."""
yaml_content = "key: value\nlist:\n - item1\n - item2"
document = Document(content=TextContentItem(text=yaml_content), mime_type="application/yaml")
result = await get_raw_document_text(document)
assert result == yaml_content
async def test_get_raw_document_text_rejects_unexpected_content_type():
"""Test that the function rejects unexpected document content types."""
# Create a mock document that bypasses Pydantic validation
mock_document = MagicMock(spec=Document)
mock_document.mime_type = "text/plain"
mock_document.content = 123 # Unexpected content type (not str, URL, or TextContentItem)
with pytest.raises(ValueError, match="Unexpected document content type: <class 'int'>"):
await get_raw_document_text(mock_document)

View file

@ -1,325 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.agents import (
Agent,
AgentConfig,
AgentCreateResponse,
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.inference import Inference
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
@pytest.fixture(autouse=True)
def setup_backends(tmp_path):
"""Register KV and SQL store backends for testing."""
from llama_stack.core.storage.datatypes import SqliteKVStoreConfig, SqliteSqlStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
kv_path = str(tmp_path / "test_kv.db")
sql_path = str(tmp_path / "test_sql.db")
register_kvstore_backends({"kv_default": SqliteKVStoreConfig(db_path=kv_path)})
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=sql_path)})
@pytest.fixture
def mock_apis():
return {
"inference_api": AsyncMock(spec=Inference),
"vector_io_api": AsyncMock(spec=VectorIO),
"safety_api": AsyncMock(spec=Safety),
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
"tool_groups_api": AsyncMock(spec=ToolGroups),
"conversations_api": AsyncMock(spec=Conversations),
}
@pytest.fixture
def config(tmp_path):
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
from llama_stack.providers.inline.agents.meta_reference.config import AgentPersistenceConfig
return MetaReferenceAgentsImplConfig(
persistence=AgentPersistenceConfig(
agent_state=KVStoreReference(
backend="kv_default",
namespace="agents",
),
responses=ResponsesStoreReference(
backend="sql_default",
table_name="responses",
),
)
)
@pytest.fixture
async def agents_impl(config, mock_apis):
impl = MetaReferenceAgentsImpl(
config,
mock_apis["inference_api"],
mock_apis["vector_io_api"],
mock_apis["safety_api"],
mock_apis["tool_runtime_api"],
mock_apis["tool_groups_api"],
mock_apis["conversations_api"],
[],
)
await impl.initialize()
yield impl
await impl.shutdown()
@pytest.fixture
def sample_agent_config():
return AgentConfig(
sampling_params={
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 1.0,
},
input_shields=["string"],
output_shields=["string"],
toolgroups=["mcp::my_mcp_server"],
client_tools=[
{
"name": "client_tool",
"description": "Client Tool",
"parameters": [
{
"name": "string",
"parameter_type": "string",
"description": "string",
"required": True,
"default": None,
}
],
"metadata": {
"property1": None,
"property2": None,
},
}
],
tool_choice="auto",
tool_prompt_format="json",
tool_config={
"tool_choice": "auto",
"tool_prompt_format": "json",
"system_message_behavior": "append",
},
max_infer_iters=10,
model="string",
instructions="string",
enable_session_persistence=False,
response_format={
"type": "json_schema",
"json_schema": {
"property1": None,
"property2": None,
},
},
)
async def test_create_agent(agents_impl, sample_agent_config):
response = await agents_impl.create_agent(sample_agent_config)
assert isinstance(response, AgentCreateResponse)
assert response.agent_id is not None
stored_agent = await agents_impl.persistence_store.get(f"agent:{response.agent_id}")
assert stored_agent is not None
agent_info = AgentInfo.model_validate_json(stored_agent)
assert agent_info.model == sample_agent_config.model
assert agent_info.created_at is not None
assert isinstance(agent_info.created_at, datetime)
async def test_get_agent(agents_impl, sample_agent_config):
create_response = await agents_impl.create_agent(sample_agent_config)
agent_id = create_response.agent_id
agent = await agents_impl.get_agent(agent_id)
assert isinstance(agent, Agent)
assert agent.agent_id == agent_id
assert agent.agent_config.model == sample_agent_config.model
assert agent.created_at is not None
assert isinstance(agent.created_at, datetime)
async def test_list_agents(agents_impl, sample_agent_config):
agent1_response = await agents_impl.create_agent(sample_agent_config)
agent2_response = await agents_impl.create_agent(sample_agent_config)
response = await agents_impl.list_agents()
assert isinstance(response, PaginatedResponse)
assert len(response.data) == 2
agent_ids = {agent["agent_id"] for agent in response.data}
assert agent1_response.agent_id in agent_ids
assert agent2_response.agent_id in agent_ids
@pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting
config = sample_agent_config.model_copy()
config.enable_session_persistence = enable_session_persistence
response = await agents_impl.create_agent(config)
agent_id = response.agent_id
# Create a session
session_response = await agents_impl.create_agent_session(agent_id, "test_session")
assert session_response.session_id is not None
# Verify the session was stored
session = await agents_impl.get_agents_session(session_response.session_id, agent_id)
assert session.session_name == "test_session"
assert session.session_id == session_response.session_id
assert session.started_at is not None
assert session.turns == []
# Delete the session
await agents_impl.delete_agents_session(session_response.session_id, agent_id)
# Verify the session was deleted
with pytest.raises(ValueError):
await agents_impl.get_agents_session(session_response.session_id, agent_id)
@pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting
config = sample_agent_config.model_copy()
config.enable_session_persistence = enable_session_persistence
response = await agents_impl.create_agent(config)
agent_id = response.agent_id
# Create multiple sessions
session1 = await agents_impl.create_agent_session(agent_id, "session1")
session2 = await agents_impl.create_agent_session(agent_id, "session2")
# List sessions
sessions = await agents_impl.list_agent_sessions(agent_id)
assert len(sessions.data) == 2
session_ids = {s["session_id"] for s in sessions.data}
assert session1.session_id in session_ids
assert session2.session_id in session_ids
# Delete one session
await agents_impl.delete_agents_session(session1.session_id, agent_id)
# Verify the session was deleted
with pytest.raises(ValueError):
await agents_impl.get_agents_session(session1.session_id, agent_id)
# List sessions again
sessions = await agents_impl.list_agent_sessions(agent_id)
assert len(sessions.data) == 1
assert session2.session_id in {s["session_id"] for s in sessions.data}
async def test_delete_agent(agents_impl, sample_agent_config):
# Create an agent
response = await agents_impl.create_agent(sample_agent_config)
agent_id = response.agent_id
# Delete the agent
await agents_impl.delete_agent(agent_id)
# Verify the agent was deleted
with pytest.raises(ValueError):
await agents_impl.get_agent(agent_id)
async def test__initialize_tools(agents_impl, sample_agent_config):
# Mock tool_groups_api.list_tools()
agents_impl.tool_groups_api.list_tools.return_value = ListToolDefsResponse(
data=[
ToolDef(
name="story_maker",
toolgroup_id="mcp::my_mcp_server",
description="Make a story",
input_schema={
"type": "object",
"properties": {
"story_title": {"type": "string", "description": "Title of the story", "title": "Story Title"},
"input_words": {
"type": "array",
"description": "Input words",
"items": {"type": "string"},
"title": "Input Words",
"default": [],
},
},
"required": ["story_title"],
},
)
]
)
create_response = await agents_impl.create_agent(sample_agent_config)
agent_id = create_response.agent_id
# Get an instance of ChatAgent
chat_agent = await agents_impl._get_agent_impl(agent_id)
assert chat_agent is not None
assert isinstance(chat_agent, ChatAgent)
# Initialize tool definitions
await chat_agent._initialize_tools()
assert len(chat_agent.tool_defs) == 2
# Verify the first tool, which is a client tool
first_tool = chat_agent.tool_defs[0]
assert first_tool.tool_name == "client_tool"
assert first_tool.description == "Client Tool"
# Verify the second tool, which is an MCP tool that has an array-type property
second_tool = chat_agent.tool_defs[1]
assert second_tool.tool_name == "story_maker"
assert second_tool.description == "Make a story"
# Verify the input schema
input_schema = second_tool.input_schema
assert input_schema is not None
assert input_schema["type"] == "object"
properties = input_schema["properties"]
assert len(properties) == 2
# Verify a string property
story_title = properties["story_title"]
assert story_title["type"] == "string"
assert story_title["description"] == "Title of the story"
assert story_title["title"] == "Story Title"
# Verify an array property
input_words = properties["input_words"]
assert input_words["type"] == "array"
assert input_words["description"] == "Input words"
assert input_words["items"]["type"] == "string"
assert input_words["title"] == "Input Words"
assert input_words["default"] == []
# Verify required fields
assert input_schema["required"] == ["story_title"]

View file

@ -1,23 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import yaml
from llama_stack.apis.inference import (
OpenAIChatCompletion,
)
FIXTURES_DIR = os.path.dirname(os.path.abspath(__file__))
def load_chat_completion_fixture(filename: str) -> OpenAIChatCompletion:
fixture_path = os.path.join(FIXTURES_DIR, filename)
with open(fixture_path) as f:
data = yaml.safe_load(f)
return OpenAIChatCompletion(**data)

View file

@ -1,9 +0,0 @@
id: chat-completion-123
choices:
- message:
content: "Dublin"
role: assistant
finish_reason: stop
index: 0
created: 1234567890
model: meta-llama/Llama-3.1-8B-Instruct

View file

@ -1,14 +0,0 @@
id: chat-completion-123
choices:
- message:
tool_calls:
- id: tool_call_123
type: function
function:
name: web_search
arguments: '{"query":"What is the capital of Ireland?"}'
role: assistant
finish_reason: stop
index: 0
created: 1234567890
model: meta-llama/Llama-3.1-8B-Instruct

View file

@ -1,249 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseOutputItemDone,
OpenAIResponseOutputMessageContentOutputText,
)
from llama_stack.apis.common.errors import (
ConversationNotFoundError,
InvalidConversationIdError,
)
from llama_stack.apis.conversations.conversations import (
ConversationItemList,
)
# Import existing fixtures from the main responses test file
pytest_plugins = ["tests.unit.providers.agents.meta_reference.test_openai_responses"]
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
@pytest.fixture
def responses_impl_with_conversations(
mock_inference_api,
mock_tool_groups_api,
mock_tool_runtime_api,
mock_responses_store,
mock_vector_io_api,
mock_conversations_api,
mock_safety_api,
):
"""Create OpenAIResponsesImpl instance with conversations API."""
return OpenAIResponsesImpl(
inference_api=mock_inference_api,
tool_groups_api=mock_tool_groups_api,
tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api,
conversations_api=mock_conversations_api,
safety_api=mock_safety_api,
)
class TestConversationValidation:
"""Test conversation ID validation logic."""
async def test_nonexistent_conversation_raises_error(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test that ConversationNotFoundError is raised for non-existent conversation."""
conv_id = "conv_nonexistent"
# Mock conversation not found
mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent")
with pytest.raises(ConversationNotFoundError):
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation=conv_id, stream=False
)
class TestMessageSyncing:
"""Test message syncing to conversations."""
async def test_sync_response_to_conversation_simple(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test syncing simple response to conversation."""
conv_id = "conv_test123"
input_text = "What are the 5 Ds of dodgeball?"
# Output items (what the model generated)
output_items = [
OpenAIResponseMessage(
id="msg_response",
content=[
OpenAIResponseOutputMessageContentOutputText(
text="The 5 Ds are: Dodge, Duck, Dip, Dive, and Dodge.", type="output_text", annotations=[]
)
],
role="assistant",
status="completed",
type="message",
)
]
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_text, output_items)
# should call add_items with user input and assistant response
mock_conversations_api.add_items.assert_called_once()
call_args = mock_conversations_api.add_items.call_args
assert call_args[0][0] == conv_id # conversation_id
items = call_args[0][1] # conversation_items
assert len(items) == 2
# User message
assert items[0].type == "message"
assert items[0].role == "user"
assert items[0].content[0].type == "input_text"
assert items[0].content[0].text == input_text
# Assistant message
assert items[1].type == "message"
assert items[1].role == "assistant"
async def test_sync_response_to_conversation_api_error(
self, responses_impl_with_conversations, mock_conversations_api
):
mock_conversations_api.add_items.side_effect = Exception("API Error")
output_items = []
# matching the behavior of OpenAI here
with pytest.raises(Exception, match="API Error"):
await responses_impl_with_conversations._sync_response_to_conversation(
"conv_test123", "Hello", output_items
)
async def test_sync_with_list_input(self, responses_impl_with_conversations, mock_conversations_api):
"""Test syncing with list of input messages."""
conv_id = "conv_test123"
input_messages = [
OpenAIResponseMessage(role="user", content=[{"type": "input_text", "text": "First message"}]),
]
output_items = [
OpenAIResponseMessage(
id="msg_response",
content=[OpenAIResponseOutputMessageContentOutputText(text="Response", type="output_text")],
role="assistant",
status="completed",
type="message",
)
]
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_messages, output_items)
mock_conversations_api.add_items.assert_called_once()
call_args = mock_conversations_api.add_items.call_args
items = call_args[0][1]
# Should have input message + output message
assert len(items) == 2
class TestIntegrationWorkflow:
"""Integration tests for the full conversation workflow."""
async def test_create_response_with_valid_conversation(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test creating a response with a valid conversation parameter."""
mock_conversations_api.list_items.return_value = ConversationItemList(
data=[], first_id=None, has_more=False, last_id=None, object="list"
)
async def mock_streaming_response(*args, **kwargs):
message_item = OpenAIResponseMessage(
id="msg_response",
content=[
OpenAIResponseOutputMessageContentOutputText(
text="Test response", type="output_text", annotations=[]
)
],
role="assistant",
status="completed",
type="message",
)
# Emit output_item.done event first (needed for conversation sync)
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id="resp_test123",
item=message_item,
output_index=0,
sequence_number=1,
type="response.output_item.done",
)
# Then emit response.completed
mock_response = OpenAIResponseObject(
id="resp_test123",
created_at=1234567890,
model="test-model",
object="response",
output=[message_item],
status="completed",
)
yield OpenAIResponseObjectStreamResponseCompleted(response=mock_response, type="response.completed")
responses_impl_with_conversations._create_streaming_response = mock_streaming_response
input_text = "Hello, how are you?"
conversation_id = "conv_test123"
response = await responses_impl_with_conversations.create_openai_response(
input=input_text, model="test-model", conversation=conversation_id, stream=False
)
assert response is not None
assert response.id == "resp_test123"
# Note: conversation sync happens inside _create_streaming_response,
# which we're mocking here, so we can't test it in this unit test.
# The sync logic is tested separately in TestMessageSyncing.
async def test_create_response_with_invalid_conversation_id(self, responses_impl_with_conversations):
"""Test creating a response with an invalid conversation ID."""
with pytest.raises(InvalidConversationIdError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation="invalid_id", stream=False
)
assert "Expected an ID that begins with 'conv_'" in str(exc_info.value)
async def test_create_response_with_nonexistent_conversation(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test creating a response with a non-existent conversation."""
mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent")
with pytest.raises(ConversationNotFoundError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation="conv_nonexistent", stream=False
)
assert "not found" in str(exc_info.value)
async def test_conversation_and_previous_response_id(
self, responses_impl_with_conversations, mock_conversations_api, mock_responses_store
):
with pytest.raises(ValueError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="test", model="test", conversation="conv_123", previous_response_id="resp_123"
)
assert "Mutually exclusive parameters" in str(exc_info.value)
assert "previous_response_id" in str(exc_info.value)
assert "conversation" in str(exc_info.value)

View file

@ -1,367 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatText,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
_extract_citations_from_text,
convert_chat_choice_to_response_message,
convert_response_content_to_chat_content,
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
get_message_type_by_role,
is_function_tool_call,
)
class TestConvertChatChoiceToResponseMessage:
async def test_convert_string_content(self):
choice = OpenAIChoice(
message=OpenAIAssistantMessageParam(content="Test message"),
finish_reason="stop",
index=0,
)
result = await convert_chat_choice_to_response_message(choice)
assert result.role == "assistant"
assert result.status == "completed"
assert len(result.content) == 1
assert isinstance(result.content[0], OpenAIResponseOutputMessageContentOutputText)
assert result.content[0].text == "Test message"
async def test_convert_text_param_content(self):
choice = OpenAIChoice(
message=OpenAIAssistantMessageParam(
content=[OpenAIChatCompletionContentPartTextParam(text="Test text param")]
),
finish_reason="stop",
index=0,
)
with pytest.raises(ValueError) as exc_info:
await convert_chat_choice_to_response_message(choice)
assert "does not yet support output content type" in str(exc_info.value)
class TestConvertResponseContentToChatContent:
async def test_convert_string_content(self):
result = await convert_response_content_to_chat_content("Simple string")
assert result == "Simple string"
async def test_convert_text_content_parts(self):
content = [
OpenAIResponseInputMessageContentText(text="First part"),
OpenAIResponseOutputMessageContentOutputText(text="Second part"),
]
result = await convert_response_content_to_chat_content(content)
assert len(result) == 2
assert isinstance(result[0], OpenAIChatCompletionContentPartTextParam)
assert result[0].text == "First part"
assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam)
assert result[1].text == "Second part"
async def test_convert_image_content(self):
content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")]
result = await convert_response_content_to_chat_content(content)
assert len(result) == 1
assert isinstance(result[0], OpenAIChatCompletionContentPartImageParam)
assert result[0].image_url.url == "https://example.com/image.jpg"
assert result[0].image_url.detail == "high"
class TestConvertResponseInputToChatMessages:
async def test_convert_string_input(self):
result = await convert_response_input_to_chat_messages("User message")
assert len(result) == 1
assert isinstance(result[0], OpenAIUserMessageParam)
assert result[0].content == "User message"
async def test_convert_function_tool_call_output(self):
input_items = [
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_123",
name="test_function",
arguments='{"param": "value"}',
),
OpenAIResponseInputFunctionToolCallOutput(
output="Tool output",
call_id="call_123",
),
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 2
assert isinstance(result[0], OpenAIAssistantMessageParam)
assert result[0].tool_calls[0].id == "call_123"
assert result[0].tool_calls[0].function.name == "test_function"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
assert isinstance(result[1], OpenAIToolMessageParam)
assert result[1].content == "Tool output"
assert result[1].tool_call_id == "call_123"
async def test_convert_function_tool_call(self):
input_items = [
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_456",
name="test_function",
arguments='{"param": "value"}',
)
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 1
assert isinstance(result[0], OpenAIAssistantMessageParam)
assert len(result[0].tool_calls) == 1
assert result[0].tool_calls[0].id == "call_456"
assert result[0].tool_calls[0].function.name == "test_function"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
async def test_convert_function_call_ordering(self):
input_items = [
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_123",
name="test_function_a",
arguments='{"param": "value"}',
),
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_456",
name="test_function_b",
arguments='{"param": "value"}',
),
OpenAIResponseInputFunctionToolCallOutput(
output="AAA",
call_id="call_123",
),
OpenAIResponseInputFunctionToolCallOutput(
output="BBB",
call_id="call_456",
),
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 4
assert isinstance(result[0], OpenAIAssistantMessageParam)
assert len(result[0].tool_calls) == 1
assert result[0].tool_calls[0].id == "call_123"
assert result[0].tool_calls[0].function.name == "test_function_a"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
assert isinstance(result[1], OpenAIToolMessageParam)
assert result[1].content == "AAA"
assert result[1].tool_call_id == "call_123"
assert isinstance(result[2], OpenAIAssistantMessageParam)
assert len(result[2].tool_calls) == 1
assert result[2].tool_calls[0].id == "call_456"
assert result[2].tool_calls[0].function.name == "test_function_b"
assert result[2].tool_calls[0].function.arguments == '{"param": "value"}'
assert isinstance(result[3], OpenAIToolMessageParam)
assert result[3].content == "BBB"
assert result[3].tool_call_id == "call_456"
async def test_convert_response_message(self):
input_items = [
OpenAIResponseMessage(
role="user",
content=[OpenAIResponseInputMessageContentText(text="User text")],
)
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 1
assert isinstance(result[0], OpenAIUserMessageParam)
# Content should be converted to chat content format
assert len(result[0].content) == 1
assert result[0].content[0].text == "User text"
class TestConvertResponseTextToChatResponseFormat:
async def test_convert_text_format(self):
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatText)
assert result.type == "text"
async def test_convert_json_object_format(self):
text = OpenAIResponseText(format={"type": "json_object"})
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatJSONObject)
async def test_convert_json_schema_format(self):
schema_def = {"type": "object", "properties": {"test": {"type": "string"}}}
text = OpenAIResponseText(
format={
"type": "json_schema",
"name": "test_schema",
"schema": schema_def,
}
)
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatJSONSchema)
assert result.json_schema["name"] == "test_schema"
assert result.json_schema["schema"] == schema_def
async def test_default_text_format(self):
text = OpenAIResponseText()
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatText)
assert result.type == "text"
class TestGetMessageTypeByRole:
async def test_user_role(self):
result = await get_message_type_by_role("user")
assert result == OpenAIUserMessageParam
async def test_system_role(self):
result = await get_message_type_by_role("system")
assert result == OpenAISystemMessageParam
async def test_assistant_role(self):
result = await get_message_type_by_role("assistant")
assert result == OpenAIAssistantMessageParam
async def test_developer_role(self):
result = await get_message_type_by_role("developer")
assert result == OpenAIDeveloperMessageParam
async def test_unknown_role(self):
result = await get_message_type_by_role("unknown")
assert result is None
class TestIsFunctionToolCall:
def test_is_function_tool_call_true(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=OpenAIChatCompletionToolCallFunction(
name="test_function",
arguments="{}",
),
)
tools = [
OpenAIResponseInputToolFunction(
type="function", name="test_function", parameters={"type": "object", "properties": {}}
),
OpenAIResponseInputToolWebSearch(type="web_search"),
]
result = is_function_tool_call(tool_call, tools)
assert result is True
def test_is_function_tool_call_false_different_name(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=OpenAIChatCompletionToolCallFunction(
name="other_function",
arguments="{}",
),
)
tools = [
OpenAIResponseInputToolFunction(
type="function", name="test_function", parameters={"type": "object", "properties": {}}
),
]
result = is_function_tool_call(tool_call, tools)
assert result is False
def test_is_function_tool_call_false_no_function(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=None,
)
tools = [
OpenAIResponseInputToolFunction(
type="function", name="test_function", parameters={"type": "object", "properties": {}}
),
]
result = is_function_tool_call(tool_call, tools)
assert result is False
def test_is_function_tool_call_false_wrong_type(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=OpenAIChatCompletionToolCallFunction(
name="web_search",
arguments="{}",
),
)
tools = [
OpenAIResponseInputToolWebSearch(type="web_search"),
]
result = is_function_tool_call(tool_call, tools)
assert result is False
class TestExtractCitationsFromText:
def test_extract_citations_and_annotations(self):
text = "Start [not-a-file]. New source <|file-abc123|>. "
text += "Other source <|file-def456|>? Repeat source <|file-abc123|>! No citation."
file_mapping = {"file-abc123": "doc1.pdf", "file-def456": "doc2.txt"}
annotations, cleaned_text = _extract_citations_from_text(text, file_mapping)
expected_annotations = [
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=30),
OpenAIResponseAnnotationFileCitation(file_id="file-def456", filename="doc2.txt", index=44),
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=59),
]
expected_clean_text = "Start [not-a-file]. New source. Other source? Repeat source! No citation."
assert cleaned_text == expected_clean_text
assert annotations == expected_annotations
# OpenAI cites at the end of the sentence
assert cleaned_text[expected_annotations[0].index] == "."
assert cleaned_text[expected_annotations[1].index] == "?"
assert cleaned_text[expected_annotations[2].index] == "!"

View file

@ -1,183 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseInputToolWebSearch,
OpenAIResponseObject,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseToolMCP,
)
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
class TestToolContext:
def test_no_tools(self):
tools = []
context = ToolContext(tools)
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 0
assert len(context.previous_tools) == 0
assert len(context.previous_tool_listings) == 0
def test_no_previous_tools(self):
tools = [
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
OpenAIResponseInputToolMCP(server_label="label", server_url="url"),
]
context = ToolContext(tools)
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 2
assert len(context.previous_tools) == 0
assert len(context.previous_tool_listings) == 0
def test_reusable_server(self):
tools = [
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
)
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
OpenAIResponseToolMCP(server_label="alabel"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 1
assert context.tools_to_process[0].type == "file_search"
assert len(context.previous_tools) == 1
assert context.previous_tools["test_tool"].server_label == "alabel"
assert context.previous_tools["test_tool"].server_url == "aurl"
assert len(context.previous_tool_listings) == 1
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "alabel"
def test_multiple_reusable_servers(self):
tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
),
OpenAIResponseOutputMessageMCPListTools(
id="test2",
server_label="anotherlabel",
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
),
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 2
assert context.tools_to_process[0].type == "function"
assert context.tools_to_process[1].type == "web_search"
assert len(context.previous_tools) == 2
assert context.previous_tools["test_tool"].server_label == "alabel"
assert context.previous_tools["test_tool"].server_url == "aurl"
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
assert len(context.previous_tool_listings) == 2
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "alabel"
assert len(context.previous_tool_listings[1].tools) == 1
assert context.previous_tool_listings[1].server_label == "anotherlabel"
def test_multiple_servers_only_one_reusable(self):
tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test2",
server_label="anotherlabel",
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
)
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 3
assert context.tools_to_process[0].type == "function"
assert context.tools_to_process[1].type == "web_search"
assert context.tools_to_process[2].type == "mcp"
assert len(context.previous_tools) == 1
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
assert len(context.previous_tool_listings) == 1
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "anotherlabel"
def test_mismatched_allowed_tools(self):
tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl", allowed_tools=["test_tool_2"]),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool_1", input_schema={})]
),
OpenAIResponseOutputMessageMCPListTools(
id="test2",
server_label="anotherlabel",
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
),
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 3
assert context.tools_to_process[0].type == "function"
assert context.tools_to_process[1].type == "web_search"
assert context.tools_to_process[2].type == "mcp"
assert len(context.previous_tools) == 1
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
assert len(context.previous_tool_listings) == 1
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "anotherlabel"

View file

@ -1,155 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
extract_guardrail_ids,
run_guardrails,
)
@pytest.fixture
def mock_apis():
"""Create mock APIs for testing."""
return {
"inference_api": AsyncMock(),
"tool_groups_api": AsyncMock(),
"tool_runtime_api": AsyncMock(),
"responses_store": AsyncMock(),
"vector_io_api": AsyncMock(),
"conversations_api": AsyncMock(),
"safety_api": AsyncMock(),
}
@pytest.fixture
def responses_impl(mock_apis):
"""Create OpenAIResponsesImpl instance with mocked dependencies."""
return OpenAIResponsesImpl(**mock_apis)
def test_extract_guardrail_ids_from_strings(responses_impl):
"""Test extraction from simple string guardrail IDs."""
guardrails = ["llama-guard", "content-filter", "nsfw-detector"]
result = extract_guardrail_ids(guardrails)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_guardrail_ids_from_objects(responses_impl):
"""Test extraction from ResponseGuardrailSpec objects."""
guardrails = [
ResponseGuardrailSpec(type="llama-guard"),
ResponseGuardrailSpec(type="content-filter"),
]
result = extract_guardrail_ids(guardrails)
assert result == ["llama-guard", "content-filter"]
def test_extract_guardrail_ids_mixed_formats(responses_impl):
"""Test extraction from mixed string and object formats."""
guardrails = [
"llama-guard",
ResponseGuardrailSpec(type="content-filter"),
"nsfw-detector",
]
result = extract_guardrail_ids(guardrails)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_guardrail_ids_none_input(responses_impl):
"""Test extraction with None input."""
result = extract_guardrail_ids(None)
assert result == []
def test_extract_guardrail_ids_empty_list(responses_impl):
"""Test extraction with empty list."""
result = extract_guardrail_ids([])
assert result == []
def test_extract_guardrail_ids_unknown_format(responses_impl):
"""Test extraction with unknown guardrail format raises ValueError."""
# Create an object that's neither string nor ResponseGuardrailSpec
unknown_object = {"invalid": "format"} # Plain dict, not ResponseGuardrailSpec
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
extract_guardrail_ids(guardrails)
@pytest.fixture
def mock_safety_api():
"""Create mock safety API for guardrails testing."""
safety_api = AsyncMock()
# Mock the routing table and shields list for guardrails lookup
safety_api.routing_table = AsyncMock()
shield = AsyncMock()
shield.identifier = "llama-guard"
shield.provider_resource_id = "llama-guard-model"
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
return safety_api
async def test_run_guardrails_no_violation(mock_safety_api):
"""Test guardrails validation with no violations."""
text = "Hello world"
guardrail_ids = ["llama-guard"]
# Mock moderation to return non-flagged content
unflagged_result = ModerationObjectResults(flagged=False, categories={"violence": False})
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[unflagged_result])
mock_safety_api.run_moderation.return_value = mock_moderation_object
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
assert result is None
# Verify run_moderation was called with the correct model
mock_safety_api.run_moderation.assert_called_once()
call_args = mock_safety_api.run_moderation.call_args
assert call_args[1]["model"] == "llama-guard-model"
async def test_run_guardrails_with_violation(mock_safety_api):
"""Test guardrails validation with safety violation."""
text = "Harmful content"
guardrail_ids = ["llama-guard"]
# Mock moderation to return flagged content
flagged_result = ModerationObjectResults(
flagged=True,
categories={"violence": True},
user_message="Content flagged by moderation",
metadata={"violation_type": ["S1"]},
)
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result])
mock_safety_api.run_moderation.return_value = mock_moderation_object
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
assert result == "Content flagged by moderation (flagged for: violence) (violation type: S1)"
async def test_run_guardrails_empty_inputs(mock_safety_api):
"""Test guardrails validation with empty inputs."""
# Test empty guardrail_ids
result = await run_guardrails(mock_safety_api, "test", [])
assert result is None
# Test empty text
result = await run_guardrails(mock_safety_api, "", ["llama-guard"])
assert result is None
# Test both empty
result = await run_guardrails(mock_safety_api, "", [])
assert result is None

View file

@ -1,169 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from datetime import datetime
from unittest.mock import patch
import pytest
from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.core.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@pytest.fixture
async def test_setup(sqlite_kvstore):
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
yield agent_persistence
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Set creator's attributes for the session
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
mock_get_authenticated_user.return_value = User("test_user", creator_attributes)
# Create a session
session_id = await agent_persistence.create_session("Test Session")
# Get the session and verify access attributes were set
session_info = await agent_persistence.get_session_info(session_id)
assert session_info is not None
assert session_info.owner is not None
assert session_info.owner.attributes is not None
assert session_info.owner.attributes["roles"] == ["researcher"]
assert session_info.owner.attributes["teams"] == ["ai-team"]
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with specific access attributes
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
# User with matching attributes can access
mock_get_authenticated_user.return_value = User(
"testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
)
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is not None
assert retrieved_session.session_id == session_id
# User without matching attributes cannot access
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]})
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is None
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
owner=User("someone", {"roles": ["admin"]}),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
# Create a turn for this session
turn_id = str(uuid.uuid4())
turn = Turn(
session_id=session_id,
turn_id=turn_id,
steps=[],
started_at=datetime.now(),
input_messages=[],
output_message=CompletionMessage(
content="Hello",
stop_reason=StopReason.end_of_turn,
),
)
# Admin can add turn
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.add_turn_to_session(session_id, turn)
# Admin can get turn
retrieved_turn = await agent_persistence.get_session_turn(session_id, turn_id)
assert retrieved_turn is not None
assert retrieved_turn.turn_id == turn_id
# Regular user cannot get turn
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
with pytest.raises(ValueError):
await agent_persistence.get_session_turn(session_id, turn_id)
# Regular user cannot get turns for session
with pytest.raises(ValueError):
await agent_persistence.get_session_turns(session_id)
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
owner=User("someone", {"roles": ["admin"]}),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
turn_id = str(uuid.uuid4())
# Admin user can set inference iterations
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
# Admin user can get inference iterations
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters == 5
# Regular user cannot get inference iterations
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters is None
# Regular user cannot set inference iterations (should raise ValueError)
with pytest.raises(ValueError):
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 10)