mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 11:19:47 +00:00
fix: ensure assistant message is followed by tool call message as expected by openai
Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
parent
58e164b8bc
commit
012088e084
3 changed files with 147 additions and 5 deletions
|
|
@ -101,14 +101,22 @@ async def convert_response_input_to_chat_messages(
|
|||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if isinstance(input, list):
|
||||
# extract all OpenAIResponseInputFunctionToolCallOutput items
|
||||
# so their corresponding OpenAIToolMessageParam instances can
|
||||
# be added immediately following the corresponding
|
||||
# OpenAIAssistantMessageParam
|
||||
tool_call_results = {}
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.call_id,
|
||||
)
|
||||
tool_call_results[input_item.call_id] = OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.call_id,
|
||||
)
|
||||
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||
# skip as these have been extracted and inserted in order
|
||||
pass
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
|
|
@ -119,6 +127,9 @@ async def convert_response_input_to_chat_messages(
|
|||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
if input_item.call_id in tool_call_results:
|
||||
messages.append(tool_call_results[input_item.call_id])
|
||||
del tool_call_results[input_item.call_id]
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
|
|
@ -146,6 +157,8 @@ async def convert_response_input_to_chat_messages(
|
|||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
messages.append(message_type(content=content))
|
||||
for result in tool_call_results.values():
|
||||
messages.append(result)
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
|
|
|||
|
|
@ -260,6 +260,94 @@ def test_response_non_streaming_custom_tool(compat_client, text_model_id, case):
|
|||
assert response.output[0].name == "get_weather"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", custom_tool_test_cases)
|
||||
def test_response_function_call_ordering_1(compat_client, text_model_id, case):
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case.input,
|
||||
tools=case.tools,
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.output) == 1
|
||||
assert response.output[0].type == "function_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].name == "get_weather"
|
||||
inputs = []
|
||||
inputs.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": case.input,
|
||||
}
|
||||
)
|
||||
inputs.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"output": "It is raining.",
|
||||
"call_id": response.output[0].call_id,
|
||||
}
|
||||
)
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id, input=inputs, tools=case.tools, stream=False, previous_response_id=response.id
|
||||
)
|
||||
assert len(response.output) == 1
|
||||
|
||||
|
||||
def test_response_function_call_ordering_2(compat_client, text_model_id):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for a given location.",
|
||||
"parameters": {
|
||||
"additionalProperties": False,
|
||||
"properties": {
|
||||
"location": {
|
||||
"description": "City and country e.g. Bogotá, Colombia",
|
||||
"type": "string",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
"type": "object",
|
||||
},
|
||||
}
|
||||
]
|
||||
inputs = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Is the weather better in San Francisco or Los Angeles?",
|
||||
}
|
||||
]
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=inputs,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
for output in response.output:
|
||||
if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
|
||||
inputs.append(output)
|
||||
for output in response.output:
|
||||
if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
|
||||
weather = "It is raining."
|
||||
if "Los Angeles" in output.arguments:
|
||||
weather = "It is cloudy."
|
||||
inputs.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"output": weather,
|
||||
"call_id": output.call_id,
|
||||
}
|
||||
)
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=inputs,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.output) == 1
|
||||
assert "Los Angeles" in response.output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases)
|
||||
def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
|
||||
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
||||
|
|
|
|||
|
|
@ -146,6 +146,47 @@ class TestConvertResponseInputToChatMessages:
|
|||
assert result[0].tool_calls[0].function.name == "test_function"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
|
||||
async def test_convert_function_call_ordering(self):
|
||||
input_items = [
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_123",
|
||||
name="test_function_a",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_456",
|
||||
name="test_function_b",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="AAA",
|
||||
call_id="call_123",
|
||||
),
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="BBB",
|
||||
call_id="call_456",
|
||||
),
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
assert len(result) == 4
|
||||
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||
assert len(result[0].tool_calls) == 1
|
||||
assert result[0].tool_calls[0].id == "call_123"
|
||||
assert result[0].tool_calls[0].function.name == "test_function_a"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
assert isinstance(result[1], OpenAIToolMessageParam)
|
||||
assert result[1].content == "AAA"
|
||||
assert result[1].tool_call_id == "call_123"
|
||||
assert isinstance(result[2], OpenAIAssistantMessageParam)
|
||||
assert len(result[2].tool_calls) == 1
|
||||
assert result[2].tool_calls[0].id == "call_456"
|
||||
assert result[2].tool_calls[0].function.name == "test_function_b"
|
||||
assert result[2].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
assert isinstance(result[3], OpenAIToolMessageParam)
|
||||
assert result[3].content == "BBB"
|
||||
assert result[3].tool_call_id == "call_456"
|
||||
|
||||
async def test_convert_response_message(self):
|
||||
input_items = [
|
||||
OpenAIResponseMessage(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue