diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index f90245c08..e2314d44f 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -108,6 +108,7 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice, + OpenAIMessageParam, OpenAIResponseFormatParam, ToolConfig, ) @@ -987,7 +988,7 @@ def _convert_openai_sampling_params( def openai_messages_to_messages( - messages: list[OpenAIChatCompletionMessage], + messages: list[OpenAIMessageParam], ) -> list[Message]: """ Convert a list of OpenAIChatCompletionMessage into a list of Message. @@ -995,12 +996,12 @@ def openai_messages_to_messages( converted_messages = [] for message in messages: if message.role == "system": - converted_message = SystemMessage(content=message.content) + converted_message = SystemMessage(content=openai_content_to_content(message.content)) elif message.role == "user": converted_message = UserMessage(content=openai_content_to_content(message.content)) elif message.role == "assistant": converted_message = CompletionMessage( - content=message.content, + content=openai_content_to_content(message.content), tool_calls=_convert_openai_tool_calls(message.tool_calls), stop_reason=StopReason.end_of_turn, ) @@ -1331,7 +1332,7 @@ class OpenAIChatCompletionToLlamaStackMixin: async def openai_chat_completion( self, model: str, - messages: list[OpenAIChatCompletionMessage], + messages: list[OpenAIMessageParam], frequency_penalty: float | None = None, function_call: str | dict[str, Any] | None = None, functions: list[dict[str, Any]] | None = None, diff --git a/tests/unit/providers/utils/inference/test_openai_compat.py b/tests/unit/providers/utils/inference/test_openai_compat.py index fda762d7f..4c75b8a2f 100644 --- a/tests/unit/providers/utils/inference/test_openai_compat.py +++ b/tests/unit/providers/utils/inference/test_openai_compat.py @@ -7,9 +7,20 @@ import pytest from llama_stack.apis.common.content_types import TextContentItem -from llama_stack.apis.inference.inference import CompletionMessage, UserMessage +from llama_stack.apis.inference.inference import ( + CompletionMessage, + OpenAIAssistantMessageParam, + OpenAIChatCompletionContentPartTextParam, + OpenAISystemMessageParam, + OpenAIUserMessageParam, + SystemMessage, + UserMessage, +) from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall -from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict +from llama_stack.providers.utils.inference.openai_compat import ( + convert_message_to_openai_dict, + openai_messages_to_messages, +) @pytest.mark.asyncio @@ -67,3 +78,39 @@ async def test_convert_message_to_openai_dict_with_builtin_tool_call(): {"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}} ], } + + +@pytest.mark.asyncio +async def test_openai_messages_to_messages_with_content_str(): + openai_messages = [ + OpenAISystemMessageParam(content="system message"), + OpenAIUserMessageParam(content="user message"), + OpenAIAssistantMessageParam(content="assistant message"), + ] + + llama_messages = openai_messages_to_messages(openai_messages) + assert len(llama_messages) == 3 + assert isinstance(llama_messages[0], SystemMessage) + assert isinstance(llama_messages[1], UserMessage) + assert isinstance(llama_messages[2], CompletionMessage) + assert llama_messages[0].content == "system message" + assert llama_messages[1].content == "user message" + assert llama_messages[2].content == "assistant message" + + +@pytest.mark.asyncio +async def test_openai_messages_to_messages_with_content_list(): + openai_messages = [ + OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]), + OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]), + OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]), + ] + + llama_messages = openai_messages_to_messages(openai_messages) + assert len(llama_messages) == 3 + assert isinstance(llama_messages[0], SystemMessage) + assert isinstance(llama_messages[1], UserMessage) + assert isinstance(llama_messages[2], CompletionMessage) + assert llama_messages[0].content[0].text == "system message" + assert llama_messages[1].content[0].text == "user message" + assert llama_messages[2].content[0].text == "assistant message"