mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
chore!: remove the agents (sessions and turns) API (#4055)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Pre-commit / pre-commit (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
Python Package Build Test / build (3.13) (push) Failing after 2s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Test External API and Providers / test-external (venv) (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 9s
Unit Tests / unit-tests (3.13) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
API Conformance Tests / check-schema-compatibility (push) Successful in 13s
UI Tests / ui-tests (22) (push) Successful in 1m10s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Pre-commit / pre-commit (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
Python Package Build Test / build (3.13) (push) Failing after 2s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Test External API and Providers / test-external (venv) (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 9s
Unit Tests / unit-tests (3.13) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
API Conformance Tests / check-schema-compatibility (push) Successful in 13s
UI Tests / ui-tests (22) (push) Successful in 1m10s
- Removes the deprecated agents (sessions and turns) API that was marked alpha in 0.3.0 - Cleans up unused imports and orphaned types after the API removal - Removes `SessionNotFoundError` and `AgentTurnInputType` which are no longer needed The agents API is completely superseded by the Responses + Conversations APIs, and the client SDK Agent class already uses those implementations. Corresponding client-side PR: https://github.com/llamastack/llama-stack-client-python/pull/295
This commit is contained in:
parent
a6ddbae0ed
commit
a8a8aa56c0
1037 changed files with 393 additions and 309806 deletions
|
|
@ -1,347 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import Session
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.providers.inline.agents.meta_reference.persistence import (
|
||||
AgentPersistence,
|
||||
AgentSessionInfo,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kvstore():
|
||||
return AsyncMock(spec=KVStore)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_policy():
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_persistence(mock_kvstore, mock_policy):
|
||||
return AgentPersistence(agent_id="test-agent-123", kvstore=mock_kvstore, policy=mock_policy)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_session():
|
||||
return AgentSessionInfo(
|
||||
session_id="session-123",
|
||||
session_name="Test Session",
|
||||
started_at=datetime.now(UTC),
|
||||
owner=User(principal="user-123", attributes=None),
|
||||
turns=[],
|
||||
identifier="test-session",
|
||||
type="session",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_session_json(sample_session):
|
||||
return sample_session.model_dump_json()
|
||||
|
||||
|
||||
class TestAgentPersistenceListSessions:
|
||||
def setup_mock_kvstore(self, mock_kvstore, session_keys=None, turn_keys=None, invalid_keys=None, custom_data=None):
|
||||
"""Helper to setup mock kvstore with sessions, turns, and custom/invalid data
|
||||
|
||||
Args:
|
||||
mock_kvstore: The mock KVStore object
|
||||
session_keys: List of session keys or dict mapping keys to custom session data
|
||||
turn_keys: List of turn keys or dict mapping keys to custom turn data
|
||||
invalid_keys: Dict mapping keys to invalid/corrupt data
|
||||
custom_data: Additional custom data to add to the mock responses
|
||||
"""
|
||||
all_keys = []
|
||||
mock_data = {}
|
||||
|
||||
# session keys
|
||||
if session_keys:
|
||||
if isinstance(session_keys, dict):
|
||||
all_keys.extend(session_keys.keys())
|
||||
mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in session_keys.items()})
|
||||
else:
|
||||
all_keys.extend(session_keys)
|
||||
for key in session_keys:
|
||||
session_id = key.split(":")[-1]
|
||||
mock_data[key] = json.dumps(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"session_name": f"Session {session_id}",
|
||||
"started_at": datetime.now(UTC).isoformat(),
|
||||
"turns": [],
|
||||
}
|
||||
)
|
||||
|
||||
# turn keys
|
||||
if turn_keys:
|
||||
if isinstance(turn_keys, dict):
|
||||
all_keys.extend(turn_keys.keys())
|
||||
mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in turn_keys.items()})
|
||||
else:
|
||||
all_keys.extend(turn_keys)
|
||||
for key in turn_keys:
|
||||
parts = key.split(":")
|
||||
session_id = parts[-2]
|
||||
turn_id = parts[-1]
|
||||
mock_data[key] = json.dumps(
|
||||
{
|
||||
"turn_id": turn_id,
|
||||
"session_id": session_id,
|
||||
"input_messages": [],
|
||||
"started_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
if invalid_keys:
|
||||
all_keys.extend(invalid_keys.keys())
|
||||
mock_data.update(invalid_keys)
|
||||
|
||||
if custom_data:
|
||||
mock_data.update(custom_data)
|
||||
|
||||
values_list = list(mock_data.values())
|
||||
mock_kvstore.values_in_range.return_value = values_list
|
||||
|
||||
async def mock_get(key):
|
||||
return mock_data.get(key)
|
||||
|
||||
mock_kvstore.get.side_effect = mock_get
|
||||
|
||||
return mock_data
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
[
|
||||
{
|
||||
# from this issue: https://github.com/meta-llama/llama-stack/issues/3048
|
||||
"name": "reported_bug",
|
||||
"session_keys": ["session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"],
|
||||
"turn_keys": [
|
||||
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad"
|
||||
],
|
||||
"expected_sessions": ["1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"],
|
||||
},
|
||||
{
|
||||
"name": "basic_filtering",
|
||||
"session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"],
|
||||
"turn_keys": ["session:test-agent-123:session-1:turn-1", "session:test-agent-123:session-1:turn-2"],
|
||||
"expected_sessions": ["session-1", "session-2"],
|
||||
},
|
||||
{
|
||||
"name": "multiple_turns_per_session",
|
||||
"session_keys": ["session:test-agent-123:session-456"],
|
||||
"turn_keys": [
|
||||
"session:test-agent-123:session-456:turn-789",
|
||||
"session:test-agent-123:session-456:turn-790",
|
||||
],
|
||||
"expected_sessions": ["session-456"],
|
||||
},
|
||||
{
|
||||
"name": "multiple_sessions_with_turns",
|
||||
"session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"],
|
||||
"turn_keys": [
|
||||
"session:test-agent-123:session-1:turn-1",
|
||||
"session:test-agent-123:session-1:turn-2",
|
||||
"session:test-agent-123:session-2:turn-3",
|
||||
],
|
||||
"expected_sessions": ["session-1", "session-2"],
|
||||
},
|
||||
],
|
||||
)
|
||||
async def test_list_sessions_key_filtering(self, agent_persistence, mock_kvstore, scenario):
|
||||
self.setup_mock_kvstore(mock_kvstore, session_keys=scenario["session_keys"], turn_keys=scenario["turn_keys"])
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
|
||||
result = await agent_persistence.list_sessions()
|
||||
|
||||
assert len(result) == len(scenario["expected_sessions"])
|
||||
session_ids = {s.session_id for s in result}
|
||||
for expected_id in scenario["expected_sessions"]:
|
||||
assert expected_id in session_ids
|
||||
|
||||
# no errors should be logged
|
||||
mock_log.error.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_scenario",
|
||||
[
|
||||
{
|
||||
"name": "invalid_json",
|
||||
"valid_keys": ["session:test-agent-123:valid-session"],
|
||||
"invalid_data": {"session:test-agent-123:invalid-json": "corrupted-json-data{"},
|
||||
"expected_valid_sessions": ["valid-session"],
|
||||
"expected_error_count": 1,
|
||||
},
|
||||
{
|
||||
"name": "missing_fields",
|
||||
"valid_keys": ["session:test-agent-123:valid-session"],
|
||||
"invalid_data": {
|
||||
"session:test-agent-123:invalid-schema": json.dumps(
|
||||
{
|
||||
"session_id": "invalid-schema",
|
||||
"session_name": "Missing Fields",
|
||||
# missing `started_at` and `turns`
|
||||
}
|
||||
)
|
||||
},
|
||||
"expected_valid_sessions": ["valid-session"],
|
||||
"expected_error_count": 1,
|
||||
},
|
||||
{
|
||||
"name": "multiple_invalid",
|
||||
"valid_keys": ["session:test-agent-123:valid-session-1", "session:test-agent-123:valid-session-2"],
|
||||
"invalid_data": {
|
||||
"session:test-agent-123:corrupted-json": "not-valid-json{",
|
||||
"session:test-agent-123:incomplete-data": json.dumps({"incomplete": "data"}),
|
||||
},
|
||||
"expected_valid_sessions": ["valid-session-1", "valid-session-2"],
|
||||
"expected_error_count": 2,
|
||||
},
|
||||
],
|
||||
)
|
||||
async def test_list_sessions_error_handling(self, agent_persistence, mock_kvstore, error_scenario):
|
||||
session_keys = {}
|
||||
for key in error_scenario["valid_keys"]:
|
||||
session_id = key.split(":")[-1]
|
||||
session_keys[key] = {
|
||||
"session_id": session_id,
|
||||
"session_name": f"Valid {session_id}",
|
||||
"started_at": datetime.now(UTC).isoformat(),
|
||||
"turns": [],
|
||||
}
|
||||
|
||||
self.setup_mock_kvstore(mock_kvstore, session_keys=session_keys, invalid_keys=error_scenario["invalid_data"])
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
|
||||
result = await agent_persistence.list_sessions()
|
||||
|
||||
# only valid sessions should be returned
|
||||
assert len(result) == len(error_scenario["expected_valid_sessions"])
|
||||
session_ids = {s.session_id for s in result}
|
||||
for expected_id in error_scenario["expected_valid_sessions"]:
|
||||
assert expected_id in session_ids
|
||||
|
||||
# error should be logged
|
||||
assert mock_log.error.call_count > 0
|
||||
assert mock_log.error.call_count == error_scenario["expected_error_count"]
|
||||
|
||||
async def test_list_sessions_empty(self, agent_persistence, mock_kvstore):
|
||||
mock_kvstore.values_in_range.return_value = []
|
||||
|
||||
result = await agent_persistence.list_sessions()
|
||||
|
||||
assert result == []
|
||||
mock_kvstore.values_in_range.assert_called_once_with(
|
||||
start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff"
|
||||
)
|
||||
|
||||
async def test_list_sessions_properties(self, agent_persistence, mock_kvstore):
|
||||
session_data = {
|
||||
"session_id": "session-123",
|
||||
"session_name": "Test Session",
|
||||
"started_at": datetime.now(UTC).isoformat(),
|
||||
"owner": {"principal": "user-123", "attributes": None},
|
||||
"turns": [],
|
||||
}
|
||||
|
||||
self.setup_mock_kvstore(mock_kvstore, session_keys={"session:test-agent-123:session-123": session_data})
|
||||
|
||||
result = await agent_persistence.list_sessions()
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], Session)
|
||||
assert result[0].session_id == "session-123"
|
||||
assert result[0].session_name == "Test Session"
|
||||
assert result[0].turns == []
|
||||
assert hasattr(result[0], "started_at")
|
||||
|
||||
async def test_list_sessions_kvstore_exception(self, agent_persistence, mock_kvstore):
|
||||
mock_kvstore.values_in_range.side_effect = Exception("KVStore error")
|
||||
|
||||
with pytest.raises(Exception, match="KVStore error"):
|
||||
await agent_persistence.list_sessions()
|
||||
|
||||
async def test_bug_data_loss_with_real_data(self, agent_persistence, mock_kvstore):
|
||||
# tests the handling of the issue reported in: https://github.com/meta-llama/llama-stack/issues/3048
|
||||
session_data = {
|
||||
"session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d",
|
||||
"session_name": "Test Session",
|
||||
"started_at": datetime.now(UTC).isoformat(),
|
||||
"turns": [],
|
||||
}
|
||||
|
||||
turn_data = {
|
||||
"turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad",
|
||||
"session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d",
|
||||
"input_messages": [
|
||||
{"role": "user", "content": "if i had a cluster i would want to call it persistence01", "context": None}
|
||||
],
|
||||
"steps": [
|
||||
{
|
||||
"turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad",
|
||||
"step_id": "c0f797dd-3d34-4bc5-a8f4-db6af9455132",
|
||||
"started_at": "2025-08-05T14:31:50.000484Z",
|
||||
"completed_at": "2025-08-05T14:31:51.303691Z",
|
||||
"step_type": "inference",
|
||||
"model_response": {
|
||||
"role": "assistant",
|
||||
"content": "OK, I can create a cluster named 'persistence01' for you.",
|
||||
"stop_reason": "end_of_turn",
|
||||
"tool_calls": [],
|
||||
},
|
||||
}
|
||||
],
|
||||
"output_message": {
|
||||
"role": "assistant",
|
||||
"content": "OK, I can create a cluster named 'persistence01' for you.",
|
||||
"stop_reason": "end_of_turn",
|
||||
"tool_calls": [],
|
||||
},
|
||||
"output_attachments": [],
|
||||
"started_at": "2025-08-05T14:31:49.999950Z",
|
||||
"completed_at": "2025-08-05T14:31:51.305384Z",
|
||||
}
|
||||
|
||||
mock_data = {
|
||||
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d": json.dumps(session_data),
|
||||
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad": json.dumps(
|
||||
turn_data
|
||||
),
|
||||
}
|
||||
|
||||
mock_kvstore.values_in_range.return_value = list(mock_data.values())
|
||||
|
||||
async def mock_get(key):
|
||||
return mock_data.get(key)
|
||||
|
||||
mock_kvstore.get.side_effect = mock_get
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
|
||||
result = await agent_persistence.list_sessions()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].session_id == "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"
|
||||
|
||||
# confirm no errors logged
|
||||
mock_log.error.assert_not_called()
|
||||
|
||||
async def test_list_sessions_key_range_construction(self, agent_persistence, mock_kvstore):
|
||||
mock_kvstore.values_in_range.return_value = []
|
||||
|
||||
await agent_persistence.list_sessions()
|
||||
|
||||
mock_kvstore.values_in_range.assert_called_once_with(
|
||||
start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff"
|
||||
)
|
||||
|
|
@ -1,196 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import Document
|
||||
from llama_stack.apis.common.content_types import URL, TextContentItem
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import get_raw_document_text
|
||||
|
||||
|
||||
async def test_get_raw_document_text_supports_text_mime_types():
|
||||
"""Test that the function accepts text/* mime types."""
|
||||
document = Document(content="Sample text content", mime_type="text/plain")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == "Sample text content"
|
||||
|
||||
|
||||
async def test_get_raw_document_text_supports_yaml_mime_type():
|
||||
"""Test that the function accepts application/yaml mime type."""
|
||||
yaml_content = """
|
||||
name: test
|
||||
version: 1.0
|
||||
items:
|
||||
- item1
|
||||
- item2
|
||||
"""
|
||||
|
||||
document = Document(content=yaml_content, mime_type="application/yaml")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == yaml_content
|
||||
|
||||
|
||||
async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning():
|
||||
"""Test that the function accepts text/yaml but emits a deprecation warning."""
|
||||
yaml_content = """
|
||||
name: test
|
||||
version: 1.0
|
||||
items:
|
||||
- item1
|
||||
- item2
|
||||
"""
|
||||
|
||||
document = Document(content=yaml_content, mime_type="text/yaml")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = await get_raw_document_text(document)
|
||||
|
||||
# Check that result is correct
|
||||
assert result == yaml_content
|
||||
|
||||
# Check that exactly one warning was issued
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "text/yaml" in str(w[0].message)
|
||||
assert "application/yaml" in str(w[0].message)
|
||||
assert "deprecated" in str(w[0].message).lower()
|
||||
|
||||
|
||||
async def test_get_raw_document_text_deprecated_text_yaml_with_url():
|
||||
"""Test that text/yaml works with URL content and emits warning."""
|
||||
yaml_content = "name: test\nversion: 1.0"
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
|
||||
mock_load.return_value = yaml_content
|
||||
|
||||
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="text/yaml")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = await get_raw_document_text(document)
|
||||
|
||||
# Check that result is correct
|
||||
assert result == yaml_content
|
||||
mock_load.assert_called_once_with("https://example.com/config.yaml")
|
||||
|
||||
# Check that deprecation warning was issued
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "text/yaml" in str(w[0].message)
|
||||
|
||||
|
||||
async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item():
|
||||
"""Test that text/yaml works with TextContentItem and emits warning."""
|
||||
yaml_content = "key: value\nlist:\n - item1\n - item2"
|
||||
|
||||
document = Document(content=TextContentItem(text=yaml_content), mime_type="text/yaml")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = await get_raw_document_text(document)
|
||||
|
||||
# Check that result is correct
|
||||
assert result == yaml_content
|
||||
|
||||
# Check that deprecation warning was issued
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "text/yaml" in str(w[0].message)
|
||||
|
||||
|
||||
async def test_get_raw_document_text_supports_json_mime_type():
|
||||
"""Test that the function accepts application/json mime type."""
|
||||
json_content = '{"name": "test", "version": "1.0", "items": ["item1", "item2"]}'
|
||||
|
||||
document = Document(content=json_content, mime_type="application/json")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == json_content
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_json_text_content_item():
|
||||
"""Test that the function handles JSON TextContentItem correctly."""
|
||||
json_content = '{"key": "value", "nested": {"array": [1, 2, 3]}}'
|
||||
|
||||
document = Document(content=TextContentItem(text=json_content), mime_type="application/json")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == json_content
|
||||
|
||||
|
||||
async def test_get_raw_document_text_rejects_unsupported_mime_types():
|
||||
"""Test that the function rejects unsupported mime types."""
|
||||
document = Document(
|
||||
content="Some content",
|
||||
mime_type="application/pdf", # Not supported
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unexpected document mime type: application/pdf"):
|
||||
await get_raw_document_text(document)
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_url_content():
|
||||
"""Test that the function handles URL content correctly."""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.text = "Content from URL"
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
|
||||
mock_load.return_value = "Content from URL"
|
||||
|
||||
document = Document(content=URL(uri="https://example.com/test.txt"), mime_type="text/plain")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == "Content from URL"
|
||||
mock_load.assert_called_once_with("https://example.com/test.txt")
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_yaml_url():
|
||||
"""Test that the function handles YAML URLs correctly."""
|
||||
yaml_content = "name: test\nversion: 1.0"
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
|
||||
mock_load.return_value = yaml_content
|
||||
|
||||
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="application/yaml")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == yaml_content
|
||||
mock_load.assert_called_once_with("https://example.com/config.yaml")
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_text_content_item():
|
||||
"""Test that the function handles TextContentItem correctly."""
|
||||
document = Document(content=TextContentItem(text="Text content item"), mime_type="text/plain")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == "Text content item"
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_yaml_text_content_item():
|
||||
"""Test that the function handles YAML TextContentItem correctly."""
|
||||
yaml_content = "key: value\nlist:\n - item1\n - item2"
|
||||
|
||||
document = Document(content=TextContentItem(text=yaml_content), mime_type="application/yaml")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == yaml_content
|
||||
|
||||
|
||||
async def test_get_raw_document_text_rejects_unexpected_content_type():
|
||||
"""Test that the function rejects unexpected document content types."""
|
||||
# Create a mock document that bypasses Pydantic validation
|
||||
mock_document = MagicMock(spec=Document)
|
||||
mock_document.mime_type = "text/plain"
|
||||
mock_document.content = 123 # Unexpected content type (not str, URL, or TextContentItem)
|
||||
|
||||
with pytest.raises(ValueError, match="Unexpected document content type: <class 'int'>"):
|
||||
await get_raw_document_text(mock_document)
|
||||
|
|
@ -1,325 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
AgentConfig,
|
||||
AgentCreateResponse,
|
||||
)
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
|
||||
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
|
||||
from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
|
||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_backends(tmp_path):
|
||||
"""Register KV and SQL store backends for testing."""
|
||||
from llama_stack.core.storage.datatypes import SqliteKVStoreConfig, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
kv_path = str(tmp_path / "test_kv.db")
|
||||
sql_path = str(tmp_path / "test_sql.db")
|
||||
|
||||
register_kvstore_backends({"kv_default": SqliteKVStoreConfig(db_path=kv_path)})
|
||||
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=sql_path)})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_apis():
|
||||
return {
|
||||
"inference_api": AsyncMock(spec=Inference),
|
||||
"vector_io_api": AsyncMock(spec=VectorIO),
|
||||
"safety_api": AsyncMock(spec=Safety),
|
||||
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
|
||||
"tool_groups_api": AsyncMock(spec=ToolGroups),
|
||||
"conversations_api": AsyncMock(spec=Conversations),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config(tmp_path):
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
|
||||
from llama_stack.providers.inline.agents.meta_reference.config import AgentPersistenceConfig
|
||||
|
||||
return MetaReferenceAgentsImplConfig(
|
||||
persistence=AgentPersistenceConfig(
|
||||
agent_state=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="agents",
|
||||
),
|
||||
responses=ResponsesStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="responses",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def agents_impl(config, mock_apis):
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
config,
|
||||
mock_apis["inference_api"],
|
||||
mock_apis["vector_io_api"],
|
||||
mock_apis["safety_api"],
|
||||
mock_apis["tool_runtime_api"],
|
||||
mock_apis["tool_groups_api"],
|
||||
mock_apis["conversations_api"],
|
||||
[],
|
||||
)
|
||||
await impl.initialize()
|
||||
yield impl
|
||||
await impl.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_config():
|
||||
return AgentConfig(
|
||||
sampling_params={
|
||||
"strategy": {"type": "greedy"},
|
||||
"max_tokens": 0,
|
||||
"repetition_penalty": 1.0,
|
||||
},
|
||||
input_shields=["string"],
|
||||
output_shields=["string"],
|
||||
toolgroups=["mcp::my_mcp_server"],
|
||||
client_tools=[
|
||||
{
|
||||
"name": "client_tool",
|
||||
"description": "Client Tool",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "string",
|
||||
"parameter_type": "string",
|
||||
"description": "string",
|
||||
"required": True,
|
||||
"default": None,
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"property1": None,
|
||||
"property2": None,
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="json",
|
||||
tool_config={
|
||||
"tool_choice": "auto",
|
||||
"tool_prompt_format": "json",
|
||||
"system_message_behavior": "append",
|
||||
},
|
||||
max_infer_iters=10,
|
||||
model="string",
|
||||
instructions="string",
|
||||
enable_session_persistence=False,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"property1": None,
|
||||
"property2": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def test_create_agent(agents_impl, sample_agent_config):
|
||||
response = await agents_impl.create_agent(sample_agent_config)
|
||||
|
||||
assert isinstance(response, AgentCreateResponse)
|
||||
assert response.agent_id is not None
|
||||
|
||||
stored_agent = await agents_impl.persistence_store.get(f"agent:{response.agent_id}")
|
||||
assert stored_agent is not None
|
||||
agent_info = AgentInfo.model_validate_json(stored_agent)
|
||||
assert agent_info.model == sample_agent_config.model
|
||||
assert agent_info.created_at is not None
|
||||
assert isinstance(agent_info.created_at, datetime)
|
||||
|
||||
|
||||
async def test_get_agent(agents_impl, sample_agent_config):
|
||||
create_response = await agents_impl.create_agent(sample_agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
agent = await agents_impl.get_agent(agent_id)
|
||||
|
||||
assert isinstance(agent, Agent)
|
||||
assert agent.agent_id == agent_id
|
||||
assert agent.agent_config.model == sample_agent_config.model
|
||||
assert agent.created_at is not None
|
||||
assert isinstance(agent.created_at, datetime)
|
||||
|
||||
|
||||
async def test_list_agents(agents_impl, sample_agent_config):
|
||||
agent1_response = await agents_impl.create_agent(sample_agent_config)
|
||||
agent2_response = await agents_impl.create_agent(sample_agent_config)
|
||||
|
||||
response = await agents_impl.list_agents()
|
||||
|
||||
assert isinstance(response, PaginatedResponse)
|
||||
assert len(response.data) == 2
|
||||
agent_ids = {agent["agent_id"] for agent in response.data}
|
||||
assert agent1_response.agent_id in agent_ids
|
||||
assert agent2_response.agent_id in agent_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_session_persistence", [True, False])
|
||||
async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
|
||||
# Create an agent with specified persistence setting
|
||||
config = sample_agent_config.model_copy()
|
||||
config.enable_session_persistence = enable_session_persistence
|
||||
response = await agents_impl.create_agent(config)
|
||||
agent_id = response.agent_id
|
||||
|
||||
# Create a session
|
||||
session_response = await agents_impl.create_agent_session(agent_id, "test_session")
|
||||
assert session_response.session_id is not None
|
||||
|
||||
# Verify the session was stored
|
||||
session = await agents_impl.get_agents_session(session_response.session_id, agent_id)
|
||||
assert session.session_name == "test_session"
|
||||
assert session.session_id == session_response.session_id
|
||||
assert session.started_at is not None
|
||||
assert session.turns == []
|
||||
|
||||
# Delete the session
|
||||
await agents_impl.delete_agents_session(session_response.session_id, agent_id)
|
||||
|
||||
# Verify the session was deleted
|
||||
with pytest.raises(ValueError):
|
||||
await agents_impl.get_agents_session(session_response.session_id, agent_id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_session_persistence", [True, False])
|
||||
async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
|
||||
# Create an agent with specified persistence setting
|
||||
config = sample_agent_config.model_copy()
|
||||
config.enable_session_persistence = enable_session_persistence
|
||||
response = await agents_impl.create_agent(config)
|
||||
agent_id = response.agent_id
|
||||
|
||||
# Create multiple sessions
|
||||
session1 = await agents_impl.create_agent_session(agent_id, "session1")
|
||||
session2 = await agents_impl.create_agent_session(agent_id, "session2")
|
||||
|
||||
# List sessions
|
||||
sessions = await agents_impl.list_agent_sessions(agent_id)
|
||||
assert len(sessions.data) == 2
|
||||
session_ids = {s["session_id"] for s in sessions.data}
|
||||
assert session1.session_id in session_ids
|
||||
assert session2.session_id in session_ids
|
||||
|
||||
# Delete one session
|
||||
await agents_impl.delete_agents_session(session1.session_id, agent_id)
|
||||
|
||||
# Verify the session was deleted
|
||||
with pytest.raises(ValueError):
|
||||
await agents_impl.get_agents_session(session1.session_id, agent_id)
|
||||
|
||||
# List sessions again
|
||||
sessions = await agents_impl.list_agent_sessions(agent_id)
|
||||
assert len(sessions.data) == 1
|
||||
assert session2.session_id in {s["session_id"] for s in sessions.data}
|
||||
|
||||
|
||||
async def test_delete_agent(agents_impl, sample_agent_config):
|
||||
# Create an agent
|
||||
response = await agents_impl.create_agent(sample_agent_config)
|
||||
agent_id = response.agent_id
|
||||
|
||||
# Delete the agent
|
||||
await agents_impl.delete_agent(agent_id)
|
||||
|
||||
# Verify the agent was deleted
|
||||
with pytest.raises(ValueError):
|
||||
await agents_impl.get_agent(agent_id)
|
||||
|
||||
|
||||
async def test__initialize_tools(agents_impl, sample_agent_config):
|
||||
# Mock tool_groups_api.list_tools()
|
||||
agents_impl.tool_groups_api.list_tools.return_value = ListToolDefsResponse(
|
||||
data=[
|
||||
ToolDef(
|
||||
name="story_maker",
|
||||
toolgroup_id="mcp::my_mcp_server",
|
||||
description="Make a story",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"story_title": {"type": "string", "description": "Title of the story", "title": "Story Title"},
|
||||
"input_words": {
|
||||
"type": "array",
|
||||
"description": "Input words",
|
||||
"items": {"type": "string"},
|
||||
"title": "Input Words",
|
||||
"default": [],
|
||||
},
|
||||
},
|
||||
"required": ["story_title"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
create_response = await agents_impl.create_agent(sample_agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Get an instance of ChatAgent
|
||||
chat_agent = await agents_impl._get_agent_impl(agent_id)
|
||||
assert chat_agent is not None
|
||||
assert isinstance(chat_agent, ChatAgent)
|
||||
|
||||
# Initialize tool definitions
|
||||
await chat_agent._initialize_tools()
|
||||
assert len(chat_agent.tool_defs) == 2
|
||||
|
||||
# Verify the first tool, which is a client tool
|
||||
first_tool = chat_agent.tool_defs[0]
|
||||
assert first_tool.tool_name == "client_tool"
|
||||
assert first_tool.description == "Client Tool"
|
||||
|
||||
# Verify the second tool, which is an MCP tool that has an array-type property
|
||||
second_tool = chat_agent.tool_defs[1]
|
||||
assert second_tool.tool_name == "story_maker"
|
||||
assert second_tool.description == "Make a story"
|
||||
|
||||
# Verify the input schema
|
||||
input_schema = second_tool.input_schema
|
||||
assert input_schema is not None
|
||||
assert input_schema["type"] == "object"
|
||||
|
||||
properties = input_schema["properties"]
|
||||
assert len(properties) == 2
|
||||
|
||||
# Verify a string property
|
||||
story_title = properties["story_title"]
|
||||
assert story_title["type"] == "string"
|
||||
assert story_title["description"] == "Title of the story"
|
||||
assert story_title["title"] == "Story Title"
|
||||
|
||||
# Verify an array property
|
||||
input_words = properties["input_words"]
|
||||
assert input_words["type"] == "array"
|
||||
assert input_words["description"] == "Input words"
|
||||
assert input_words["items"]["type"] == "string"
|
||||
assert input_words["title"] == "Input Words"
|
||||
assert input_words["default"] == []
|
||||
|
||||
# Verify required fields
|
||||
assert input_schema["required"] == ["story_title"]
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletion,
|
||||
)
|
||||
|
||||
FIXTURES_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def load_chat_completion_fixture(filename: str) -> OpenAIChatCompletion:
|
||||
fixture_path = os.path.join(FIXTURES_DIR, filename)
|
||||
|
||||
with open(fixture_path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
return OpenAIChatCompletion(**data)
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
id: chat-completion-123
|
||||
choices:
|
||||
- message:
|
||||
content: "Dublin"
|
||||
role: assistant
|
||||
finish_reason: stop
|
||||
index: 0
|
||||
created: 1234567890
|
||||
model: meta-llama/Llama-3.1-8B-Instruct
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
id: chat-completion-123
|
||||
choices:
|
||||
- message:
|
||||
tool_calls:
|
||||
- id: tool_call_123
|
||||
type: function
|
||||
function:
|
||||
name: web_search
|
||||
arguments: '{"query":"What is the capital of Ireland?"}'
|
||||
role: assistant
|
||||
finish_reason: stop
|
||||
index: 0
|
||||
created: 1234567890
|
||||
model: meta-llama/Llama-3.1-8B-Instruct
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,249 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
)
|
||||
from llama_stack.apis.common.errors import (
|
||||
ConversationNotFoundError,
|
||||
InvalidConversationIdError,
|
||||
)
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
ConversationItemList,
|
||||
)
|
||||
|
||||
# Import existing fixtures from the main responses test file
|
||||
pytest_plugins = ["tests.unit.providers.agents.meta_reference.test_openai_responses"]
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def responses_impl_with_conversations(
|
||||
mock_inference_api,
|
||||
mock_tool_groups_api,
|
||||
mock_tool_runtime_api,
|
||||
mock_responses_store,
|
||||
mock_vector_io_api,
|
||||
mock_conversations_api,
|
||||
mock_safety_api,
|
||||
):
|
||||
"""Create OpenAIResponsesImpl instance with conversations API."""
|
||||
return OpenAIResponsesImpl(
|
||||
inference_api=mock_inference_api,
|
||||
tool_groups_api=mock_tool_groups_api,
|
||||
tool_runtime_api=mock_tool_runtime_api,
|
||||
responses_store=mock_responses_store,
|
||||
vector_io_api=mock_vector_io_api,
|
||||
conversations_api=mock_conversations_api,
|
||||
safety_api=mock_safety_api,
|
||||
)
|
||||
|
||||
|
||||
class TestConversationValidation:
|
||||
"""Test conversation ID validation logic."""
|
||||
|
||||
async def test_nonexistent_conversation_raises_error(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test that ConversationNotFoundError is raised for non-existent conversation."""
|
||||
conv_id = "conv_nonexistent"
|
||||
|
||||
# Mock conversation not found
|
||||
mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent")
|
||||
|
||||
with pytest.raises(ConversationNotFoundError):
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="Hello", model="test-model", conversation=conv_id, stream=False
|
||||
)
|
||||
|
||||
|
||||
class TestMessageSyncing:
|
||||
"""Test message syncing to conversations."""
|
||||
|
||||
async def test_sync_response_to_conversation_simple(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test syncing simple response to conversation."""
|
||||
conv_id = "conv_test123"
|
||||
input_text = "What are the 5 Ds of dodgeball?"
|
||||
|
||||
# Output items (what the model generated)
|
||||
output_items = [
|
||||
OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[
|
||||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text="The 5 Ds are: Dodge, Duck, Dip, Dive, and Dodge.", type="output_text", annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
]
|
||||
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_text, output_items)
|
||||
|
||||
# should call add_items with user input and assistant response
|
||||
mock_conversations_api.add_items.assert_called_once()
|
||||
call_args = mock_conversations_api.add_items.call_args
|
||||
|
||||
assert call_args[0][0] == conv_id # conversation_id
|
||||
items = call_args[0][1] # conversation_items
|
||||
|
||||
assert len(items) == 2
|
||||
# User message
|
||||
assert items[0].type == "message"
|
||||
assert items[0].role == "user"
|
||||
assert items[0].content[0].type == "input_text"
|
||||
assert items[0].content[0].text == input_text
|
||||
|
||||
# Assistant message
|
||||
assert items[1].type == "message"
|
||||
assert items[1].role == "assistant"
|
||||
|
||||
async def test_sync_response_to_conversation_api_error(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
mock_conversations_api.add_items.side_effect = Exception("API Error")
|
||||
output_items = []
|
||||
|
||||
# matching the behavior of OpenAI here
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(
|
||||
"conv_test123", "Hello", output_items
|
||||
)
|
||||
|
||||
async def test_sync_with_list_input(self, responses_impl_with_conversations, mock_conversations_api):
|
||||
"""Test syncing with list of input messages."""
|
||||
conv_id = "conv_test123"
|
||||
input_messages = [
|
||||
OpenAIResponseMessage(role="user", content=[{"type": "input_text", "text": "First message"}]),
|
||||
]
|
||||
output_items = [
|
||||
OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text="Response", type="output_text")],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
]
|
||||
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_messages, output_items)
|
||||
|
||||
mock_conversations_api.add_items.assert_called_once()
|
||||
call_args = mock_conversations_api.add_items.call_args
|
||||
|
||||
items = call_args[0][1]
|
||||
# Should have input message + output message
|
||||
assert len(items) == 2
|
||||
|
||||
|
||||
class TestIntegrationWorkflow:
|
||||
"""Integration tests for the full conversation workflow."""
|
||||
|
||||
async def test_create_response_with_valid_conversation(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test creating a response with a valid conversation parameter."""
|
||||
mock_conversations_api.list_items.return_value = ConversationItemList(
|
||||
data=[], first_id=None, has_more=False, last_id=None, object="list"
|
||||
)
|
||||
|
||||
async def mock_streaming_response(*args, **kwargs):
|
||||
message_item = OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[
|
||||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text="Test response", type="output_text", annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
|
||||
# Emit output_item.done event first (needed for conversation sync)
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id="resp_test123",
|
||||
item=message_item,
|
||||
output_index=0,
|
||||
sequence_number=1,
|
||||
type="response.output_item.done",
|
||||
)
|
||||
|
||||
# Then emit response.completed
|
||||
mock_response = OpenAIResponseObject(
|
||||
id="resp_test123",
|
||||
created_at=1234567890,
|
||||
model="test-model",
|
||||
object="response",
|
||||
output=[message_item],
|
||||
status="completed",
|
||||
)
|
||||
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=mock_response, type="response.completed")
|
||||
|
||||
responses_impl_with_conversations._create_streaming_response = mock_streaming_response
|
||||
|
||||
input_text = "Hello, how are you?"
|
||||
conversation_id = "conv_test123"
|
||||
|
||||
response = await responses_impl_with_conversations.create_openai_response(
|
||||
input=input_text, model="test-model", conversation=conversation_id, stream=False
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.id == "resp_test123"
|
||||
|
||||
# Note: conversation sync happens inside _create_streaming_response,
|
||||
# which we're mocking here, so we can't test it in this unit test.
|
||||
# The sync logic is tested separately in TestMessageSyncing.
|
||||
|
||||
async def test_create_response_with_invalid_conversation_id(self, responses_impl_with_conversations):
|
||||
"""Test creating a response with an invalid conversation ID."""
|
||||
with pytest.raises(InvalidConversationIdError) as exc_info:
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="Hello", model="test-model", conversation="invalid_id", stream=False
|
||||
)
|
||||
|
||||
assert "Expected an ID that begins with 'conv_'" in str(exc_info.value)
|
||||
|
||||
async def test_create_response_with_nonexistent_conversation(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test creating a response with a non-existent conversation."""
|
||||
mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent")
|
||||
|
||||
with pytest.raises(ConversationNotFoundError) as exc_info:
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="Hello", model="test-model", conversation="conv_nonexistent", stream=False
|
||||
)
|
||||
|
||||
assert "not found" in str(exc_info.value)
|
||||
|
||||
async def test_conversation_and_previous_response_id(
|
||||
self, responses_impl_with_conversations, mock_conversations_api, mock_responses_store
|
||||
):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="test", model="test", conversation="conv_123", previous_response_id="resp_123"
|
||||
)
|
||||
|
||||
assert "Mutually exclusive parameters" in str(exc_info.value)
|
||||
assert "previous_response_id" in str(exc_info.value)
|
||||
assert "conversation" in str(exc_info.value)
|
||||
|
|
@ -1,367 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
_extract_citations_from_text,
|
||||
convert_chat_choice_to_response_message,
|
||||
convert_response_content_to_chat_content,
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
get_message_type_by_role,
|
||||
is_function_tool_call,
|
||||
)
|
||||
|
||||
|
||||
class TestConvertChatChoiceToResponseMessage:
|
||||
async def test_convert_string_content(self):
|
||||
choice = OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(content="Test message"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
result = await convert_chat_choice_to_response_message(choice)
|
||||
|
||||
assert result.role == "assistant"
|
||||
assert result.status == "completed"
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], OpenAIResponseOutputMessageContentOutputText)
|
||||
assert result.content[0].text == "Test message"
|
||||
|
||||
async def test_convert_text_param_content(self):
|
||||
choice = OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(
|
||||
content=[OpenAIChatCompletionContentPartTextParam(text="Test text param")]
|
||||
),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await convert_chat_choice_to_response_message(choice)
|
||||
|
||||
assert "does not yet support output content type" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestConvertResponseContentToChatContent:
|
||||
async def test_convert_string_content(self):
|
||||
result = await convert_response_content_to_chat_content("Simple string")
|
||||
assert result == "Simple string"
|
||||
|
||||
async def test_convert_text_content_parts(self):
|
||||
content = [
|
||||
OpenAIResponseInputMessageContentText(text="First part"),
|
||||
OpenAIResponseOutputMessageContentOutputText(text="Second part"),
|
||||
]
|
||||
|
||||
result = await convert_response_content_to_chat_content(content)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], OpenAIChatCompletionContentPartTextParam)
|
||||
assert result[0].text == "First part"
|
||||
assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam)
|
||||
assert result[1].text == "Second part"
|
||||
|
||||
async def test_convert_image_content(self):
|
||||
content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")]
|
||||
|
||||
result = await convert_response_content_to_chat_content(content)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIChatCompletionContentPartImageParam)
|
||||
assert result[0].image_url.url == "https://example.com/image.jpg"
|
||||
assert result[0].image_url.detail == "high"
|
||||
|
||||
|
||||
class TestConvertResponseInputToChatMessages:
|
||||
async def test_convert_string_input(self):
|
||||
result = await convert_response_input_to_chat_messages("User message")
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIUserMessageParam)
|
||||
assert result[0].content == "User message"
|
||||
|
||||
async def test_convert_function_tool_call_output(self):
|
||||
input_items = [
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_123",
|
||||
name="test_function",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="Tool output",
|
||||
call_id="call_123",
|
||||
),
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||
assert result[0].tool_calls[0].id == "call_123"
|
||||
assert result[0].tool_calls[0].function.name == "test_function"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
assert isinstance(result[1], OpenAIToolMessageParam)
|
||||
assert result[1].content == "Tool output"
|
||||
assert result[1].tool_call_id == "call_123"
|
||||
|
||||
async def test_convert_function_tool_call(self):
|
||||
input_items = [
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_456",
|
||||
name="test_function",
|
||||
arguments='{"param": "value"}',
|
||||
)
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||
assert len(result[0].tool_calls) == 1
|
||||
assert result[0].tool_calls[0].id == "call_456"
|
||||
assert result[0].tool_calls[0].function.name == "test_function"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
|
||||
async def test_convert_function_call_ordering(self):
|
||||
input_items = [
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_123",
|
||||
name="test_function_a",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_456",
|
||||
name="test_function_b",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="AAA",
|
||||
call_id="call_123",
|
||||
),
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="BBB",
|
||||
call_id="call_456",
|
||||
),
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
assert len(result) == 4
|
||||
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||
assert len(result[0].tool_calls) == 1
|
||||
assert result[0].tool_calls[0].id == "call_123"
|
||||
assert result[0].tool_calls[0].function.name == "test_function_a"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
assert isinstance(result[1], OpenAIToolMessageParam)
|
||||
assert result[1].content == "AAA"
|
||||
assert result[1].tool_call_id == "call_123"
|
||||
assert isinstance(result[2], OpenAIAssistantMessageParam)
|
||||
assert len(result[2].tool_calls) == 1
|
||||
assert result[2].tool_calls[0].id == "call_456"
|
||||
assert result[2].tool_calls[0].function.name == "test_function_b"
|
||||
assert result[2].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
assert isinstance(result[3], OpenAIToolMessageParam)
|
||||
assert result[3].content == "BBB"
|
||||
assert result[3].tool_call_id == "call_456"
|
||||
|
||||
async def test_convert_response_message(self):
|
||||
input_items = [
|
||||
OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[OpenAIResponseInputMessageContentText(text="User text")],
|
||||
)
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIUserMessageParam)
|
||||
# Content should be converted to chat content format
|
||||
assert len(result[0].content) == 1
|
||||
assert result[0].content[0].text == "User text"
|
||||
|
||||
|
||||
class TestConvertResponseTextToChatResponseFormat:
|
||||
async def test_convert_text_format(self):
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatText)
|
||||
assert result.type == "text"
|
||||
|
||||
async def test_convert_json_object_format(self):
|
||||
text = OpenAIResponseText(format={"type": "json_object"})
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatJSONObject)
|
||||
|
||||
async def test_convert_json_schema_format(self):
|
||||
schema_def = {"type": "object", "properties": {"test": {"type": "string"}}}
|
||||
text = OpenAIResponseText(
|
||||
format={
|
||||
"type": "json_schema",
|
||||
"name": "test_schema",
|
||||
"schema": schema_def,
|
||||
}
|
||||
)
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatJSONSchema)
|
||||
assert result.json_schema["name"] == "test_schema"
|
||||
assert result.json_schema["schema"] == schema_def
|
||||
|
||||
async def test_default_text_format(self):
|
||||
text = OpenAIResponseText()
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatText)
|
||||
assert result.type == "text"
|
||||
|
||||
|
||||
class TestGetMessageTypeByRole:
|
||||
async def test_user_role(self):
|
||||
result = await get_message_type_by_role("user")
|
||||
assert result == OpenAIUserMessageParam
|
||||
|
||||
async def test_system_role(self):
|
||||
result = await get_message_type_by_role("system")
|
||||
assert result == OpenAISystemMessageParam
|
||||
|
||||
async def test_assistant_role(self):
|
||||
result = await get_message_type_by_role("assistant")
|
||||
assert result == OpenAIAssistantMessageParam
|
||||
|
||||
async def test_developer_role(self):
|
||||
result = await get_message_type_by_role("developer")
|
||||
assert result == OpenAIDeveloperMessageParam
|
||||
|
||||
async def test_unknown_role(self):
|
||||
result = await get_message_type_by_role("unknown")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestIsFunctionToolCall:
|
||||
def test_is_function_tool_call_true(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="test_function",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(
|
||||
type="function", name="test_function", parameters={"type": "object", "properties": {}}
|
||||
),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is True
|
||||
|
||||
def test_is_function_tool_call_false_different_name(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="other_function",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(
|
||||
type="function", name="test_function", parameters={"type": "object", "properties": {}}
|
||||
),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
||||
|
||||
def test_is_function_tool_call_false_no_function(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=None,
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(
|
||||
type="function", name="test_function", parameters={"type": "object", "properties": {}}
|
||||
),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
||||
|
||||
def test_is_function_tool_call_false_wrong_type(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="web_search",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestExtractCitationsFromText:
|
||||
def test_extract_citations_and_annotations(self):
|
||||
text = "Start [not-a-file]. New source <|file-abc123|>. "
|
||||
text += "Other source <|file-def456|>? Repeat source <|file-abc123|>! No citation."
|
||||
file_mapping = {"file-abc123": "doc1.pdf", "file-def456": "doc2.txt"}
|
||||
|
||||
annotations, cleaned_text = _extract_citations_from_text(text, file_mapping)
|
||||
|
||||
expected_annotations = [
|
||||
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=30),
|
||||
OpenAIResponseAnnotationFileCitation(file_id="file-def456", filename="doc2.txt", index=44),
|
||||
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=59),
|
||||
]
|
||||
expected_clean_text = "Start [not-a-file]. New source. Other source? Repeat source! No citation."
|
||||
|
||||
assert cleaned_text == expected_clean_text
|
||||
assert annotations == expected_annotations
|
||||
# OpenAI cites at the end of the sentence
|
||||
assert cleaned_text[expected_annotations[0].index] == "."
|
||||
assert cleaned_text[expected_annotations[1].index] == "?"
|
||||
assert cleaned_text[expected_annotations[2].index] == "!"
|
||||
|
|
@ -1,183 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseToolMCP,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
|
||||
|
||||
|
||||
class TestToolContext:
|
||||
def test_no_tools(self):
|
||||
tools = []
|
||||
context = ToolContext(tools)
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 0
|
||||
assert len(context.previous_tools) == 0
|
||||
assert len(context.previous_tool_listings) == 0
|
||||
|
||||
def test_no_previous_tools(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
|
||||
OpenAIResponseInputToolMCP(server_label="label", server_url="url"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 2
|
||||
assert len(context.previous_tools) == 0
|
||||
assert len(context.previous_tool_listings) == 0
|
||||
|
||||
def test_reusable_server(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
|
||||
)
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
|
||||
OpenAIResponseToolMCP(server_label="alabel"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 1
|
||||
assert context.tools_to_process[0].type == "file_search"
|
||||
assert len(context.previous_tools) == 1
|
||||
assert context.previous_tools["test_tool"].server_label == "alabel"
|
||||
assert context.previous_tools["test_tool"].server_url == "aurl"
|
||||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "alabel"
|
||||
|
||||
def test_multiple_reusable_servers(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
|
||||
),
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test2",
|
||||
server_label="anotherlabel",
|
||||
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
|
||||
),
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 2
|
||||
assert context.tools_to_process[0].type == "function"
|
||||
assert context.tools_to_process[1].type == "web_search"
|
||||
assert len(context.previous_tools) == 2
|
||||
assert context.previous_tools["test_tool"].server_label == "alabel"
|
||||
assert context.previous_tools["test_tool"].server_url == "aurl"
|
||||
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
|
||||
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
|
||||
assert len(context.previous_tool_listings) == 2
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "alabel"
|
||||
assert len(context.previous_tool_listings[1].tools) == 1
|
||||
assert context.previous_tool_listings[1].server_label == "anotherlabel"
|
||||
|
||||
def test_multiple_servers_only_one_reusable(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test2",
|
||||
server_label="anotherlabel",
|
||||
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
|
||||
)
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 3
|
||||
assert context.tools_to_process[0].type == "function"
|
||||
assert context.tools_to_process[1].type == "web_search"
|
||||
assert context.tools_to_process[2].type == "mcp"
|
||||
assert len(context.previous_tools) == 1
|
||||
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
|
||||
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
|
||||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "anotherlabel"
|
||||
|
||||
def test_mismatched_allowed_tools(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl", allowed_tools=["test_tool_2"]),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool_1", input_schema={})]
|
||||
),
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test2",
|
||||
server_label="anotherlabel",
|
||||
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
|
||||
),
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 3
|
||||
assert context.tools_to_process[0].type == "function"
|
||||
assert context.tools_to_process[1].type == "web_search"
|
||||
assert context.tools_to_process[2].type == "mcp"
|
||||
assert len(context.previous_tools) == 1
|
||||
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
|
||||
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
|
||||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "anotherlabel"
|
||||
|
|
@ -1,155 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
extract_guardrail_ids,
|
||||
run_guardrails,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_apis():
|
||||
"""Create mock APIs for testing."""
|
||||
return {
|
||||
"inference_api": AsyncMock(),
|
||||
"tool_groups_api": AsyncMock(),
|
||||
"tool_runtime_api": AsyncMock(),
|
||||
"responses_store": AsyncMock(),
|
||||
"vector_io_api": AsyncMock(),
|
||||
"conversations_api": AsyncMock(),
|
||||
"safety_api": AsyncMock(),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def responses_impl(mock_apis):
|
||||
"""Create OpenAIResponsesImpl instance with mocked dependencies."""
|
||||
return OpenAIResponsesImpl(**mock_apis)
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_from_strings(responses_impl):
|
||||
"""Test extraction from simple string guardrail IDs."""
|
||||
guardrails = ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_from_objects(responses_impl):
|
||||
"""Test extraction from ResponseGuardrailSpec objects."""
|
||||
guardrails = [
|
||||
ResponseGuardrailSpec(type="llama-guard"),
|
||||
ResponseGuardrailSpec(type="content-filter"),
|
||||
]
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter"]
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_mixed_formats(responses_impl):
|
||||
"""Test extraction from mixed string and object formats."""
|
||||
guardrails = [
|
||||
"llama-guard",
|
||||
ResponseGuardrailSpec(type="content-filter"),
|
||||
"nsfw-detector",
|
||||
]
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_none_input(responses_impl):
|
||||
"""Test extraction with None input."""
|
||||
result = extract_guardrail_ids(None)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_empty_list(responses_impl):
|
||||
"""Test extraction with empty list."""
|
||||
result = extract_guardrail_ids([])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_unknown_format(responses_impl):
|
||||
"""Test extraction with unknown guardrail format raises ValueError."""
|
||||
# Create an object that's neither string nor ResponseGuardrailSpec
|
||||
unknown_object = {"invalid": "format"} # Plain dict, not ResponseGuardrailSpec
|
||||
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
|
||||
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
|
||||
extract_guardrail_ids(guardrails)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
"""Create mock safety API for guardrails testing."""
|
||||
safety_api = AsyncMock()
|
||||
# Mock the routing table and shields list for guardrails lookup
|
||||
safety_api.routing_table = AsyncMock()
|
||||
shield = AsyncMock()
|
||||
shield.identifier = "llama-guard"
|
||||
shield.provider_resource_id = "llama-guard-model"
|
||||
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
|
||||
return safety_api
|
||||
|
||||
|
||||
async def test_run_guardrails_no_violation(mock_safety_api):
|
||||
"""Test guardrails validation with no violations."""
|
||||
text = "Hello world"
|
||||
guardrail_ids = ["llama-guard"]
|
||||
|
||||
# Mock moderation to return non-flagged content
|
||||
unflagged_result = ModerationObjectResults(flagged=False, categories={"violence": False})
|
||||
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[unflagged_result])
|
||||
mock_safety_api.run_moderation.return_value = mock_moderation_object
|
||||
|
||||
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
|
||||
|
||||
assert result is None
|
||||
# Verify run_moderation was called with the correct model
|
||||
mock_safety_api.run_moderation.assert_called_once()
|
||||
call_args = mock_safety_api.run_moderation.call_args
|
||||
assert call_args[1]["model"] == "llama-guard-model"
|
||||
|
||||
|
||||
async def test_run_guardrails_with_violation(mock_safety_api):
|
||||
"""Test guardrails validation with safety violation."""
|
||||
text = "Harmful content"
|
||||
guardrail_ids = ["llama-guard"]
|
||||
|
||||
# Mock moderation to return flagged content
|
||||
flagged_result = ModerationObjectResults(
|
||||
flagged=True,
|
||||
categories={"violence": True},
|
||||
user_message="Content flagged by moderation",
|
||||
metadata={"violation_type": ["S1"]},
|
||||
)
|
||||
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result])
|
||||
mock_safety_api.run_moderation.return_value = mock_moderation_object
|
||||
|
||||
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
|
||||
|
||||
assert result == "Content flagged by moderation (flagged for: violence) (violation type: S1)"
|
||||
|
||||
|
||||
async def test_run_guardrails_empty_inputs(mock_safety_api):
|
||||
"""Test guardrails validation with empty inputs."""
|
||||
# Test empty guardrail_ids
|
||||
result = await run_guardrails(mock_safety_api, "test", [])
|
||||
assert result is None
|
||||
|
||||
# Test empty text
|
||||
result = await run_guardrails(mock_safety_api, "", ["llama-guard"])
|
||||
assert result is None
|
||||
|
||||
# Test both empty
|
||||
result = await run_guardrails(mock_safety_api, "", [])
|
||||
assert result is None
|
||||
|
|
@ -1,169 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import Turn
|
||||
from llama_stack.apis.inference import CompletionMessage, StopReason
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_setup(sqlite_kvstore):
|
||||
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
|
||||
yield agent_persistence
|
||||
|
||||
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Set creator's attributes for the session
|
||||
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
|
||||
mock_get_authenticated_user.return_value = User("test_user", creator_attributes)
|
||||
|
||||
# Create a session
|
||||
session_id = await agent_persistence.create_session("Test Session")
|
||||
|
||||
# Get the session and verify access attributes were set
|
||||
session_info = await agent_persistence.get_session_info(session_id)
|
||||
assert session_info is not None
|
||||
assert session_info.owner is not None
|
||||
assert session_info.owner.attributes is not None
|
||||
assert session_info.owner.attributes["roles"] == ["researcher"]
|
||||
assert session_info.owner.attributes["teams"] == ["ai-team"]
|
||||
|
||||
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_session_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with specific access attributes
|
||||
session_id = str(uuid.uuid4())
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
key=f"session:{agent_persistence.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
# User with matching attributes can access
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
|
||||
)
|
||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||
assert retrieved_session is not None
|
||||
assert retrieved_session.session_id == session_id
|
||||
|
||||
# User without matching attributes cannot access
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]})
|
||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||
assert retrieved_session is None
|
||||
|
||||
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with restricted access
|
||||
session_id = str(uuid.uuid4())
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
owner=User("someone", {"roles": ["admin"]}),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
key=f"session:{agent_persistence.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
# Create a turn for this session
|
||||
turn_id = str(uuid.uuid4())
|
||||
turn = Turn(
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
steps=[],
|
||||
started_at=datetime.now(),
|
||||
input_messages=[],
|
||||
output_message=CompletionMessage(
|
||||
content="Hello",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
),
|
||||
)
|
||||
|
||||
# Admin can add turn
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
|
||||
await agent_persistence.add_turn_to_session(session_id, turn)
|
||||
|
||||
# Admin can get turn
|
||||
retrieved_turn = await agent_persistence.get_session_turn(session_id, turn_id)
|
||||
assert retrieved_turn is not None
|
||||
assert retrieved_turn.turn_id == turn_id
|
||||
|
||||
# Regular user cannot get turn
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
|
||||
with pytest.raises(ValueError):
|
||||
await agent_persistence.get_session_turn(session_id, turn_id)
|
||||
|
||||
# Regular user cannot get turns for session
|
||||
with pytest.raises(ValueError):
|
||||
await agent_persistence.get_session_turns(session_id)
|
||||
|
||||
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with restricted access
|
||||
session_id = str(uuid.uuid4())
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
owner=User("someone", {"roles": ["admin"]}),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
key=f"session:{agent_persistence.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
|
||||
# Admin user can set inference iterations
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
|
||||
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
|
||||
|
||||
# Admin user can get inference iterations
|
||||
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
|
||||
assert infer_iters == 5
|
||||
|
||||
# Regular user cannot get inference iterations
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
|
||||
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
|
||||
assert infer_iters is None
|
||||
|
||||
# Regular user cannot set inference iterations (should raise ValueError)
|
||||
with pytest.raises(ValueError):
|
||||
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 10)
|
||||
Loading…
Add table
Add a link
Reference in a new issue