From 46570774182f3e7fa55bb82f04b2c85734aa162f Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Fri, 2 May 2025 11:10:07 +0100 Subject: [PATCH] feat(openai-responses): Support multiple message roles in API inputs Also update the nesting to add multiple messages(where appropriate) rather then a single message with multiple content parts. Signed-off-by: Derek Higgins --- .../agents/meta_reference/openai_responses.py | 46 ++++++++++++----- .../meta_reference/test_openai_responses.py | 51 +++++++++++++++++++ 2 files changed, 83 insertions(+), 14 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 24a99dd6e..172dbd5ae 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -34,8 +34,10 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletionContentPartTextParam, OpenAIChatCompletionToolCallFunction, OpenAIChoice, + OpenAIDeveloperMessageParam, OpenAIImageURL, OpenAIMessageParam, + OpenAISystemMessageParam, OpenAIToolMessageParam, OpenAIUserMessageParam, ) @@ -77,6 +79,16 @@ async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> lis return output_messages +async def _get_message_type_by_role(role: str): + role_to_type = { + "user": OpenAIUserMessageParam, + "system": OpenAISystemMessageParam, + "assistant": OpenAIAssistantMessageParam, + "developer": OpenAIDeveloperMessageParam, + } + return role_to_type.get(role) + + class OpenAIResponsesImpl: def __init__( self, @@ -116,26 +128,32 @@ class OpenAIResponsesImpl: if previous_response_id: previous_response = await self.get_openai_response(previous_response_id) messages.extend(await _previous_response_to_messages(previous_response)) + # TODO: refactor this user_content parsing out into a separate method - user_content: str | list[OpenAIChatCompletionContentPartParam] = "" + content: str | list[OpenAIChatCompletionContentPartParam] = "" if isinstance(input, list): - user_content = [] - for user_input in input: - if isinstance(user_input.content, list): - for user_input_content in user_input.content: - if isinstance(user_input_content, OpenAIResponseInputMessageContentText): - user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input_content.text)) - elif isinstance(user_input_content, OpenAIResponseInputMessageContentImage): - if user_input_content.image_url: + for input_message in input: + if isinstance(input_message.content, list): + content = [] + for input_message_content in input_message.content: + if isinstance(input_message_content, OpenAIResponseInputMessageContentText): + content.append(OpenAIChatCompletionContentPartTextParam(text=input_message_content.text)) + elif isinstance(input_message_content, OpenAIResponseInputMessageContentImage): + if input_message_content.image_url: image_url = OpenAIImageURL( - url=user_input_content.image_url, detail=user_input_content.detail + url=input_message_content.image_url, detail=input_message_content.detail ) - user_content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url)) + content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url)) else: - user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input.content)) + content = input_message.content + message_type = await _get_message_type_by_role(input_message.role) + if message_type is None: + raise ValueError( + f"Llama Stack OpenAI Responses does not yet support message role '{input_message.role}' in this context" + ) + messages.append(message_type(content=content)) else: - user_content = input - messages.append(OpenAIUserMessageParam(content=user_content)) + messages.append(OpenAIUserMessageParam(content=input)) chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None chat_response = await self.inference_api.openai_chat_completion( diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index d321b29b9..e0acac643 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -9,10 +9,15 @@ from unittest.mock import AsyncMock import pytest from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseInputMessage, + OpenAIResponseInputMessageContentText, OpenAIResponseInputToolWebSearch, OpenAIResponseOutputMessage, ) from llama_stack.apis.inference.inference import ( + OpenAIAssistantMessageParam, + OpenAIChatCompletionContentPartTextParam, + OpenAIDeveloperMessageParam, OpenAIUserMessageParam, ) from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime @@ -156,3 +161,49 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon assert len(result.output) >= 1 assert isinstance(result.output[1], OpenAIResponseOutputMessage) assert result.output[1].content[0].text == "Dublin" + + +@pytest.mark.asyncio +async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api): + """Test creating an OpenAI response with multiple messages.""" + # Setup + input_messages = [ + OpenAIResponseInputMessage(role="developer", content="You are a helpful assistant", name=None), + OpenAIResponseInputMessage(role="user", content="Name some towns in Ireland", name=None), + OpenAIResponseInputMessage( + role="assistant", + content=[ + OpenAIResponseInputMessageContentText(text="Galway, Longford, Sligo"), + OpenAIResponseInputMessageContentText(text="Dublin"), + ], + name=None, + ), + OpenAIResponseInputMessage(role="user", content="Which is the largest town in Ireland?", name=None), + ] + model = "meta-llama/Llama-3.1-8B-Instruct" + + mock_inference_api.openai_chat_completion.return_value = load_chat_completion_fixture("simple_chat_completion.yaml") + + # Execute + await openai_responses_impl.create_openai_response( + input=input_messages, + model=model, + temperature=0.1, + ) + + # Verify the the correct messages were sent to the inference API i.e. + # All of the responses message were convered to the chat completion message objects + inference_messages = mock_inference_api.openai_chat_completion.call_args_list[0].kwargs["messages"] + for i, m in enumerate(input_messages): + if isinstance(m.content, str): + assert inference_messages[i].content == m.content + else: + assert inference_messages[i].content[0].text == m.content[0].text + assert isinstance(inference_messages[i].content[0], OpenAIChatCompletionContentPartTextParam) + assert inference_messages[i].role == m.role + if m.role == "user": + assert isinstance(inference_messages[i], OpenAIUserMessageParam) + elif m.role == "assistant": + assert isinstance(inference_messages[i], OpenAIAssistantMessageParam) + else: + assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)