diff --git a/docs/static/deprecated-llama-stack-spec.html b/docs/static/deprecated-llama-stack-spec.html index 2fa339eeb..0ea2e8c43 100644 --- a/docs/static/deprecated-llama-stack-spec.html +++ b/docs/static/deprecated-llama-stack-spec.html @@ -10083,6 +10083,10 @@ "type": "string", "description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses." }, + "conversation": { + "type": "string", + "description": "(Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation." + }, "store": { "type": "boolean" }, diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 98af89fa8..008cd8673 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -7493,6 +7493,12 @@ components: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. + conversation: + type: string + description: >- + (Optional) The ID of a conversation to add the response to. Must begin + with 'conv_'. Input and output messages will be automatically added to + the conversation. store: type: boolean stream: diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html index 1064c1433..7e534f995 100644 --- a/docs/static/llama-stack-spec.html +++ b/docs/static/llama-stack-spec.html @@ -8178,6 +8178,10 @@ "type": "string", "description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses." }, + "conversation": { + "type": "string", + "description": "(Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation." + }, "store": { "type": "boolean" }, diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index f36d69e3a..bad40c87d 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -6189,6 +6189,12 @@ components: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. + conversation: + type: string + description: >- + (Optional) The ID of a conversation to add the response to. Must begin + with 'conv_'. Input and output messages will be automatically added to + the conversation. store: type: boolean stream: diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 25fa2bc03..36c63367c 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -10187,6 +10187,10 @@ "type": "string", "description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses." }, + "conversation": { + "type": "string", + "description": "(Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation." + }, "store": { "type": "boolean" }, diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index df0112be7..4475cc8f0 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -7634,6 +7634,12 @@ components: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. + conversation: + type: string + description: >- + (Optional) The ID of a conversation to add the response to. Must begin + with 'conv_'. Input and output messages will be automatically added to + the conversation. store: type: boolean stream: diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 5983b5c45..ff4412c12 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -812,6 +812,7 @@ class Agents(Protocol): model: str, instructions: str | None = None, previous_response_id: str | None = None, + conversation: str | None = None, store: bool | None = True, stream: bool | None = False, temperature: float | None = None, @@ -831,6 +832,7 @@ class Agents(Protocol): :param input: Input message(s) to create the response. :param model: The underlying LLM used for completions. :param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. + :param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation. :param include: (Optional) Additional fields to include in the response. :param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications. :returns: An OpenAIResponseObject. diff --git a/llama_stack/apis/common/errors.py b/llama_stack/apis/common/errors.py index 4c9c0a818..a421d0c6f 100644 --- a/llama_stack/apis/common/errors.py +++ b/llama_stack/apis/common/errors.py @@ -86,3 +86,18 @@ class TokenValidationError(ValueError): def __init__(self, message: str) -> None: super().__init__(message) + + +class ConversationNotFoundError(ResourceNotFoundError): + """raised when Llama Stack cannot find a referenced conversation""" + + def __init__(self, conversation_id: str) -> None: + super().__init__(conversation_id, "Conversation", "client.conversations.list()") + + +class InvalidConversationIdError(ValueError): + """raised when a conversation ID has an invalid format""" + + def __init__(self, conversation_id: str) -> None: + message = f"Invalid conversation ID '{conversation_id}'. Expected an ID that begins with 'conv_'." + super().__init__(message) diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py index 0d6f54f9e..749253865 100644 --- a/llama_stack/core/resolver.py +++ b/llama_stack/core/resolver.py @@ -150,6 +150,7 @@ async def resolve_impls( provider_registry: ProviderRegistry, dist_registry: DistributionRegistry, policy: list[AccessRule], + internal_impls: dict[Api, Any] | None = None, ) -> dict[Api, Any]: """ Resolves provider implementations by: @@ -172,7 +173,7 @@ async def resolve_impls( sorted_providers = sort_providers_by_deps(providers_with_specs, run_config) - return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy) + return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy, internal_impls) def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]: @@ -280,9 +281,10 @@ async def instantiate_providers( dist_registry: DistributionRegistry, run_config: StackRunConfig, policy: list[AccessRule], + internal_impls: dict[Api, Any] | None = None, ) -> dict[Api, Any]: """Instantiates providers asynchronously while managing dependencies.""" - impls: dict[Api, Any] = {} + impls: dict[Api, Any] = internal_impls.copy() if internal_impls else {} inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} for api_str, provider in sorted_providers: # Skip providers that are not enabled diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index 49f6b9cc9..2eab9344f 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -326,12 +326,17 @@ class Stack: dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name) policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else [] - impls = await resolve_impls( - self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy - ) - # Add internal implementations after all other providers are resolved - add_internal_implementations(impls, self.run_config) + internal_impls = {} + add_internal_implementations(internal_impls, self.run_config) + + impls = await resolve_impls( + self.run_config, + self.provider_registry or get_provider_registry(self.run_config), + dist_registry, + policy, + internal_impls, + ) if Api.prompts in impls: await impls[Api.prompts].initialize() diff --git a/llama_stack/providers/inline/agents/meta_reference/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py index 37b0b50c8..d5cfd2e5b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/__init__.py +++ b/llama_stack/providers/inline/agents/meta_reference/__init__.py @@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap deps[Api.safety], deps[Api.tool_runtime], deps[Api.tool_groups], + deps[Api.conversations], policy, Api.telemetry in deps, ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index cfaf56a34..27d3a94cc 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -30,6 +30,7 @@ from llama_stack.apis.agents import ( ) from llama_stack.apis.agents.openai_responses import OpenAIResponseText from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.apis.conversations import Conversations from llama_stack.apis.inference import ( Inference, ToolConfig, @@ -63,6 +64,7 @@ class MetaReferenceAgentsImpl(Agents): safety_api: Safety, tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, + conversations_api: Conversations, policy: list[AccessRule], telemetry_enabled: bool = False, ): @@ -72,6 +74,7 @@ class MetaReferenceAgentsImpl(Agents): self.safety_api = safety_api self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api + self.conversations_api = conversations_api self.telemetry_enabled = telemetry_enabled self.in_memory_store = InmemoryKVStoreImpl() @@ -88,6 +91,7 @@ class MetaReferenceAgentsImpl(Agents): tool_runtime_api=self.tool_runtime_api, responses_store=self.responses_store, vector_io_api=self.vector_io_api, + conversations_api=self.conversations_api, ) async def create_agent( @@ -325,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents): model: str, instructions: str | None = None, previous_response_id: str | None = None, + conversation: str | None = None, store: bool | None = True, stream: bool | None = False, temperature: float | None = None, @@ -339,6 +344,7 @@ class MetaReferenceAgentsImpl(Agents): model, instructions, previous_response_id, + conversation, store, stream, temperature, diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index fabe46f43..b317d6672 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -24,6 +24,12 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseText, OpenAIResponseTextFormat, ) +from llama_stack.apis.common.errors import ( + ConversationNotFoundError, + InvalidConversationIdError, +) +from llama_stack.apis.conversations import Conversations +from llama_stack.apis.conversations.conversations import ConversationItem from llama_stack.apis.inference import ( Inference, OpenAIMessageParam, @@ -61,12 +67,14 @@ class OpenAIResponsesImpl: tool_runtime_api: ToolRuntime, responses_store: ResponsesStore, vector_io_api: VectorIO, # VectorIO + conversations_api: Conversations, ): self.inference_api = inference_api self.tool_groups_api = tool_groups_api self.tool_runtime_api = tool_runtime_api self.responses_store = responses_store self.vector_io_api = vector_io_api + self.conversations_api = conversations_api self.tool_executor = ToolExecutor( tool_groups_api=tool_groups_api, tool_runtime_api=tool_runtime_api, @@ -205,6 +213,7 @@ class OpenAIResponsesImpl: model: str, instructions: str | None = None, previous_response_id: str | None = None, + conversation: str | None = None, store: bool | None = True, stream: bool | None = False, temperature: float | None = None, @@ -221,11 +230,22 @@ class OpenAIResponsesImpl: if shields is not None: raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider") + if conversation is not None: + if not conversation.startswith("conv_"): + raise InvalidConversationIdError(conversation) + + conversation_exists = await self._check_conversation_exists(conversation) + if not conversation_exists: + raise ConversationNotFoundError(conversation) + + input = await self._load_conversation_context(conversation, input) + stream_gen = self._create_streaming_response( input=input, model=model, instructions=instructions, previous_response_id=previous_response_id, + conversation=conversation, store=store, temperature=temperature, text=text, @@ -270,6 +290,7 @@ class OpenAIResponsesImpl: model: str, instructions: str | None = None, previous_response_id: str | None = None, + conversation: str | None = None, store: bool | None = True, temperature: float | None = None, text: OpenAIResponseText | None = None, @@ -296,7 +317,7 @@ class OpenAIResponsesImpl: ) # Create orchestrator and delegate streaming logic - response_id = f"resp-{uuid.uuid4()}" + response_id = f"resp_{uuid.uuid4()}" created_at = int(time.time()) orchestrator = StreamingResponseOrchestrator( @@ -327,5 +348,98 @@ class OpenAIResponsesImpl: messages=orchestrator.final_messages, ) + if conversation and final_response: + await self._sync_response_to_conversation(conversation, all_input, final_response) + async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject: return await self.responses_store.delete_response_object(response_id) + + async def _check_conversation_exists(self, conversation_id: str) -> bool: + """Check if a conversation exists.""" + try: + await self.conversations_api.get_conversation(conversation_id) + return True + except ConversationNotFoundError: + return False + + async def _load_conversation_context( + self, conversation_id: str, content: str | list[OpenAIResponseInput] + ) -> list[OpenAIResponseInput]: + """Load conversation history and merge with provided content.""" + try: + conversation_items = await self.conversations_api.list(conversation_id, order="asc") + + context_messages = [] + for item in conversation_items.data: + if isinstance(item, OpenAIResponseMessage): + if item.role == "user": + context_messages.append( + OpenAIResponseMessage( + role="user", content=item.content, id=item.id if hasattr(item, "id") else None + ) + ) + elif item.role == "assistant": + context_messages.append( + OpenAIResponseMessage( + role="assistant", content=item.content, id=item.id if hasattr(item, "id") else None + ) + ) + + # add new content to context + if isinstance(content, str): + context_messages.append(OpenAIResponseMessage(role="user", content=content)) + elif isinstance(content, list): + context_messages.extend(content) + + return context_messages + + except Exception as e: + logger.error(f"Failed to load conversation context for {conversation_id}: {e}") + if isinstance(content, str): + return [OpenAIResponseMessage(role="user", content=content)] + return content + + async def _sync_response_to_conversation( + self, conversation_id: str, content: str | list[OpenAIResponseInput], response: OpenAIResponseObject + ) -> None: + """Sync content and response messages to the conversation.""" + try: + conversation_items = [] + + # add user content message(s) + if isinstance(content, str): + conversation_items.append( + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": content}]} + ) + elif isinstance(content, list): + for item in content: + if isinstance(item, OpenAIResponseMessage) and item.role == "user": + if isinstance(item.content, str): + conversation_items.append( + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": item.content}], + } + ) + elif isinstance(item.content, list): + conversation_items.append({"type": "message", "role": "user", "content": item.content}) + + # add assistant response message + for output_item in response.output: + if isinstance(output_item, OpenAIResponseMessage) and output_item.role == "assistant": + if hasattr(output_item, "content") and isinstance(output_item.content, list): + conversation_items.append( + {"type": "message", "role": "assistant", "content": output_item.content} + ) + + if conversation_items: + adapter = TypeAdapter(list[ConversationItem]) + validated_items = adapter.validate_python(conversation_items) + await self.conversations_api.add_items(conversation_id, validated_items) + + except Exception as e: + logger.error(f"Failed to sync response {response.id} to conversation {conversation_id}: {e}") + # don't fail response creation if conversation sync fails + + return None diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index bc46b4de2..d7e9bed88 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -35,6 +35,7 @@ def available_providers() -> list[ProviderSpec]: Api.vector_dbs, Api.tool_runtime, Api.tool_groups, + Api.conversations, ], optional_api_dependencies=[ Api.telemetry, diff --git a/tests/integration/responses/test_conversation_responses.py b/tests/integration/responses/test_conversation_responses.py new file mode 100644 index 000000000..ed9753884 --- /dev/null +++ b/tests/integration/responses/test_conversation_responses.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + + +@pytest.mark.integration +class TestConversationResponses: + """Integration tests for the conversation parameter in responses API.""" + + def test_conversation_basic_workflow(self, openai_client, text_model_id): + """Test basic conversation workflow: create conversation, add response, verify sync.""" + conversation = openai_client.conversations.create(metadata={"topic": "test"}) + assert conversation.id.startswith("conv_") + + response = openai_client.responses.create( + model=text_model_id, + input=[{"role": "user", "content": "What are the 5 Ds of dodgeball?"}], + conversation=conversation.id, + ) + + assert response.id.startswith("resp_") + assert len(response.output_text.strip()) > 0 + + # Verify conversation was synced + conversation_items = openai_client.conversations.items.list(conversation.id) + assert len(conversation_items.data) >= 2 + + roles = [item.role for item in conversation_items.data if hasattr(item, "role")] + assert "user" in roles and "assistant" in roles + + def test_conversation_multi_turn_and_streaming(self, openai_client, text_model_id): + """Test multi-turn conversations and streaming responses.""" + conversation = openai_client.conversations.create() + + # First turn + response1 = openai_client.responses.create( + model=text_model_id, + input=[{"role": "user", "content": "Say hello"}], + conversation=conversation.id, + ) + + # Second turn with streaming + response_stream = openai_client.responses.create( + model=text_model_id, + input=[{"role": "user", "content": "Say goodbye"}], + conversation=conversation.id, + stream=True, + ) + + final_response = None + for chunk in response_stream: + if chunk.type == "response.completed": + final_response = chunk.response + break + + assert response1.id != final_response.id + assert len(response1.output_text.strip()) > 0 + assert len(final_response.output_text.strip()) > 0 + + # Verify all turns are in conversation + conversation_items = openai_client.conversations.items.list(conversation.id) + assert len(conversation_items.data) >= 4 # 2 user + 2 assistant messages + + def test_conversation_context_loading(self, openai_client, text_model_id): + """Test that conversation context is properly loaded for responses.""" + conversation = openai_client.conversations.create( + items=[ + {"type": "message", "role": "user", "content": "My name is Alice"}, + {"type": "message", "role": "assistant", "content": "Hello Alice!"}, + ] + ) + + response = openai_client.responses.create( + model=text_model_id, + input=[{"role": "user", "content": "What's my name?"}], + conversation=conversation.id, + ) + + assert "alice" in response.output_text.lower() + + def test_conversation_error_handling(self, openai_client, text_model_id): + """Test error handling for invalid and nonexistent conversations.""" + # Invalid conversation ID format + with pytest.raises(Exception) as exc_info: + openai_client.responses.create( + model=text_model_id, + input=[{"role": "user", "content": "Hello"}], + conversation="invalid_id", + ) + assert any(word in str(exc_info.value).lower() for word in ["conv", "invalid", "bad"]) + + # Nonexistent conversation ID + with pytest.raises(Exception) as exc_info: + openai_client.responses.create( + model=text_model_id, + input=[{"role": "user", "content": "Hello"}], + conversation="conv_nonexistent123", + ) + assert any(word in str(exc_info.value).lower() for word in ["not found", "404"]) + + def test_conversation_backward_compatibility(self, openai_client, text_model_id): + """Test that responses work without conversation parameter (backward compatibility).""" + response = openai_client.responses.create( + model=text_model_id, input=[{"role": "user", "content": "Hello world"}] + ) + + assert response.id.startswith("resp_") + assert len(response.output_text.strip()) > 0 + + def test_conversation_compat_client(self, compat_client, text_model_id): + """Test conversation parameter works with compatibility client.""" + if not hasattr(compat_client, "conversations"): + pytest.skip("compat_client does not support conversations API") + + conversation = compat_client.conversations.create() + response = compat_client.responses.create( + model=text_model_id, input="Tell me a joke", conversation=conversation.id + ) + + assert response is not None + assert len(response.output_text.strip()) > 0 + + conversation_items = compat_client.conversations.items.list(conversation.id) + assert len(conversation_items.data) >= 2 diff --git a/tests/unit/providers/agent/test_meta_reference_agent.py b/tests/unit/providers/agent/test_meta_reference_agent.py index fdbb2b8e9..cfb3e1327 100644 --- a/tests/unit/providers/agent/test_meta_reference_agent.py +++ b/tests/unit/providers/agent/test_meta_reference_agent.py @@ -15,6 +15,7 @@ from llama_stack.apis.agents import ( AgentCreateResponse, ) from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.apis.conversations import Conversations from llama_stack.apis.inference import Inference from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime @@ -33,6 +34,7 @@ def mock_apis(): "safety_api": AsyncMock(spec=Safety), "tool_runtime_api": AsyncMock(spec=ToolRuntime), "tool_groups_api": AsyncMock(spec=ToolGroups), + "conversations_api": AsyncMock(spec=Conversations), } @@ -59,7 +61,8 @@ async def agents_impl(config, mock_apis): mock_apis["safety_api"], mock_apis["tool_runtime_api"], mock_apis["tool_groups_api"], - {}, + mock_apis["conversations_api"], + [], ) await impl.initialize() yield impl diff --git a/tests/unit/providers/agents/meta_reference/test_conversation_integration.py b/tests/unit/providers/agents/meta_reference/test_conversation_integration.py new file mode 100644 index 000000000..fd99c7514 --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/test_conversation_integration.py @@ -0,0 +1,332 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import pytest + +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseMessage, + OpenAIResponseObject, + OpenAIResponseObjectStreamResponseCompleted, + OpenAIResponseOutputMessageContentOutputText, +) +from llama_stack.apis.common.errors import ( + ConversationNotFoundError, + InvalidConversationIdError, +) +from llama_stack.apis.conversations.conversations import ( + Conversation, + ConversationItemList, +) + +# Import existing fixtures from the main responses test file +pytest_plugins = ["tests.unit.providers.agents.meta_reference.test_openai_responses"] + +from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( + OpenAIResponsesImpl, +) + + +@pytest.fixture +def responses_impl_with_conversations( + mock_inference_api, + mock_tool_groups_api, + mock_tool_runtime_api, + mock_responses_store, + mock_vector_io_api, + mock_conversations_api, +): + """Create OpenAIResponsesImpl instance with conversations API.""" + return OpenAIResponsesImpl( + inference_api=mock_inference_api, + tool_groups_api=mock_tool_groups_api, + tool_runtime_api=mock_tool_runtime_api, + responses_store=mock_responses_store, + vector_io_api=mock_vector_io_api, + conversations_api=mock_conversations_api, + ) + + +class TestConversationValidation: + """Test conversation ID validation logic.""" + + async def test_conversation_existence_check_valid(self, responses_impl_with_conversations, mock_conversations_api): + """Test conversation existence check for valid conversation.""" + conv_id = "conv_valid123" + + # Mock successful conversation retrieval + mock_conversations_api.get_conversation.return_value = Conversation( + id=conv_id, created_at=1234567890, metadata={}, object="conversation" + ) + + result = await responses_impl_with_conversations._check_conversation_exists(conv_id) + + assert result is True + mock_conversations_api.get_conversation.assert_called_once_with(conv_id) + + async def test_conversation_existence_check_invalid( + self, responses_impl_with_conversations, mock_conversations_api + ): + """Test conversation existence check for non-existent conversation.""" + conv_id = "conv_nonexistent" + + # Mock conversation not found + mock_conversations_api.get_conversation.side_effect = ConversationNotFoundError("conv_nonexistent") + + result = await responses_impl_with_conversations._check_conversation_exists(conv_id) + + assert result is False + mock_conversations_api.get_conversation.assert_called_once_with(conv_id) + + +class TestConversationContextLoading: + """Test conversation context loading functionality.""" + + async def test_load_conversation_context_simple_input( + self, responses_impl_with_conversations, mock_conversations_api + ): + """Test loading conversation context with simple string input.""" + conv_id = "conv_test123" + input_text = "Hello, how are you?" + + # mock items in chronological order (a consequence of order="asc") + mock_conversation_items = ConversationItemList( + data=[ + OpenAIResponseMessage( + id="msg_1", + content=[{"type": "input_text", "text": "Previous user message"}], + role="user", + status="completed", + type="message", + ), + OpenAIResponseMessage( + id="msg_2", + content=[{"type": "output_text", "text": "Previous assistant response"}], + role="assistant", + status="completed", + type="message", + ), + ], + first_id="msg_1", + has_more=False, + last_id="msg_2", + object="list", + ) + + mock_conversations_api.list.return_value = mock_conversation_items + + result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_text) + + # should have conversation history + new input + assert len(result) == 3 + assert isinstance(result[0], OpenAIResponseMessage) + assert result[0].role == "user" + assert isinstance(result[1], OpenAIResponseMessage) + assert result[1].role == "assistant" + assert isinstance(result[2], OpenAIResponseMessage) + assert result[2].role == "user" + assert result[2].content == input_text + + async def test_load_conversation_context_api_error(self, responses_impl_with_conversations, mock_conversations_api): + """Test loading conversation context when API call fails.""" + conv_id = "conv_test123" + input_text = "Hello" + + mock_conversations_api.list.side_effect = Exception("API Error") + + result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_text) + + assert len(result) == 1 + assert isinstance(result[0], OpenAIResponseMessage) + assert result[0].role == "user" + assert result[0].content == input_text + + async def test_load_conversation_context_with_list_input( + self, responses_impl_with_conversations, mock_conversations_api + ): + """Test loading conversation context with list input.""" + conv_id = "conv_test123" + input_messages = [ + OpenAIResponseMessage(role="user", content="First message"), + OpenAIResponseMessage(role="user", content="Second message"), + ] + + mock_conversations_api.list.return_value = ConversationItemList( + data=[], first_id=None, has_more=False, last_id=None, object="list" + ) + + result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_messages) + + assert len(result) == 2 + assert result == input_messages + + async def test_load_conversation_context_empty_conversation( + self, responses_impl_with_conversations, mock_conversations_api + ): + """Test loading context from empty conversation.""" + conv_id = "conv_empty" + input_text = "Hello" + + mock_conversations_api.list.return_value = ConversationItemList( + data=[], first_id=None, has_more=False, last_id=None, object="list" + ) + + result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_text) + + assert len(result) == 1 + assert result[0].role == "user" + assert result[0].content == input_text + + +class TestMessageSyncing: + """Test message syncing to conversations.""" + + async def test_sync_response_to_conversation_simple( + self, responses_impl_with_conversations, mock_conversations_api + ): + """Test syncing simple response to conversation.""" + conv_id = "conv_test123" + input_text = "What are the 5 Ds of dodgeball?" + + # mock response + mock_response = OpenAIResponseObject( + id="resp_123", + created_at=1234567890, + model="test-model", + object="response", + output=[ + OpenAIResponseMessage( + id="msg_response", + content=[ + OpenAIResponseOutputMessageContentOutputText( + text="The 5 Ds are: Dodge, Duck, Dip, Dive, and Dodge.", type="output_text", annotations=[] + ) + ], + role="assistant", + status="completed", + type="message", + ) + ], + status="completed", + ) + + await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_text, mock_response) + + # should call add_items with user input and assistant response + mock_conversations_api.add_items.assert_called_once() + call_args = mock_conversations_api.add_items.call_args + + assert call_args[0][0] == conv_id # conversation_id + items = call_args[0][1] # conversation_items + + assert len(items) == 2 + # User message + assert items[0].type == "message" + assert items[0].role == "user" + assert items[0].content[0].type == "input_text" + assert items[0].content[0].text == input_text + + # Assistant message + assert items[1].type == "message" + assert items[1].role == "assistant" + + async def test_sync_response_to_conversation_api_error( + self, responses_impl_with_conversations, mock_conversations_api + ): + """Test syncing when conversations API call fails.""" + conv_id = "conv_test123" + + mock_response = OpenAIResponseObject( + id="resp_123", created_at=1234567890, model="test-model", object="response", output=[], status="completed" + ) + + # Mock API error + mock_conversations_api.add_items.side_effect = Exception("API Error") + + # Should not raise exception (graceful failure) + result = await responses_impl_with_conversations._sync_response_to_conversation(conv_id, "Hello", mock_response) + assert result is None + + +class TestIntegrationWorkflow: + """Integration tests for the full conversation workflow.""" + + async def test_create_response_with_valid_conversation( + self, responses_impl_with_conversations, mock_conversations_api + ): + """Test creating a response with a valid conversation parameter.""" + mock_conversations_api.get_conversation.return_value = Conversation( + id="conv_test123", created_at=1234567890, metadata={}, object="conversation" + ) + + mock_conversations_api.list.return_value = ConversationItemList( + data=[], first_id=None, has_more=False, last_id=None, object="list" + ) + + async def mock_streaming_response(*args, **kwargs): + mock_response = OpenAIResponseObject( + id="resp_test123", + created_at=1234567890, + model="test-model", + object="response", + output=[ + OpenAIResponseMessage( + id="msg_response", + content=[ + OpenAIResponseOutputMessageContentOutputText( + text="Test response", type="output_text", annotations=[] + ) + ], + role="assistant", + status="completed", + type="message", + ) + ], + status="completed", + ) + + yield OpenAIResponseObjectStreamResponseCompleted(response=mock_response, type="response.completed") + + responses_impl_with_conversations._create_streaming_response = mock_streaming_response + + input_text = "Hello, how are you?" + conversation_id = "conv_test123" + + response = await responses_impl_with_conversations.create_openai_response( + input=input_text, model="test-model", conversation=conversation_id, stream=False + ) + + assert response is not None + assert response.id == "resp_test123" + + mock_conversations_api.get_conversation.assert_called_once_with(conversation_id) + + mock_conversations_api.list.assert_called_once_with(conversation_id, order="asc") + + # Note: conversation sync happens in the streaming response flow, + # which is complex to mock fully in this unit test + + async def test_create_response_with_invalid_conversation_id(self, responses_impl_with_conversations): + """Test creating a response with an invalid conversation ID.""" + with pytest.raises(InvalidConversationIdError) as exc_info: + await responses_impl_with_conversations.create_openai_response( + input="Hello", model="test-model", conversation="invalid_id", stream=False + ) + + assert "Expected an ID that begins with 'conv_'" in str(exc_info.value) + + async def test_create_response_with_nonexistent_conversation( + self, responses_impl_with_conversations, mock_conversations_api + ): + """Test creating a response with a non-existent conversation.""" + mock_conversations_api.get_conversation.side_effect = ConversationNotFoundError("conv_nonexistent") + + with pytest.raises(ConversationNotFoundError) as exc_info: + await responses_impl_with_conversations.create_openai_response( + input="Hello", model="test-model", conversation="conv_nonexistent", stream=False + ) + + assert "not found" in str(exc_info.value) diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index 2ff586a08..2c09ad1d7 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -83,9 +83,21 @@ def mock_vector_io_api(): return vector_io_api +@pytest.fixture +def mock_conversations_api(): + """Mock conversations API for testing.""" + mock_api = AsyncMock() + return mock_api + + @pytest.fixture def openai_responses_impl( - mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api + mock_inference_api, + mock_tool_groups_api, + mock_tool_runtime_api, + mock_responses_store, + mock_vector_io_api, + mock_conversations_api, ): return OpenAIResponsesImpl( inference_api=mock_inference_api, @@ -93,6 +105,7 @@ def openai_responses_impl( tool_runtime_api=mock_tool_runtime_api, responses_store=mock_responses_store, vector_io_api=mock_vector_io_api, + conversations_api=mock_conversations_api, )