mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
fix: openai_compat messages system/assistant non-str content (#2095)
# What does this PR do? When converting OpenAI message content for the "system" and "assistant" roles to Llama Stack inference APIs (used for some providers when dealing with Llama models via OpenAI API requests to get proper prompt / tool handling), we were not properly converting any non-string content. I discovered this while running the new Responses AI verification suite against the Fireworks provider, but instead of fixing it as part of some ongoing work there split this out into a separate PR. This fixes that, by using the `openai_content_to_content` helper we used elsewhere to ensure content parts were mapped properly. ## Test Plan I added a couple of new tests to `test_openai_compat` to reproduce this issue and validate its fix. I ran those as below: ``` python -m pytest -s -v tests/unit/providers/utils/inference/test_openai_compat.py ``` Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
272d3359ee
commit
f1b103e6c8
2 changed files with 54 additions and 6 deletions
|
@ -108,6 +108,7 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionChoice,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ToolConfig,
|
||||
)
|
||||
|
@ -987,7 +988,7 @@ def _convert_openai_sampling_params(
|
|||
|
||||
|
||||
def openai_messages_to_messages(
|
||||
messages: list[OpenAIChatCompletionMessage],
|
||||
messages: list[OpenAIMessageParam],
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Convert a list of OpenAIChatCompletionMessage into a list of Message.
|
||||
|
@ -995,12 +996,12 @@ def openai_messages_to_messages(
|
|||
converted_messages = []
|
||||
for message in messages:
|
||||
if message.role == "system":
|
||||
converted_message = SystemMessage(content=message.content)
|
||||
converted_message = SystemMessage(content=openai_content_to_content(message.content))
|
||||
elif message.role == "user":
|
||||
converted_message = UserMessage(content=openai_content_to_content(message.content))
|
||||
elif message.role == "assistant":
|
||||
converted_message = CompletionMessage(
|
||||
content=message.content,
|
||||
content=openai_content_to_content(message.content),
|
||||
tool_calls=_convert_openai_tool_calls(message.tool_calls),
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
|
@ -1331,7 +1332,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIChatCompletionMessage],
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
|
|
|
@ -7,9 +7,20 @@
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.common.content_types import TextContentItem
|
||||
from llama_stack.apis.inference.inference import CompletionMessage, UserMessage
|
||||
from llama_stack.apis.inference.inference import (
|
||||
CompletionMessage,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
openai_messages_to_messages,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -67,3 +78,39 @@ async def test_convert_message_to_openai_dict_with_builtin_tool_call():
|
|||
{"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_messages_to_messages_with_content_str():
|
||||
openai_messages = [
|
||||
OpenAISystemMessageParam(content="system message"),
|
||||
OpenAIUserMessageParam(content="user message"),
|
||||
OpenAIAssistantMessageParam(content="assistant message"),
|
||||
]
|
||||
|
||||
llama_messages = openai_messages_to_messages(openai_messages)
|
||||
assert len(llama_messages) == 3
|
||||
assert isinstance(llama_messages[0], SystemMessage)
|
||||
assert isinstance(llama_messages[1], UserMessage)
|
||||
assert isinstance(llama_messages[2], CompletionMessage)
|
||||
assert llama_messages[0].content == "system message"
|
||||
assert llama_messages[1].content == "user message"
|
||||
assert llama_messages[2].content == "assistant message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_messages_to_messages_with_content_list():
|
||||
openai_messages = [
|
||||
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
|
||||
OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]),
|
||||
OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]),
|
||||
]
|
||||
|
||||
llama_messages = openai_messages_to_messages(openai_messages)
|
||||
assert len(llama_messages) == 3
|
||||
assert isinstance(llama_messages[0], SystemMessage)
|
||||
assert isinstance(llama_messages[1], UserMessage)
|
||||
assert isinstance(llama_messages[2], CompletionMessage)
|
||||
assert llama_messages[0].content[0].text == "system message"
|
||||
assert llama_messages[1].content[0].text == "user message"
|
||||
assert llama_messages[2].content[0].text == "assistant message"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue