mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
feat(openai-responses): Support multiple message roles in API inputs
Also update the nesting to add multiple messages(where appropriate) rather then a single message with multiple content parts. Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
e4888b930b
commit
4657077418
2 changed files with 83 additions and 14 deletions
|
@ -34,8 +34,10 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
OpenAIChatCompletionToolCallFunction,
|
OpenAIChatCompletionToolCallFunction,
|
||||||
OpenAIChoice,
|
OpenAIChoice,
|
||||||
|
OpenAIDeveloperMessageParam,
|
||||||
OpenAIImageURL,
|
OpenAIImageURL,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
|
OpenAISystemMessageParam,
|
||||||
OpenAIToolMessageParam,
|
OpenAIToolMessageParam,
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
)
|
)
|
||||||
|
@ -77,6 +79,16 @@ async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> lis
|
||||||
return output_messages
|
return output_messages
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_message_type_by_role(role: str):
|
||||||
|
role_to_type = {
|
||||||
|
"user": OpenAIUserMessageParam,
|
||||||
|
"system": OpenAISystemMessageParam,
|
||||||
|
"assistant": OpenAIAssistantMessageParam,
|
||||||
|
"developer": OpenAIDeveloperMessageParam,
|
||||||
|
}
|
||||||
|
return role_to_type.get(role)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponsesImpl:
|
class OpenAIResponsesImpl:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -116,26 +128,32 @@ class OpenAIResponsesImpl:
|
||||||
if previous_response_id:
|
if previous_response_id:
|
||||||
previous_response = await self.get_openai_response(previous_response_id)
|
previous_response = await self.get_openai_response(previous_response_id)
|
||||||
messages.extend(await _previous_response_to_messages(previous_response))
|
messages.extend(await _previous_response_to_messages(previous_response))
|
||||||
|
|
||||||
# TODO: refactor this user_content parsing out into a separate method
|
# TODO: refactor this user_content parsing out into a separate method
|
||||||
user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
|
content: str | list[OpenAIChatCompletionContentPartParam] = ""
|
||||||
if isinstance(input, list):
|
if isinstance(input, list):
|
||||||
user_content = []
|
for input_message in input:
|
||||||
for user_input in input:
|
if isinstance(input_message.content, list):
|
||||||
if isinstance(user_input.content, list):
|
content = []
|
||||||
for user_input_content in user_input.content:
|
for input_message_content in input_message.content:
|
||||||
if isinstance(user_input_content, OpenAIResponseInputMessageContentText):
|
if isinstance(input_message_content, OpenAIResponseInputMessageContentText):
|
||||||
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input_content.text))
|
content.append(OpenAIChatCompletionContentPartTextParam(text=input_message_content.text))
|
||||||
elif isinstance(user_input_content, OpenAIResponseInputMessageContentImage):
|
elif isinstance(input_message_content, OpenAIResponseInputMessageContentImage):
|
||||||
if user_input_content.image_url:
|
if input_message_content.image_url:
|
||||||
image_url = OpenAIImageURL(
|
image_url = OpenAIImageURL(
|
||||||
url=user_input_content.image_url, detail=user_input_content.detail
|
url=input_message_content.image_url, detail=input_message_content.detail
|
||||||
)
|
)
|
||||||
user_content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||||
else:
|
else:
|
||||||
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input.content))
|
content = input_message.content
|
||||||
|
message_type = await _get_message_type_by_role(input_message.role)
|
||||||
|
if message_type is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Llama Stack OpenAI Responses does not yet support message role '{input_message.role}' in this context"
|
||||||
|
)
|
||||||
|
messages.append(message_type(content=content))
|
||||||
else:
|
else:
|
||||||
user_content = input
|
messages.append(OpenAIUserMessageParam(content=input))
|
||||||
messages.append(OpenAIUserMessageParam(content=user_content))
|
|
||||||
|
|
||||||
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
|
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
|
||||||
chat_response = await self.inference_api.openai_chat_completion(
|
chat_response = await self.inference_api.openai_chat_completion(
|
||||||
|
|
|
@ -9,10 +9,15 @@ from unittest.mock import AsyncMock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
OpenAIResponseInputMessage,
|
||||||
|
OpenAIResponseInputMessageContentText,
|
||||||
OpenAIResponseInputToolWebSearch,
|
OpenAIResponseInputToolWebSearch,
|
||||||
OpenAIResponseOutputMessage,
|
OpenAIResponseOutputMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
OpenAIDeveloperMessageParam,
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||||
|
@ -156,3 +161,49 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
||||||
assert len(result.output) >= 1
|
assert len(result.output) >= 1
|
||||||
assert isinstance(result.output[1], OpenAIResponseOutputMessage)
|
assert isinstance(result.output[1], OpenAIResponseOutputMessage)
|
||||||
assert result.output[1].content[0].text == "Dublin"
|
assert result.output[1].content[0].text == "Dublin"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
|
||||||
|
"""Test creating an OpenAI response with multiple messages."""
|
||||||
|
# Setup
|
||||||
|
input_messages = [
|
||||||
|
OpenAIResponseInputMessage(role="developer", content="You are a helpful assistant", name=None),
|
||||||
|
OpenAIResponseInputMessage(role="user", content="Name some towns in Ireland", name=None),
|
||||||
|
OpenAIResponseInputMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=[
|
||||||
|
OpenAIResponseInputMessageContentText(text="Galway, Longford, Sligo"),
|
||||||
|
OpenAIResponseInputMessageContentText(text="Dublin"),
|
||||||
|
],
|
||||||
|
name=None,
|
||||||
|
),
|
||||||
|
OpenAIResponseInputMessage(role="user", content="Which is the largest town in Ireland?", name=None),
|
||||||
|
]
|
||||||
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
|
mock_inference_api.openai_chat_completion.return_value = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
await openai_responses_impl.create_openai_response(
|
||||||
|
input=input_messages,
|
||||||
|
model=model,
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the the correct messages were sent to the inference API i.e.
|
||||||
|
# All of the responses message were convered to the chat completion message objects
|
||||||
|
inference_messages = mock_inference_api.openai_chat_completion.call_args_list[0].kwargs["messages"]
|
||||||
|
for i, m in enumerate(input_messages):
|
||||||
|
if isinstance(m.content, str):
|
||||||
|
assert inference_messages[i].content == m.content
|
||||||
|
else:
|
||||||
|
assert inference_messages[i].content[0].text == m.content[0].text
|
||||||
|
assert isinstance(inference_messages[i].content[0], OpenAIChatCompletionContentPartTextParam)
|
||||||
|
assert inference_messages[i].role == m.role
|
||||||
|
if m.role == "user":
|
||||||
|
assert isinstance(inference_messages[i], OpenAIUserMessageParam)
|
||||||
|
elif m.role == "assistant":
|
||||||
|
assert isinstance(inference_messages[i], OpenAIAssistantMessageParam)
|
||||||
|
else:
|
||||||
|
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue