feat: Add support for Conversatsions in Responses API

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-10-08 15:51:11 -04:00
parent 548ccff368
commit 1e59793288
18 changed files with 662 additions and 10 deletions

View file

@ -10083,6 +10083,10 @@
"type": "string", "type": "string",
"description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses." "description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses."
}, },
"conversation": {
"type": "string",
"description": "(Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation."
},
"store": { "store": {
"type": "boolean" "type": "boolean"
}, },

View file

@ -7493,6 +7493,12 @@ components:
(Optional) if specified, the new response will be a continuation of the (Optional) if specified, the new response will be a continuation of the
previous response. This can be used to easily fork-off new responses from previous response. This can be used to easily fork-off new responses from
existing responses. existing responses.
conversation:
type: string
description: >-
(Optional) The ID of a conversation to add the response to. Must begin
with 'conv_'. Input and output messages will be automatically added to
the conversation.
store: store:
type: boolean type: boolean
stream: stream:

View file

@ -8178,6 +8178,10 @@
"type": "string", "type": "string",
"description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses." "description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses."
}, },
"conversation": {
"type": "string",
"description": "(Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation."
},
"store": { "store": {
"type": "boolean" "type": "boolean"
}, },

View file

@ -6189,6 +6189,12 @@ components:
(Optional) if specified, the new response will be a continuation of the (Optional) if specified, the new response will be a continuation of the
previous response. This can be used to easily fork-off new responses from previous response. This can be used to easily fork-off new responses from
existing responses. existing responses.
conversation:
type: string
description: >-
(Optional) The ID of a conversation to add the response to. Must begin
with 'conv_'. Input and output messages will be automatically added to
the conversation.
store: store:
type: boolean type: boolean
stream: stream:

View file

@ -10187,6 +10187,10 @@
"type": "string", "type": "string",
"description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses." "description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses."
}, },
"conversation": {
"type": "string",
"description": "(Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation."
},
"store": { "store": {
"type": "boolean" "type": "boolean"
}, },

View file

@ -7634,6 +7634,12 @@ components:
(Optional) if specified, the new response will be a continuation of the (Optional) if specified, the new response will be a continuation of the
previous response. This can be used to easily fork-off new responses from previous response. This can be used to easily fork-off new responses from
existing responses. existing responses.
conversation:
type: string
description: >-
(Optional) The ID of a conversation to add the response to. Must begin
with 'conv_'. Input and output messages will be automatically added to
the conversation.
store: store:
type: boolean type: boolean
stream: stream:

View file

@ -812,6 +812,7 @@ class Agents(Protocol):
model: str, model: str,
instructions: str | None = None, instructions: str | None = None,
previous_response_id: str | None = None, previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True, store: bool | None = True,
stream: bool | None = False, stream: bool | None = False,
temperature: float | None = None, temperature: float | None = None,
@ -831,6 +832,7 @@ class Agents(Protocol):
:param input: Input message(s) to create the response. :param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions. :param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. :param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
:param include: (Optional) Additional fields to include in the response. :param include: (Optional) Additional fields to include in the response.
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications. :param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
:returns: An OpenAIResponseObject. :returns: An OpenAIResponseObject.

View file

@ -86,3 +86,18 @@ class TokenValidationError(ValueError):
def __init__(self, message: str) -> None: def __init__(self, message: str) -> None:
super().__init__(message) super().__init__(message)
class ConversationNotFoundError(ResourceNotFoundError):
"""raised when Llama Stack cannot find a referenced conversation"""
def __init__(self, conversation_id: str) -> None:
super().__init__(conversation_id, "Conversation", "client.conversations.list()")
class InvalidConversationIdError(ValueError):
"""raised when a conversation ID has an invalid format"""
def __init__(self, conversation_id: str) -> None:
message = f"Invalid conversation ID '{conversation_id}'. Expected an ID that begins with 'conv_'."
super().__init__(message)

View file

@ -150,6 +150,7 @@ async def resolve_impls(
provider_registry: ProviderRegistry, provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
policy: list[AccessRule], policy: list[AccessRule],
internal_impls: dict[Api, Any] | None = None,
) -> dict[Api, Any]: ) -> dict[Api, Any]:
""" """
Resolves provider implementations by: Resolves provider implementations by:
@ -172,7 +173,7 @@ async def resolve_impls(
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config) sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy) return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy, internal_impls)
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]: def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
@ -280,9 +281,10 @@ async def instantiate_providers(
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
run_config: StackRunConfig, run_config: StackRunConfig,
policy: list[AccessRule], policy: list[AccessRule],
internal_impls: dict[Api, Any] | None = None,
) -> dict[Api, Any]: ) -> dict[Api, Any]:
"""Instantiates providers asynchronously while managing dependencies.""" """Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {} impls: dict[Api, Any] = internal_impls.copy() if internal_impls else {}
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
# Skip providers that are not enabled # Skip providers that are not enabled

View file

@ -326,12 +326,17 @@ class Stack:
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name) dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else [] policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
impls = await resolve_impls(
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
)
# Add internal implementations after all other providers are resolved internal_impls = {}
add_internal_implementations(impls, self.run_config) add_internal_implementations(internal_impls, self.run_config)
impls = await resolve_impls(
self.run_config,
self.provider_registry or get_provider_registry(self.run_config),
dist_registry,
policy,
internal_impls,
)
if Api.prompts in impls: if Api.prompts in impls:
await impls[Api.prompts].initialize() await impls[Api.prompts].initialize()

View file

@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
deps[Api.safety], deps[Api.safety],
deps[Api.tool_runtime], deps[Api.tool_runtime],
deps[Api.tool_groups], deps[Api.tool_groups],
deps[Api.conversations],
policy, policy,
Api.telemetry in deps, Api.telemetry in deps,
) )

View file

@ -30,6 +30,7 @@ from llama_stack.apis.agents import (
) )
from llama_stack.apis.agents.openai_responses import OpenAIResponseText from llama_stack.apis.agents.openai_responses import OpenAIResponseText
from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Inference, Inference,
ToolConfig, ToolConfig,
@ -63,6 +64,7 @@ class MetaReferenceAgentsImpl(Agents):
safety_api: Safety, safety_api: Safety,
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
conversations_api: Conversations,
policy: list[AccessRule], policy: list[AccessRule],
telemetry_enabled: bool = False, telemetry_enabled: bool = False,
): ):
@ -72,6 +74,7 @@ class MetaReferenceAgentsImpl(Agents):
self.safety_api = safety_api self.safety_api = safety_api
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
self.conversations_api = conversations_api
self.telemetry_enabled = telemetry_enabled self.telemetry_enabled = telemetry_enabled
self.in_memory_store = InmemoryKVStoreImpl() self.in_memory_store = InmemoryKVStoreImpl()
@ -88,6 +91,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_runtime_api=self.tool_runtime_api, tool_runtime_api=self.tool_runtime_api,
responses_store=self.responses_store, responses_store=self.responses_store,
vector_io_api=self.vector_io_api, vector_io_api=self.vector_io_api,
conversations_api=self.conversations_api,
) )
async def create_agent( async def create_agent(
@ -325,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents):
model: str, model: str,
instructions: str | None = None, instructions: str | None = None,
previous_response_id: str | None = None, previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True, store: bool | None = True,
stream: bool | None = False, stream: bool | None = False,
temperature: float | None = None, temperature: float | None = None,
@ -339,6 +344,7 @@ class MetaReferenceAgentsImpl(Agents):
model, model,
instructions, instructions,
previous_response_id, previous_response_id,
conversation,
store, store,
stream, stream,
temperature, temperature,

View file

@ -24,6 +24,12 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseText, OpenAIResponseText,
OpenAIResponseTextFormat, OpenAIResponseTextFormat,
) )
from llama_stack.apis.common.errors import (
ConversationNotFoundError,
InvalidConversationIdError,
)
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.conversations.conversations import ConversationItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Inference, Inference,
OpenAIMessageParam, OpenAIMessageParam,
@ -61,12 +67,14 @@ class OpenAIResponsesImpl:
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore, responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO vector_io_api: VectorIO, # VectorIO
conversations_api: Conversations,
): ):
self.inference_api = inference_api self.inference_api = inference_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store self.responses_store = responses_store
self.vector_io_api = vector_io_api self.vector_io_api = vector_io_api
self.conversations_api = conversations_api
self.tool_executor = ToolExecutor( self.tool_executor = ToolExecutor(
tool_groups_api=tool_groups_api, tool_groups_api=tool_groups_api,
tool_runtime_api=tool_runtime_api, tool_runtime_api=tool_runtime_api,
@ -205,6 +213,7 @@ class OpenAIResponsesImpl:
model: str, model: str,
instructions: str | None = None, instructions: str | None = None,
previous_response_id: str | None = None, previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True, store: bool | None = True,
stream: bool | None = False, stream: bool | None = False,
temperature: float | None = None, temperature: float | None = None,
@ -221,11 +230,22 @@ class OpenAIResponsesImpl:
if shields is not None: if shields is not None:
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider") raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
if conversation is not None:
if not conversation.startswith("conv_"):
raise InvalidConversationIdError(conversation)
conversation_exists = await self._check_conversation_exists(conversation)
if not conversation_exists:
raise ConversationNotFoundError(conversation)
input = await self._load_conversation_context(conversation, input)
stream_gen = self._create_streaming_response( stream_gen = self._create_streaming_response(
input=input, input=input,
model=model, model=model,
instructions=instructions, instructions=instructions,
previous_response_id=previous_response_id, previous_response_id=previous_response_id,
conversation=conversation,
store=store, store=store,
temperature=temperature, temperature=temperature,
text=text, text=text,
@ -270,6 +290,7 @@ class OpenAIResponsesImpl:
model: str, model: str,
instructions: str | None = None, instructions: str | None = None,
previous_response_id: str | None = None, previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True, store: bool | None = True,
temperature: float | None = None, temperature: float | None = None,
text: OpenAIResponseText | None = None, text: OpenAIResponseText | None = None,
@ -296,7 +317,7 @@ class OpenAIResponsesImpl:
) )
# Create orchestrator and delegate streaming logic # Create orchestrator and delegate streaming logic
response_id = f"resp-{uuid.uuid4()}" response_id = f"resp_{uuid.uuid4()}"
created_at = int(time.time()) created_at = int(time.time())
orchestrator = StreamingResponseOrchestrator( orchestrator = StreamingResponseOrchestrator(
@ -327,5 +348,98 @@ class OpenAIResponsesImpl:
messages=orchestrator.final_messages, messages=orchestrator.final_messages,
) )
if conversation and final_response:
await self._sync_response_to_conversation(conversation, all_input, final_response)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject: async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
return await self.responses_store.delete_response_object(response_id) return await self.responses_store.delete_response_object(response_id)
async def _check_conversation_exists(self, conversation_id: str) -> bool:
"""Check if a conversation exists."""
try:
await self.conversations_api.get_conversation(conversation_id)
return True
except ConversationNotFoundError:
return False
async def _load_conversation_context(
self, conversation_id: str, content: str | list[OpenAIResponseInput]
) -> list[OpenAIResponseInput]:
"""Load conversation history and merge with provided content."""
try:
conversation_items = await self.conversations_api.list(conversation_id, order="asc")
context_messages = []
for item in conversation_items.data:
if isinstance(item, OpenAIResponseMessage):
if item.role == "user":
context_messages.append(
OpenAIResponseMessage(
role="user", content=item.content, id=item.id if hasattr(item, "id") else None
)
)
elif item.role == "assistant":
context_messages.append(
OpenAIResponseMessage(
role="assistant", content=item.content, id=item.id if hasattr(item, "id") else None
)
)
# add new content to context
if isinstance(content, str):
context_messages.append(OpenAIResponseMessage(role="user", content=content))
elif isinstance(content, list):
context_messages.extend(content)
return context_messages
except Exception as e:
logger.error(f"Failed to load conversation context for {conversation_id}: {e}")
if isinstance(content, str):
return [OpenAIResponseMessage(role="user", content=content)]
return content
async def _sync_response_to_conversation(
self, conversation_id: str, content: str | list[OpenAIResponseInput], response: OpenAIResponseObject
) -> None:
"""Sync content and response messages to the conversation."""
try:
conversation_items = []
# add user content message(s)
if isinstance(content, str):
conversation_items.append(
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": content}]}
)
elif isinstance(content, list):
for item in content:
if isinstance(item, OpenAIResponseMessage) and item.role == "user":
if isinstance(item.content, str):
conversation_items.append(
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": item.content}],
}
)
elif isinstance(item.content, list):
conversation_items.append({"type": "message", "role": "user", "content": item.content})
# add assistant response message
for output_item in response.output:
if isinstance(output_item, OpenAIResponseMessage) and output_item.role == "assistant":
if hasattr(output_item, "content") and isinstance(output_item.content, list):
conversation_items.append(
{"type": "message", "role": "assistant", "content": output_item.content}
)
if conversation_items:
adapter = TypeAdapter(list[ConversationItem])
validated_items = adapter.validate_python(conversation_items)
await self.conversations_api.add_items(conversation_id, validated_items)
except Exception as e:
logger.error(f"Failed to sync response {response.id} to conversation {conversation_id}: {e}")
# don't fail response creation if conversation sync fails
return None

View file

@ -35,6 +35,7 @@ def available_providers() -> list[ProviderSpec]:
Api.vector_dbs, Api.vector_dbs,
Api.tool_runtime, Api.tool_runtime,
Api.tool_groups, Api.tool_groups,
Api.conversations,
], ],
optional_api_dependencies=[ optional_api_dependencies=[
Api.telemetry, Api.telemetry,

View file

@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
@pytest.mark.integration
class TestConversationResponses:
"""Integration tests for the conversation parameter in responses API."""
def test_conversation_basic_workflow(self, openai_client, text_model_id):
"""Test basic conversation workflow: create conversation, add response, verify sync."""
conversation = openai_client.conversations.create(metadata={"topic": "test"})
assert conversation.id.startswith("conv_")
response = openai_client.responses.create(
model=text_model_id,
input=[{"role": "user", "content": "What are the 5 Ds of dodgeball?"}],
conversation=conversation.id,
)
assert response.id.startswith("resp_")
assert len(response.output_text.strip()) > 0
# Verify conversation was synced
conversation_items = openai_client.conversations.items.list(conversation.id)
assert len(conversation_items.data) >= 2
roles = [item.role for item in conversation_items.data if hasattr(item, "role")]
assert "user" in roles and "assistant" in roles
def test_conversation_multi_turn_and_streaming(self, openai_client, text_model_id):
"""Test multi-turn conversations and streaming responses."""
conversation = openai_client.conversations.create()
# First turn
response1 = openai_client.responses.create(
model=text_model_id,
input=[{"role": "user", "content": "Say hello"}],
conversation=conversation.id,
)
# Second turn with streaming
response_stream = openai_client.responses.create(
model=text_model_id,
input=[{"role": "user", "content": "Say goodbye"}],
conversation=conversation.id,
stream=True,
)
final_response = None
for chunk in response_stream:
if chunk.type == "response.completed":
final_response = chunk.response
break
assert response1.id != final_response.id
assert len(response1.output_text.strip()) > 0
assert len(final_response.output_text.strip()) > 0
# Verify all turns are in conversation
conversation_items = openai_client.conversations.items.list(conversation.id)
assert len(conversation_items.data) >= 4 # 2 user + 2 assistant messages
def test_conversation_context_loading(self, openai_client, text_model_id):
"""Test that conversation context is properly loaded for responses."""
conversation = openai_client.conversations.create(
items=[
{"type": "message", "role": "user", "content": "My name is Alice"},
{"type": "message", "role": "assistant", "content": "Hello Alice!"},
]
)
response = openai_client.responses.create(
model=text_model_id,
input=[{"role": "user", "content": "What's my name?"}],
conversation=conversation.id,
)
assert "alice" in response.output_text.lower()
def test_conversation_error_handling(self, openai_client, text_model_id):
"""Test error handling for invalid and nonexistent conversations."""
# Invalid conversation ID format
with pytest.raises(Exception) as exc_info:
openai_client.responses.create(
model=text_model_id,
input=[{"role": "user", "content": "Hello"}],
conversation="invalid_id",
)
assert any(word in str(exc_info.value).lower() for word in ["conv", "invalid", "bad"])
# Nonexistent conversation ID
with pytest.raises(Exception) as exc_info:
openai_client.responses.create(
model=text_model_id,
input=[{"role": "user", "content": "Hello"}],
conversation="conv_nonexistent123",
)
assert any(word in str(exc_info.value).lower() for word in ["not found", "404"])
def test_conversation_backward_compatibility(self, openai_client, text_model_id):
"""Test that responses work without conversation parameter (backward compatibility)."""
response = openai_client.responses.create(
model=text_model_id, input=[{"role": "user", "content": "Hello world"}]
)
assert response.id.startswith("resp_")
assert len(response.output_text.strip()) > 0
def test_conversation_compat_client(self, compat_client, text_model_id):
"""Test conversation parameter works with compatibility client."""
if not hasattr(compat_client, "conversations"):
pytest.skip("compat_client does not support conversations API")
conversation = compat_client.conversations.create()
response = compat_client.responses.create(
model=text_model_id, input="Tell me a joke", conversation=conversation.id
)
assert response is not None
assert len(response.output_text.strip()) > 0
conversation_items = compat_client.conversations.items.list(conversation.id)
assert len(conversation_items.data) >= 2

View file

@ -15,6 +15,7 @@ from llama_stack.apis.agents import (
AgentCreateResponse, AgentCreateResponse,
) )
from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime
@ -33,6 +34,7 @@ def mock_apis():
"safety_api": AsyncMock(spec=Safety), "safety_api": AsyncMock(spec=Safety),
"tool_runtime_api": AsyncMock(spec=ToolRuntime), "tool_runtime_api": AsyncMock(spec=ToolRuntime),
"tool_groups_api": AsyncMock(spec=ToolGroups), "tool_groups_api": AsyncMock(spec=ToolGroups),
"conversations_api": AsyncMock(spec=Conversations),
} }
@ -59,7 +61,8 @@ async def agents_impl(config, mock_apis):
mock_apis["safety_api"], mock_apis["safety_api"],
mock_apis["tool_runtime_api"], mock_apis["tool_runtime_api"],
mock_apis["tool_groups_api"], mock_apis["tool_groups_api"],
{}, mock_apis["conversations_api"],
[],
) )
await impl.initialize() await impl.initialize()
yield impl yield impl

View file

@ -0,0 +1,332 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseOutputMessageContentOutputText,
)
from llama_stack.apis.common.errors import (
ConversationNotFoundError,
InvalidConversationIdError,
)
from llama_stack.apis.conversations.conversations import (
Conversation,
ConversationItemList,
)
# Import existing fixtures from the main responses test file
pytest_plugins = ["tests.unit.providers.agents.meta_reference.test_openai_responses"]
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
@pytest.fixture
def responses_impl_with_conversations(
mock_inference_api,
mock_tool_groups_api,
mock_tool_runtime_api,
mock_responses_store,
mock_vector_io_api,
mock_conversations_api,
):
"""Create OpenAIResponsesImpl instance with conversations API."""
return OpenAIResponsesImpl(
inference_api=mock_inference_api,
tool_groups_api=mock_tool_groups_api,
tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api,
conversations_api=mock_conversations_api,
)
class TestConversationValidation:
"""Test conversation ID validation logic."""
async def test_conversation_existence_check_valid(self, responses_impl_with_conversations, mock_conversations_api):
"""Test conversation existence check for valid conversation."""
conv_id = "conv_valid123"
# Mock successful conversation retrieval
mock_conversations_api.get_conversation.return_value = Conversation(
id=conv_id, created_at=1234567890, metadata={}, object="conversation"
)
result = await responses_impl_with_conversations._check_conversation_exists(conv_id)
assert result is True
mock_conversations_api.get_conversation.assert_called_once_with(conv_id)
async def test_conversation_existence_check_invalid(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test conversation existence check for non-existent conversation."""
conv_id = "conv_nonexistent"
# Mock conversation not found
mock_conversations_api.get_conversation.side_effect = ConversationNotFoundError("conv_nonexistent")
result = await responses_impl_with_conversations._check_conversation_exists(conv_id)
assert result is False
mock_conversations_api.get_conversation.assert_called_once_with(conv_id)
class TestConversationContextLoading:
"""Test conversation context loading functionality."""
async def test_load_conversation_context_simple_input(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test loading conversation context with simple string input."""
conv_id = "conv_test123"
input_text = "Hello, how are you?"
# mock items in chronological order (a consequence of order="asc")
mock_conversation_items = ConversationItemList(
data=[
OpenAIResponseMessage(
id="msg_1",
content=[{"type": "input_text", "text": "Previous user message"}],
role="user",
status="completed",
type="message",
),
OpenAIResponseMessage(
id="msg_2",
content=[{"type": "output_text", "text": "Previous assistant response"}],
role="assistant",
status="completed",
type="message",
),
],
first_id="msg_1",
has_more=False,
last_id="msg_2",
object="list",
)
mock_conversations_api.list.return_value = mock_conversation_items
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_text)
# should have conversation history + new input
assert len(result) == 3
assert isinstance(result[0], OpenAIResponseMessage)
assert result[0].role == "user"
assert isinstance(result[1], OpenAIResponseMessage)
assert result[1].role == "assistant"
assert isinstance(result[2], OpenAIResponseMessage)
assert result[2].role == "user"
assert result[2].content == input_text
async def test_load_conversation_context_api_error(self, responses_impl_with_conversations, mock_conversations_api):
"""Test loading conversation context when API call fails."""
conv_id = "conv_test123"
input_text = "Hello"
mock_conversations_api.list.side_effect = Exception("API Error")
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_text)
assert len(result) == 1
assert isinstance(result[0], OpenAIResponseMessage)
assert result[0].role == "user"
assert result[0].content == input_text
async def test_load_conversation_context_with_list_input(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test loading conversation context with list input."""
conv_id = "conv_test123"
input_messages = [
OpenAIResponseMessage(role="user", content="First message"),
OpenAIResponseMessage(role="user", content="Second message"),
]
mock_conversations_api.list.return_value = ConversationItemList(
data=[], first_id=None, has_more=False, last_id=None, object="list"
)
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_messages)
assert len(result) == 2
assert result == input_messages
async def test_load_conversation_context_empty_conversation(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test loading context from empty conversation."""
conv_id = "conv_empty"
input_text = "Hello"
mock_conversations_api.list.return_value = ConversationItemList(
data=[], first_id=None, has_more=False, last_id=None, object="list"
)
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_text)
assert len(result) == 1
assert result[0].role == "user"
assert result[0].content == input_text
class TestMessageSyncing:
"""Test message syncing to conversations."""
async def test_sync_response_to_conversation_simple(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test syncing simple response to conversation."""
conv_id = "conv_test123"
input_text = "What are the 5 Ds of dodgeball?"
# mock response
mock_response = OpenAIResponseObject(
id="resp_123",
created_at=1234567890,
model="test-model",
object="response",
output=[
OpenAIResponseMessage(
id="msg_response",
content=[
OpenAIResponseOutputMessageContentOutputText(
text="The 5 Ds are: Dodge, Duck, Dip, Dive, and Dodge.", type="output_text", annotations=[]
)
],
role="assistant",
status="completed",
type="message",
)
],
status="completed",
)
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_text, mock_response)
# should call add_items with user input and assistant response
mock_conversations_api.add_items.assert_called_once()
call_args = mock_conversations_api.add_items.call_args
assert call_args[0][0] == conv_id # conversation_id
items = call_args[0][1] # conversation_items
assert len(items) == 2
# User message
assert items[0].type == "message"
assert items[0].role == "user"
assert items[0].content[0].type == "input_text"
assert items[0].content[0].text == input_text
# Assistant message
assert items[1].type == "message"
assert items[1].role == "assistant"
async def test_sync_response_to_conversation_api_error(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test syncing when conversations API call fails."""
conv_id = "conv_test123"
mock_response = OpenAIResponseObject(
id="resp_123", created_at=1234567890, model="test-model", object="response", output=[], status="completed"
)
# Mock API error
mock_conversations_api.add_items.side_effect = Exception("API Error")
# Should not raise exception (graceful failure)
result = await responses_impl_with_conversations._sync_response_to_conversation(conv_id, "Hello", mock_response)
assert result is None
class TestIntegrationWorkflow:
"""Integration tests for the full conversation workflow."""
async def test_create_response_with_valid_conversation(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test creating a response with a valid conversation parameter."""
mock_conversations_api.get_conversation.return_value = Conversation(
id="conv_test123", created_at=1234567890, metadata={}, object="conversation"
)
mock_conversations_api.list.return_value = ConversationItemList(
data=[], first_id=None, has_more=False, last_id=None, object="list"
)
async def mock_streaming_response(*args, **kwargs):
mock_response = OpenAIResponseObject(
id="resp_test123",
created_at=1234567890,
model="test-model",
object="response",
output=[
OpenAIResponseMessage(
id="msg_response",
content=[
OpenAIResponseOutputMessageContentOutputText(
text="Test response", type="output_text", annotations=[]
)
],
role="assistant",
status="completed",
type="message",
)
],
status="completed",
)
yield OpenAIResponseObjectStreamResponseCompleted(response=mock_response, type="response.completed")
responses_impl_with_conversations._create_streaming_response = mock_streaming_response
input_text = "Hello, how are you?"
conversation_id = "conv_test123"
response = await responses_impl_with_conversations.create_openai_response(
input=input_text, model="test-model", conversation=conversation_id, stream=False
)
assert response is not None
assert response.id == "resp_test123"
mock_conversations_api.get_conversation.assert_called_once_with(conversation_id)
mock_conversations_api.list.assert_called_once_with(conversation_id, order="asc")
# Note: conversation sync happens in the streaming response flow,
# which is complex to mock fully in this unit test
async def test_create_response_with_invalid_conversation_id(self, responses_impl_with_conversations):
"""Test creating a response with an invalid conversation ID."""
with pytest.raises(InvalidConversationIdError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation="invalid_id", stream=False
)
assert "Expected an ID that begins with 'conv_'" in str(exc_info.value)
async def test_create_response_with_nonexistent_conversation(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test creating a response with a non-existent conversation."""
mock_conversations_api.get_conversation.side_effect = ConversationNotFoundError("conv_nonexistent")
with pytest.raises(ConversationNotFoundError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation="conv_nonexistent", stream=False
)
assert "not found" in str(exc_info.value)

View file

@ -83,9 +83,21 @@ def mock_vector_io_api():
return vector_io_api return vector_io_api
@pytest.fixture
def mock_conversations_api():
"""Mock conversations API for testing."""
mock_api = AsyncMock()
return mock_api
@pytest.fixture @pytest.fixture
def openai_responses_impl( def openai_responses_impl(
mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api mock_inference_api,
mock_tool_groups_api,
mock_tool_runtime_api,
mock_responses_store,
mock_vector_io_api,
mock_conversations_api,
): ):
return OpenAIResponsesImpl( return OpenAIResponsesImpl(
inference_api=mock_inference_api, inference_api=mock_inference_api,
@ -93,6 +105,7 @@ def openai_responses_impl(
tool_runtime_api=mock_tool_runtime_api, tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store, responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api, vector_io_api=mock_vector_io_api,
conversations_api=mock_conversations_api,
) )