chore!: remove the agents (sessions and turns) API (#4055)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Pre-commit / pre-commit (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
Python Package Build Test / build (3.13) (push) Failing after 2s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Test External API and Providers / test-external (venv) (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 9s
Unit Tests / unit-tests (3.13) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
API Conformance Tests / check-schema-compatibility (push) Successful in 13s
UI Tests / ui-tests (22) (push) Successful in 1m10s

- Removes the deprecated agents (sessions and turns) API that was marked
alpha in 0.3.0
- Cleans up unused imports and orphaned types after the API removal
- Removes `SessionNotFoundError` and `AgentTurnInputType` which are no
longer needed

The agents API is completely superseded by the Responses + Conversations
APIs, and the client SDK Agent class already uses those implementations.

Corresponding client-side PR:
https://github.com/llamastack/llama-stack-client-python/pull/295
This commit is contained in:
Ashwin Bharambe 2025-11-04 09:38:39 -08:00 committed by GitHub
parent a6ddbae0ed
commit a8a8aa56c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
1037 changed files with 393 additions and 309806 deletions

View file

@ -1,23 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import yaml
from llama_stack.apis.inference import (
OpenAIChatCompletion,
)
FIXTURES_DIR = os.path.dirname(os.path.abspath(__file__))
def load_chat_completion_fixture(filename: str) -> OpenAIChatCompletion:
fixture_path = os.path.join(FIXTURES_DIR, filename)
with open(fixture_path) as f:
data = yaml.safe_load(f)
return OpenAIChatCompletion(**data)

View file

@ -1,9 +0,0 @@
id: chat-completion-123
choices:
- message:
content: "Dublin"
role: assistant
finish_reason: stop
index: 0
created: 1234567890
model: meta-llama/Llama-3.1-8B-Instruct

View file

@ -1,14 +0,0 @@
id: chat-completion-123
choices:
- message:
tool_calls:
- id: tool_call_123
type: function
function:
name: web_search
arguments: '{"query":"What is the capital of Ireland?"}'
role: assistant
finish_reason: stop
index: 0
created: 1234567890
model: meta-llama/Llama-3.1-8B-Instruct

View file

@ -1,249 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseOutputItemDone,
OpenAIResponseOutputMessageContentOutputText,
)
from llama_stack.apis.common.errors import (
ConversationNotFoundError,
InvalidConversationIdError,
)
from llama_stack.apis.conversations.conversations import (
ConversationItemList,
)
# Import existing fixtures from the main responses test file
pytest_plugins = ["tests.unit.providers.agents.meta_reference.test_openai_responses"]
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
@pytest.fixture
def responses_impl_with_conversations(
mock_inference_api,
mock_tool_groups_api,
mock_tool_runtime_api,
mock_responses_store,
mock_vector_io_api,
mock_conversations_api,
mock_safety_api,
):
"""Create OpenAIResponsesImpl instance with conversations API."""
return OpenAIResponsesImpl(
inference_api=mock_inference_api,
tool_groups_api=mock_tool_groups_api,
tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api,
conversations_api=mock_conversations_api,
safety_api=mock_safety_api,
)
class TestConversationValidation:
"""Test conversation ID validation logic."""
async def test_nonexistent_conversation_raises_error(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test that ConversationNotFoundError is raised for non-existent conversation."""
conv_id = "conv_nonexistent"
# Mock conversation not found
mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent")
with pytest.raises(ConversationNotFoundError):
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation=conv_id, stream=False
)
class TestMessageSyncing:
"""Test message syncing to conversations."""
async def test_sync_response_to_conversation_simple(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test syncing simple response to conversation."""
conv_id = "conv_test123"
input_text = "What are the 5 Ds of dodgeball?"
# Output items (what the model generated)
output_items = [
OpenAIResponseMessage(
id="msg_response",
content=[
OpenAIResponseOutputMessageContentOutputText(
text="The 5 Ds are: Dodge, Duck, Dip, Dive, and Dodge.", type="output_text", annotations=[]
)
],
role="assistant",
status="completed",
type="message",
)
]
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_text, output_items)
# should call add_items with user input and assistant response
mock_conversations_api.add_items.assert_called_once()
call_args = mock_conversations_api.add_items.call_args
assert call_args[0][0] == conv_id # conversation_id
items = call_args[0][1] # conversation_items
assert len(items) == 2
# User message
assert items[0].type == "message"
assert items[0].role == "user"
assert items[0].content[0].type == "input_text"
assert items[0].content[0].text == input_text
# Assistant message
assert items[1].type == "message"
assert items[1].role == "assistant"
async def test_sync_response_to_conversation_api_error(
self, responses_impl_with_conversations, mock_conversations_api
):
mock_conversations_api.add_items.side_effect = Exception("API Error")
output_items = []
# matching the behavior of OpenAI here
with pytest.raises(Exception, match="API Error"):
await responses_impl_with_conversations._sync_response_to_conversation(
"conv_test123", "Hello", output_items
)
async def test_sync_with_list_input(self, responses_impl_with_conversations, mock_conversations_api):
"""Test syncing with list of input messages."""
conv_id = "conv_test123"
input_messages = [
OpenAIResponseMessage(role="user", content=[{"type": "input_text", "text": "First message"}]),
]
output_items = [
OpenAIResponseMessage(
id="msg_response",
content=[OpenAIResponseOutputMessageContentOutputText(text="Response", type="output_text")],
role="assistant",
status="completed",
type="message",
)
]
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_messages, output_items)
mock_conversations_api.add_items.assert_called_once()
call_args = mock_conversations_api.add_items.call_args
items = call_args[0][1]
# Should have input message + output message
assert len(items) == 2
class TestIntegrationWorkflow:
"""Integration tests for the full conversation workflow."""
async def test_create_response_with_valid_conversation(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test creating a response with a valid conversation parameter."""
mock_conversations_api.list_items.return_value = ConversationItemList(
data=[], first_id=None, has_more=False, last_id=None, object="list"
)
async def mock_streaming_response(*args, **kwargs):
message_item = OpenAIResponseMessage(
id="msg_response",
content=[
OpenAIResponseOutputMessageContentOutputText(
text="Test response", type="output_text", annotations=[]
)
],
role="assistant",
status="completed",
type="message",
)
# Emit output_item.done event first (needed for conversation sync)
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id="resp_test123",
item=message_item,
output_index=0,
sequence_number=1,
type="response.output_item.done",
)
# Then emit response.completed
mock_response = OpenAIResponseObject(
id="resp_test123",
created_at=1234567890,
model="test-model",
object="response",
output=[message_item],
status="completed",
)
yield OpenAIResponseObjectStreamResponseCompleted(response=mock_response, type="response.completed")
responses_impl_with_conversations._create_streaming_response = mock_streaming_response
input_text = "Hello, how are you?"
conversation_id = "conv_test123"
response = await responses_impl_with_conversations.create_openai_response(
input=input_text, model="test-model", conversation=conversation_id, stream=False
)
assert response is not None
assert response.id == "resp_test123"
# Note: conversation sync happens inside _create_streaming_response,
# which we're mocking here, so we can't test it in this unit test.
# The sync logic is tested separately in TestMessageSyncing.
async def test_create_response_with_invalid_conversation_id(self, responses_impl_with_conversations):
"""Test creating a response with an invalid conversation ID."""
with pytest.raises(InvalidConversationIdError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation="invalid_id", stream=False
)
assert "Expected an ID that begins with 'conv_'" in str(exc_info.value)
async def test_create_response_with_nonexistent_conversation(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test creating a response with a non-existent conversation."""
mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent")
with pytest.raises(ConversationNotFoundError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation="conv_nonexistent", stream=False
)
assert "not found" in str(exc_info.value)
async def test_conversation_and_previous_response_id(
self, responses_impl_with_conversations, mock_conversations_api, mock_responses_store
):
with pytest.raises(ValueError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="test", model="test", conversation="conv_123", previous_response_id="resp_123"
)
assert "Mutually exclusive parameters" in str(exc_info.value)
assert "previous_response_id" in str(exc_info.value)
assert "conversation" in str(exc_info.value)

View file

@ -1,367 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatText,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
_extract_citations_from_text,
convert_chat_choice_to_response_message,
convert_response_content_to_chat_content,
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
get_message_type_by_role,
is_function_tool_call,
)
class TestConvertChatChoiceToResponseMessage:
async def test_convert_string_content(self):
choice = OpenAIChoice(
message=OpenAIAssistantMessageParam(content="Test message"),
finish_reason="stop",
index=0,
)
result = await convert_chat_choice_to_response_message(choice)
assert result.role == "assistant"
assert result.status == "completed"
assert len(result.content) == 1
assert isinstance(result.content[0], OpenAIResponseOutputMessageContentOutputText)
assert result.content[0].text == "Test message"
async def test_convert_text_param_content(self):
choice = OpenAIChoice(
message=OpenAIAssistantMessageParam(
content=[OpenAIChatCompletionContentPartTextParam(text="Test text param")]
),
finish_reason="stop",
index=0,
)
with pytest.raises(ValueError) as exc_info:
await convert_chat_choice_to_response_message(choice)
assert "does not yet support output content type" in str(exc_info.value)
class TestConvertResponseContentToChatContent:
async def test_convert_string_content(self):
result = await convert_response_content_to_chat_content("Simple string")
assert result == "Simple string"
async def test_convert_text_content_parts(self):
content = [
OpenAIResponseInputMessageContentText(text="First part"),
OpenAIResponseOutputMessageContentOutputText(text="Second part"),
]
result = await convert_response_content_to_chat_content(content)
assert len(result) == 2
assert isinstance(result[0], OpenAIChatCompletionContentPartTextParam)
assert result[0].text == "First part"
assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam)
assert result[1].text == "Second part"
async def test_convert_image_content(self):
content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")]
result = await convert_response_content_to_chat_content(content)
assert len(result) == 1
assert isinstance(result[0], OpenAIChatCompletionContentPartImageParam)
assert result[0].image_url.url == "https://example.com/image.jpg"
assert result[0].image_url.detail == "high"
class TestConvertResponseInputToChatMessages:
async def test_convert_string_input(self):
result = await convert_response_input_to_chat_messages("User message")
assert len(result) == 1
assert isinstance(result[0], OpenAIUserMessageParam)
assert result[0].content == "User message"
async def test_convert_function_tool_call_output(self):
input_items = [
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_123",
name="test_function",
arguments='{"param": "value"}',
),
OpenAIResponseInputFunctionToolCallOutput(
output="Tool output",
call_id="call_123",
),
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 2
assert isinstance(result[0], OpenAIAssistantMessageParam)
assert result[0].tool_calls[0].id == "call_123"
assert result[0].tool_calls[0].function.name == "test_function"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
assert isinstance(result[1], OpenAIToolMessageParam)
assert result[1].content == "Tool output"
assert result[1].tool_call_id == "call_123"
async def test_convert_function_tool_call(self):
input_items = [
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_456",
name="test_function",
arguments='{"param": "value"}',
)
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 1
assert isinstance(result[0], OpenAIAssistantMessageParam)
assert len(result[0].tool_calls) == 1
assert result[0].tool_calls[0].id == "call_456"
assert result[0].tool_calls[0].function.name == "test_function"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
async def test_convert_function_call_ordering(self):
input_items = [
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_123",
name="test_function_a",
arguments='{"param": "value"}',
),
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_456",
name="test_function_b",
arguments='{"param": "value"}',
),
OpenAIResponseInputFunctionToolCallOutput(
output="AAA",
call_id="call_123",
),
OpenAIResponseInputFunctionToolCallOutput(
output="BBB",
call_id="call_456",
),
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 4
assert isinstance(result[0], OpenAIAssistantMessageParam)
assert len(result[0].tool_calls) == 1
assert result[0].tool_calls[0].id == "call_123"
assert result[0].tool_calls[0].function.name == "test_function_a"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
assert isinstance(result[1], OpenAIToolMessageParam)
assert result[1].content == "AAA"
assert result[1].tool_call_id == "call_123"
assert isinstance(result[2], OpenAIAssistantMessageParam)
assert len(result[2].tool_calls) == 1
assert result[2].tool_calls[0].id == "call_456"
assert result[2].tool_calls[0].function.name == "test_function_b"
assert result[2].tool_calls[0].function.arguments == '{"param": "value"}'
assert isinstance(result[3], OpenAIToolMessageParam)
assert result[3].content == "BBB"
assert result[3].tool_call_id == "call_456"
async def test_convert_response_message(self):
input_items = [
OpenAIResponseMessage(
role="user",
content=[OpenAIResponseInputMessageContentText(text="User text")],
)
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 1
assert isinstance(result[0], OpenAIUserMessageParam)
# Content should be converted to chat content format
assert len(result[0].content) == 1
assert result[0].content[0].text == "User text"
class TestConvertResponseTextToChatResponseFormat:
async def test_convert_text_format(self):
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatText)
assert result.type == "text"
async def test_convert_json_object_format(self):
text = OpenAIResponseText(format={"type": "json_object"})
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatJSONObject)
async def test_convert_json_schema_format(self):
schema_def = {"type": "object", "properties": {"test": {"type": "string"}}}
text = OpenAIResponseText(
format={
"type": "json_schema",
"name": "test_schema",
"schema": schema_def,
}
)
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatJSONSchema)
assert result.json_schema["name"] == "test_schema"
assert result.json_schema["schema"] == schema_def
async def test_default_text_format(self):
text = OpenAIResponseText()
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatText)
assert result.type == "text"
class TestGetMessageTypeByRole:
async def test_user_role(self):
result = await get_message_type_by_role("user")
assert result == OpenAIUserMessageParam
async def test_system_role(self):
result = await get_message_type_by_role("system")
assert result == OpenAISystemMessageParam
async def test_assistant_role(self):
result = await get_message_type_by_role("assistant")
assert result == OpenAIAssistantMessageParam
async def test_developer_role(self):
result = await get_message_type_by_role("developer")
assert result == OpenAIDeveloperMessageParam
async def test_unknown_role(self):
result = await get_message_type_by_role("unknown")
assert result is None
class TestIsFunctionToolCall:
def test_is_function_tool_call_true(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=OpenAIChatCompletionToolCallFunction(
name="test_function",
arguments="{}",
),
)
tools = [
OpenAIResponseInputToolFunction(
type="function", name="test_function", parameters={"type": "object", "properties": {}}
),
OpenAIResponseInputToolWebSearch(type="web_search"),
]
result = is_function_tool_call(tool_call, tools)
assert result is True
def test_is_function_tool_call_false_different_name(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=OpenAIChatCompletionToolCallFunction(
name="other_function",
arguments="{}",
),
)
tools = [
OpenAIResponseInputToolFunction(
type="function", name="test_function", parameters={"type": "object", "properties": {}}
),
]
result = is_function_tool_call(tool_call, tools)
assert result is False
def test_is_function_tool_call_false_no_function(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=None,
)
tools = [
OpenAIResponseInputToolFunction(
type="function", name="test_function", parameters={"type": "object", "properties": {}}
),
]
result = is_function_tool_call(tool_call, tools)
assert result is False
def test_is_function_tool_call_false_wrong_type(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=OpenAIChatCompletionToolCallFunction(
name="web_search",
arguments="{}",
),
)
tools = [
OpenAIResponseInputToolWebSearch(type="web_search"),
]
result = is_function_tool_call(tool_call, tools)
assert result is False
class TestExtractCitationsFromText:
def test_extract_citations_and_annotations(self):
text = "Start [not-a-file]. New source <|file-abc123|>. "
text += "Other source <|file-def456|>? Repeat source <|file-abc123|>! No citation."
file_mapping = {"file-abc123": "doc1.pdf", "file-def456": "doc2.txt"}
annotations, cleaned_text = _extract_citations_from_text(text, file_mapping)
expected_annotations = [
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=30),
OpenAIResponseAnnotationFileCitation(file_id="file-def456", filename="doc2.txt", index=44),
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=59),
]
expected_clean_text = "Start [not-a-file]. New source. Other source? Repeat source! No citation."
assert cleaned_text == expected_clean_text
assert annotations == expected_annotations
# OpenAI cites at the end of the sentence
assert cleaned_text[expected_annotations[0].index] == "."
assert cleaned_text[expected_annotations[1].index] == "?"
assert cleaned_text[expected_annotations[2].index] == "!"

View file

@ -1,183 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseInputToolWebSearch,
OpenAIResponseObject,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseToolMCP,
)
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
class TestToolContext:
def test_no_tools(self):
tools = []
context = ToolContext(tools)
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 0
assert len(context.previous_tools) == 0
assert len(context.previous_tool_listings) == 0
def test_no_previous_tools(self):
tools = [
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
OpenAIResponseInputToolMCP(server_label="label", server_url="url"),
]
context = ToolContext(tools)
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 2
assert len(context.previous_tools) == 0
assert len(context.previous_tool_listings) == 0
def test_reusable_server(self):
tools = [
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
)
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
OpenAIResponseToolMCP(server_label="alabel"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 1
assert context.tools_to_process[0].type == "file_search"
assert len(context.previous_tools) == 1
assert context.previous_tools["test_tool"].server_label == "alabel"
assert context.previous_tools["test_tool"].server_url == "aurl"
assert len(context.previous_tool_listings) == 1
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "alabel"
def test_multiple_reusable_servers(self):
tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
),
OpenAIResponseOutputMessageMCPListTools(
id="test2",
server_label="anotherlabel",
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
),
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 2
assert context.tools_to_process[0].type == "function"
assert context.tools_to_process[1].type == "web_search"
assert len(context.previous_tools) == 2
assert context.previous_tools["test_tool"].server_label == "alabel"
assert context.previous_tools["test_tool"].server_url == "aurl"
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
assert len(context.previous_tool_listings) == 2
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "alabel"
assert len(context.previous_tool_listings[1].tools) == 1
assert context.previous_tool_listings[1].server_label == "anotherlabel"
def test_multiple_servers_only_one_reusable(self):
tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test2",
server_label="anotherlabel",
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
)
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 3
assert context.tools_to_process[0].type == "function"
assert context.tools_to_process[1].type == "web_search"
assert context.tools_to_process[2].type == "mcp"
assert len(context.previous_tools) == 1
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
assert len(context.previous_tool_listings) == 1
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "anotherlabel"
def test_mismatched_allowed_tools(self):
tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl", allowed_tools=["test_tool_2"]),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool_1", input_schema={})]
),
OpenAIResponseOutputMessageMCPListTools(
id="test2",
server_label="anotherlabel",
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
),
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 3
assert context.tools_to_process[0].type == "function"
assert context.tools_to_process[1].type == "web_search"
assert context.tools_to_process[2].type == "mcp"
assert len(context.previous_tools) == 1
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
assert len(context.previous_tool_listings) == 1
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "anotherlabel"

View file

@ -1,155 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
extract_guardrail_ids,
run_guardrails,
)
@pytest.fixture
def mock_apis():
"""Create mock APIs for testing."""
return {
"inference_api": AsyncMock(),
"tool_groups_api": AsyncMock(),
"tool_runtime_api": AsyncMock(),
"responses_store": AsyncMock(),
"vector_io_api": AsyncMock(),
"conversations_api": AsyncMock(),
"safety_api": AsyncMock(),
}
@pytest.fixture
def responses_impl(mock_apis):
"""Create OpenAIResponsesImpl instance with mocked dependencies."""
return OpenAIResponsesImpl(**mock_apis)
def test_extract_guardrail_ids_from_strings(responses_impl):
"""Test extraction from simple string guardrail IDs."""
guardrails = ["llama-guard", "content-filter", "nsfw-detector"]
result = extract_guardrail_ids(guardrails)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_guardrail_ids_from_objects(responses_impl):
"""Test extraction from ResponseGuardrailSpec objects."""
guardrails = [
ResponseGuardrailSpec(type="llama-guard"),
ResponseGuardrailSpec(type="content-filter"),
]
result = extract_guardrail_ids(guardrails)
assert result == ["llama-guard", "content-filter"]
def test_extract_guardrail_ids_mixed_formats(responses_impl):
"""Test extraction from mixed string and object formats."""
guardrails = [
"llama-guard",
ResponseGuardrailSpec(type="content-filter"),
"nsfw-detector",
]
result = extract_guardrail_ids(guardrails)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_guardrail_ids_none_input(responses_impl):
"""Test extraction with None input."""
result = extract_guardrail_ids(None)
assert result == []
def test_extract_guardrail_ids_empty_list(responses_impl):
"""Test extraction with empty list."""
result = extract_guardrail_ids([])
assert result == []
def test_extract_guardrail_ids_unknown_format(responses_impl):
"""Test extraction with unknown guardrail format raises ValueError."""
# Create an object that's neither string nor ResponseGuardrailSpec
unknown_object = {"invalid": "format"} # Plain dict, not ResponseGuardrailSpec
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
extract_guardrail_ids(guardrails)
@pytest.fixture
def mock_safety_api():
"""Create mock safety API for guardrails testing."""
safety_api = AsyncMock()
# Mock the routing table and shields list for guardrails lookup
safety_api.routing_table = AsyncMock()
shield = AsyncMock()
shield.identifier = "llama-guard"
shield.provider_resource_id = "llama-guard-model"
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
return safety_api
async def test_run_guardrails_no_violation(mock_safety_api):
"""Test guardrails validation with no violations."""
text = "Hello world"
guardrail_ids = ["llama-guard"]
# Mock moderation to return non-flagged content
unflagged_result = ModerationObjectResults(flagged=False, categories={"violence": False})
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[unflagged_result])
mock_safety_api.run_moderation.return_value = mock_moderation_object
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
assert result is None
# Verify run_moderation was called with the correct model
mock_safety_api.run_moderation.assert_called_once()
call_args = mock_safety_api.run_moderation.call_args
assert call_args[1]["model"] == "llama-guard-model"
async def test_run_guardrails_with_violation(mock_safety_api):
"""Test guardrails validation with safety violation."""
text = "Harmful content"
guardrail_ids = ["llama-guard"]
# Mock moderation to return flagged content
flagged_result = ModerationObjectResults(
flagged=True,
categories={"violence": True},
user_message="Content flagged by moderation",
metadata={"violation_type": ["S1"]},
)
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result])
mock_safety_api.run_moderation.return_value = mock_moderation_object
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
assert result == "Content flagged by moderation (flagged for: violence) (violation type: S1)"
async def test_run_guardrails_empty_inputs(mock_safety_api):
"""Test guardrails validation with empty inputs."""
# Test empty guardrail_ids
result = await run_guardrails(mock_safety_api, "test", [])
assert result is None
# Test empty text
result = await run_guardrails(mock_safety_api, "", ["llama-guard"])
assert result is None
# Test both empty
result = await run_guardrails(mock_safety_api, "", [])
assert result is None

View file

@ -1,169 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from datetime import datetime
from unittest.mock import patch
import pytest
from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.core.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@pytest.fixture
async def test_setup(sqlite_kvstore):
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
yield agent_persistence
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Set creator's attributes for the session
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
mock_get_authenticated_user.return_value = User("test_user", creator_attributes)
# Create a session
session_id = await agent_persistence.create_session("Test Session")
# Get the session and verify access attributes were set
session_info = await agent_persistence.get_session_info(session_id)
assert session_info is not None
assert session_info.owner is not None
assert session_info.owner.attributes is not None
assert session_info.owner.attributes["roles"] == ["researcher"]
assert session_info.owner.attributes["teams"] == ["ai-team"]
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with specific access attributes
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
# User with matching attributes can access
mock_get_authenticated_user.return_value = User(
"testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
)
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is not None
assert retrieved_session.session_id == session_id
# User without matching attributes cannot access
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]})
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is None
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
owner=User("someone", {"roles": ["admin"]}),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
# Create a turn for this session
turn_id = str(uuid.uuid4())
turn = Turn(
session_id=session_id,
turn_id=turn_id,
steps=[],
started_at=datetime.now(),
input_messages=[],
output_message=CompletionMessage(
content="Hello",
stop_reason=StopReason.end_of_turn,
),
)
# Admin can add turn
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.add_turn_to_session(session_id, turn)
# Admin can get turn
retrieved_turn = await agent_persistence.get_session_turn(session_id, turn_id)
assert retrieved_turn is not None
assert retrieved_turn.turn_id == turn_id
# Regular user cannot get turn
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
with pytest.raises(ValueError):
await agent_persistence.get_session_turn(session_id, turn_id)
# Regular user cannot get turns for session
with pytest.raises(ValueError):
await agent_persistence.get_session_turns(session_id)
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
owner=User("someone", {"roles": ["admin"]}),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
turn_id = str(uuid.uuid4())
# Admin user can set inference iterations
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
# Admin user can get inference iterations
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters == 5
# Regular user cannot get inference iterations
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters is None
# Regular user cannot set inference iterations (should raise ValueError)
with pytest.raises(ValueError):
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 10)