mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 14:57:20 +00:00
feat(responses)!: improve responses + conversations implementations (#3810)
This PR updates the Conversation item related types and improves a couple critical parts of the implemenation: - it creates a streaming output item for the final assistant message output by the model. until now we only added content parts and included that message in the final response. - rewrites the conversation update code completely to account for items other than messages (tool calls, outputs, etc.) ## Test Plan Used the test script from https://github.com/llamastack/llama-stack-client-python/pull/281 for this ``` TEST_API_BASE_URL=http://localhost:8321/v1 \ pytest tests/integration/test_agent_turn_step_events.py::test_client_side_function_tool -xvs ```
This commit is contained in:
parent
add8cd801b
commit
e9b4278a51
129 changed files with 86266 additions and 903 deletions
|
@ -11,6 +11,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
)
|
||||
from llama_stack.apis.common.errors import (
|
||||
|
@ -67,101 +68,6 @@ class TestConversationValidation:
|
|||
)
|
||||
|
||||
|
||||
class TestConversationContextLoading:
|
||||
"""Test conversation context loading functionality."""
|
||||
|
||||
async def test_load_conversation_context_simple_input(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test loading conversation context with simple string input."""
|
||||
conv_id = "conv_test123"
|
||||
input_text = "Hello, how are you?"
|
||||
|
||||
# mock items in chronological order (a consequence of order="asc")
|
||||
mock_conversation_items = ConversationItemList(
|
||||
data=[
|
||||
OpenAIResponseMessage(
|
||||
id="msg_1",
|
||||
content=[{"type": "input_text", "text": "Previous user message"}],
|
||||
role="user",
|
||||
status="completed",
|
||||
type="message",
|
||||
),
|
||||
OpenAIResponseMessage(
|
||||
id="msg_2",
|
||||
content=[{"type": "output_text", "text": "Previous assistant response"}],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
),
|
||||
],
|
||||
first_id="msg_1",
|
||||
has_more=False,
|
||||
last_id="msg_2",
|
||||
object="list",
|
||||
)
|
||||
|
||||
mock_conversations_api.list.return_value = mock_conversation_items
|
||||
|
||||
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_text)
|
||||
|
||||
# should have conversation history + new input
|
||||
assert len(result) == 3
|
||||
assert isinstance(result[0], OpenAIResponseMessage)
|
||||
assert result[0].role == "user"
|
||||
assert isinstance(result[1], OpenAIResponseMessage)
|
||||
assert result[1].role == "assistant"
|
||||
assert isinstance(result[2], OpenAIResponseMessage)
|
||||
assert result[2].role == "user"
|
||||
assert result[2].content == input_text
|
||||
|
||||
async def test_load_conversation_context_api_error(self, responses_impl_with_conversations, mock_conversations_api):
|
||||
"""Test loading conversation context when API call fails."""
|
||||
conv_id = "conv_test123"
|
||||
input_text = "Hello"
|
||||
|
||||
mock_conversations_api.list.side_effect = Exception("API Error")
|
||||
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
await responses_impl_with_conversations._load_conversation_context(conv_id, input_text)
|
||||
|
||||
async def test_load_conversation_context_with_list_input(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test loading conversation context with list input."""
|
||||
conv_id = "conv_test123"
|
||||
input_messages = [
|
||||
OpenAIResponseMessage(role="user", content="First message"),
|
||||
OpenAIResponseMessage(role="user", content="Second message"),
|
||||
]
|
||||
|
||||
mock_conversations_api.list.return_value = ConversationItemList(
|
||||
data=[], first_id=None, has_more=False, last_id=None, object="list"
|
||||
)
|
||||
|
||||
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_messages)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result == input_messages
|
||||
|
||||
async def test_load_conversation_context_empty_conversation(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test loading context from empty conversation."""
|
||||
conv_id = "conv_empty"
|
||||
input_text = "Hello"
|
||||
|
||||
mock_conversations_api.list.return_value = ConversationItemList(
|
||||
data=[], first_id=None, has_more=False, last_id=None, object="list"
|
||||
)
|
||||
|
||||
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_text)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].role == "user"
|
||||
assert result[0].content == input_text
|
||||
|
||||
|
||||
class TestMessageSyncing:
|
||||
"""Test message syncing to conversations."""
|
||||
|
||||
|
@ -172,29 +78,22 @@ class TestMessageSyncing:
|
|||
conv_id = "conv_test123"
|
||||
input_text = "What are the 5 Ds of dodgeball?"
|
||||
|
||||
# mock response
|
||||
mock_response = OpenAIResponseObject(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
model="test-model",
|
||||
object="response",
|
||||
output=[
|
||||
OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[
|
||||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text="The 5 Ds are: Dodge, Duck, Dip, Dive, and Dodge.", type="output_text", annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
],
|
||||
status="completed",
|
||||
)
|
||||
# Output items (what the model generated)
|
||||
output_items = [
|
||||
OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[
|
||||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text="The 5 Ds are: Dodge, Duck, Dip, Dive, and Dodge.", type="output_text", annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
]
|
||||
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_text, mock_response)
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_text, output_items)
|
||||
|
||||
# should call add_items with user input and assistant response
|
||||
mock_conversations_api.add_items.assert_called_once()
|
||||
|
@ -218,30 +117,38 @@ class TestMessageSyncing:
|
|||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
mock_conversations_api.add_items.side_effect = Exception("API Error")
|
||||
mock_response = OpenAIResponseObject(
|
||||
id="resp_123", created_at=1234567890, model="test-model", object="response", output=[], status="completed"
|
||||
)
|
||||
output_items = []
|
||||
|
||||
# matching the behavior of OpenAI here
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(
|
||||
"conv_test123", "Hello", mock_response
|
||||
"conv_test123", "Hello", output_items
|
||||
)
|
||||
|
||||
async def test_sync_unsupported_types(self, responses_impl_with_conversations):
|
||||
mock_response = OpenAIResponseObject(
|
||||
id="resp_123", created_at=1234567890, model="test-model", object="response", output=[], status="completed"
|
||||
)
|
||||
|
||||
with pytest.raises(NotImplementedError, match="Unsupported input item type"):
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(
|
||||
"conv_123", [{"not": "message"}], mock_response
|
||||
async def test_sync_with_list_input(self, responses_impl_with_conversations, mock_conversations_api):
|
||||
"""Test syncing with list of input messages."""
|
||||
conv_id = "conv_test123"
|
||||
input_messages = [
|
||||
OpenAIResponseMessage(role="user", content=[{"type": "input_text", "text": "First message"}]),
|
||||
]
|
||||
output_items = [
|
||||
OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text="Response", type="output_text")],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(NotImplementedError, match="Unsupported message role: system"):
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(
|
||||
"conv_123", [OpenAIResponseMessage(role="system", content="test")], mock_response
|
||||
)
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_messages, output_items)
|
||||
|
||||
mock_conversations_api.add_items.assert_called_once()
|
||||
call_args = mock_conversations_api.add_items.call_args
|
||||
|
||||
items = call_args[0][1]
|
||||
# Should have input message + output message
|
||||
assert len(items) == 2
|
||||
|
||||
|
||||
class TestIntegrationWorkflow:
|
||||
|
@ -256,24 +163,34 @@ class TestIntegrationWorkflow:
|
|||
)
|
||||
|
||||
async def mock_streaming_response(*args, **kwargs):
|
||||
message_item = OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[
|
||||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text="Test response", type="output_text", annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
|
||||
# Emit output_item.done event first (needed for conversation sync)
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id="resp_test123",
|
||||
item=message_item,
|
||||
output_index=0,
|
||||
sequence_number=1,
|
||||
type="response.output_item.done",
|
||||
)
|
||||
|
||||
# Then emit response.completed
|
||||
mock_response = OpenAIResponseObject(
|
||||
id="resp_test123",
|
||||
created_at=1234567890,
|
||||
model="test-model",
|
||||
object="response",
|
||||
output=[
|
||||
OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[
|
||||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text="Test response", type="output_text", annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
],
|
||||
output=[message_item],
|
||||
status="completed",
|
||||
)
|
||||
|
||||
|
@ -291,10 +208,9 @@ class TestIntegrationWorkflow:
|
|||
assert response is not None
|
||||
assert response.id == "resp_test123"
|
||||
|
||||
mock_conversations_api.list.assert_called_once_with(conversation_id, order="asc")
|
||||
|
||||
# Note: conversation sync happens in the streaming response flow,
|
||||
# which is complex to mock fully in this unit test
|
||||
# Note: conversation sync happens inside _create_streaming_response,
|
||||
# which we're mocking here, so we can't test it in this unit test.
|
||||
# The sync logic is tested separately in TestMessageSyncing.
|
||||
|
||||
async def test_create_response_with_invalid_conversation_id(self, responses_impl_with_conversations):
|
||||
"""Test creating a response with an invalid conversation ID."""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue