fix: ensure assistant message is followed by tool call message as expected by openai

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
Gordon Sim 2025-08-15 21:56:19 +01:00
parent 58e164b8bc
commit 012088e084
3 changed files with 147 additions and 5 deletions

View file

@ -101,14 +101,22 @@ async def convert_response_input_to_chat_messages(
""" """
messages: list[OpenAIMessageParam] = [] messages: list[OpenAIMessageParam] = []
if isinstance(input, list): if isinstance(input, list):
# extract all OpenAIResponseInputFunctionToolCallOutput items
# so their corresponding OpenAIToolMessageParam instances can
# be added immediately following the corresponding
# OpenAIAssistantMessageParam
tool_call_results = {}
for input_item in input: for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
messages.append( tool_call_results[input_item.call_id] = OpenAIToolMessageParam(
OpenAIToolMessageParam(
content=input_item.output, content=input_item.output,
tool_call_id=input_item.call_id, tool_call_id=input_item.call_id,
) )
)
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
# skip as these have been extracted and inserted in order
pass
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall): elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
tool_call = OpenAIChatCompletionToolCall( tool_call = OpenAIChatCompletionToolCall(
index=0, index=0,
@ -119,6 +127,9 @@ async def convert_response_input_to_chat_messages(
), ),
) )
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
if input_item.call_id in tool_call_results:
messages.append(tool_call_results[input_item.call_id])
del tool_call_results[input_item.call_id]
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall): elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
tool_call = OpenAIChatCompletionToolCall( tool_call = OpenAIChatCompletionToolCall(
index=0, index=0,
@ -146,6 +157,8 @@ async def convert_response_input_to_chat_messages(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
) )
messages.append(message_type(content=content)) messages.append(message_type(content=content))
for result in tool_call_results.values():
messages.append(result)
else: else:
messages.append(OpenAIUserMessageParam(content=input)) messages.append(OpenAIUserMessageParam(content=input))
return messages return messages

View file

@ -260,6 +260,94 @@ def test_response_non_streaming_custom_tool(compat_client, text_model_id, case):
assert response.output[0].name == "get_weather" assert response.output[0].name == "get_weather"
@pytest.mark.parametrize("case", custom_tool_test_cases)
def test_response_function_call_ordering_1(compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=case.tools,
stream=False,
)
assert len(response.output) == 1
assert response.output[0].type == "function_call"
assert response.output[0].status == "completed"
assert response.output[0].name == "get_weather"
inputs = []
inputs.append(
{
"role": "user",
"content": case.input,
}
)
inputs.append(
{
"type": "function_call_output",
"output": "It is raining.",
"call_id": response.output[0].call_id,
}
)
response = compat_client.responses.create(
model=text_model_id, input=inputs, tools=case.tools, stream=False, previous_response_id=response.id
)
assert len(response.output) == 1
def test_response_function_call_ordering_2(compat_client, text_model_id):
tools = [
{
"type": "function",
"name": "get_weather",
"description": "Get current temperature for a given location.",
"parameters": {
"additionalProperties": False,
"properties": {
"location": {
"description": "City and country e.g. Bogotá, Colombia",
"type": "string",
}
},
"required": ["location"],
"type": "object",
},
}
]
inputs = [
{
"role": "user",
"content": "Is the weather better in San Francisco or Los Angeles?",
}
]
response = compat_client.responses.create(
model=text_model_id,
input=inputs,
tools=tools,
stream=False,
)
for output in response.output:
if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
inputs.append(output)
for output in response.output:
if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
weather = "It is raining."
if "Los Angeles" in output.arguments:
weather = "It is cloudy."
inputs.append(
{
"type": "function_call_output",
"output": weather,
"call_id": output.call_id,
}
)
response = compat_client.responses.create(
model=text_model_id,
input=inputs,
tools=tools,
stream=False,
)
assert len(response.output) == 1
assert "Los Angeles" in response.output_text
@pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases) @pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases)
def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case): def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence.""" """Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""

View file

@ -146,6 +146,47 @@ class TestConvertResponseInputToChatMessages:
assert result[0].tool_calls[0].function.name == "test_function" assert result[0].tool_calls[0].function.name == "test_function"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}' assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
async def test_convert_function_call_ordering(self):
input_items = [
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_123",
name="test_function_a",
arguments='{"param": "value"}',
),
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_456",
name="test_function_b",
arguments='{"param": "value"}',
),
OpenAIResponseInputFunctionToolCallOutput(
output="AAA",
call_id="call_123",
),
OpenAIResponseInputFunctionToolCallOutput(
output="BBB",
call_id="call_456",
),
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 4
assert isinstance(result[0], OpenAIAssistantMessageParam)
assert len(result[0].tool_calls) == 1
assert result[0].tool_calls[0].id == "call_123"
assert result[0].tool_calls[0].function.name == "test_function_a"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
assert isinstance(result[1], OpenAIToolMessageParam)
assert result[1].content == "AAA"
assert result[1].tool_call_id == "call_123"
assert isinstance(result[2], OpenAIAssistantMessageParam)
assert len(result[2].tool_calls) == 1
assert result[2].tool_calls[0].id == "call_456"
assert result[2].tool_calls[0].function.name == "test_function_b"
assert result[2].tool_calls[0].function.arguments == '{"param": "value"}'
assert isinstance(result[3], OpenAIToolMessageParam)
assert result[3].content == "BBB"
assert result[3].tool_call_id == "call_456"
async def test_convert_response_message(self): async def test_convert_response_message(self):
input_items = [ input_items = [
OpenAIResponseMessage( OpenAIResponseMessage(