mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 13:22:36 +00:00
fix: ensure assistant message is followed by tool call message as expected by openai
Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
parent
58e164b8bc
commit
012088e084
3 changed files with 147 additions and 5 deletions
|
|
@ -101,14 +101,22 @@ async def convert_response_input_to_chat_messages(
|
||||||
"""
|
"""
|
||||||
messages: list[OpenAIMessageParam] = []
|
messages: list[OpenAIMessageParam] = []
|
||||||
if isinstance(input, list):
|
if isinstance(input, list):
|
||||||
|
# extract all OpenAIResponseInputFunctionToolCallOutput items
|
||||||
|
# so their corresponding OpenAIToolMessageParam instances can
|
||||||
|
# be added immediately following the corresponding
|
||||||
|
# OpenAIAssistantMessageParam
|
||||||
|
tool_call_results = {}
|
||||||
for input_item in input:
|
for input_item in input:
|
||||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||||
messages.append(
|
tool_call_results[input_item.call_id] = OpenAIToolMessageParam(
|
||||||
OpenAIToolMessageParam(
|
content=input_item.output,
|
||||||
content=input_item.output,
|
tool_call_id=input_item.call_id,
|
||||||
tool_call_id=input_item.call_id,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for input_item in input:
|
||||||
|
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||||
|
# skip as these have been extracted and inserted in order
|
||||||
|
pass
|
||||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||||
tool_call = OpenAIChatCompletionToolCall(
|
tool_call = OpenAIChatCompletionToolCall(
|
||||||
index=0,
|
index=0,
|
||||||
|
|
@ -119,6 +127,9 @@ async def convert_response_input_to_chat_messages(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||||
|
if input_item.call_id in tool_call_results:
|
||||||
|
messages.append(tool_call_results[input_item.call_id])
|
||||||
|
del tool_call_results[input_item.call_id]
|
||||||
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
|
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
|
||||||
tool_call = OpenAIChatCompletionToolCall(
|
tool_call = OpenAIChatCompletionToolCall(
|
||||||
index=0,
|
index=0,
|
||||||
|
|
@ -146,6 +157,8 @@ async def convert_response_input_to_chat_messages(
|
||||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||||
)
|
)
|
||||||
messages.append(message_type(content=content))
|
messages.append(message_type(content=content))
|
||||||
|
for result in tool_call_results.values():
|
||||||
|
messages.append(result)
|
||||||
else:
|
else:
|
||||||
messages.append(OpenAIUserMessageParam(content=input))
|
messages.append(OpenAIUserMessageParam(content=input))
|
||||||
return messages
|
return messages
|
||||||
|
|
|
||||||
|
|
@ -260,6 +260,94 @@ def test_response_non_streaming_custom_tool(compat_client, text_model_id, case):
|
||||||
assert response.output[0].name == "get_weather"
|
assert response.output[0].name == "get_weather"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("case", custom_tool_test_cases)
|
||||||
|
def test_response_function_call_ordering_1(compat_client, text_model_id, case):
|
||||||
|
response = compat_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=case.input,
|
||||||
|
tools=case.tools,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
assert len(response.output) == 1
|
||||||
|
assert response.output[0].type == "function_call"
|
||||||
|
assert response.output[0].status == "completed"
|
||||||
|
assert response.output[0].name == "get_weather"
|
||||||
|
inputs = []
|
||||||
|
inputs.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": case.input,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
inputs.append(
|
||||||
|
{
|
||||||
|
"type": "function_call_output",
|
||||||
|
"output": "It is raining.",
|
||||||
|
"call_id": response.output[0].call_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response = compat_client.responses.create(
|
||||||
|
model=text_model_id, input=inputs, tools=case.tools, stream=False, previous_response_id=response.id
|
||||||
|
)
|
||||||
|
assert len(response.output) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_function_call_ordering_2(compat_client, text_model_id):
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get current temperature for a given location.",
|
||||||
|
"parameters": {
|
||||||
|
"additionalProperties": False,
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"description": "City and country e.g. Bogotá, Colombia",
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
"type": "object",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
inputs = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Is the weather better in San Francisco or Los Angeles?",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
response = compat_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=inputs,
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
for output in response.output:
|
||||||
|
if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
|
||||||
|
inputs.append(output)
|
||||||
|
for output in response.output:
|
||||||
|
if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
|
||||||
|
weather = "It is raining."
|
||||||
|
if "Los Angeles" in output.arguments:
|
||||||
|
weather = "It is cloudy."
|
||||||
|
inputs.append(
|
||||||
|
{
|
||||||
|
"type": "function_call_output",
|
||||||
|
"output": weather,
|
||||||
|
"call_id": output.call_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response = compat_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=inputs,
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
assert len(response.output) == 1
|
||||||
|
assert "Los Angeles" in response.output_text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases)
|
@pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases)
|
||||||
def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
|
def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
|
||||||
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
||||||
|
|
|
||||||
|
|
@ -146,6 +146,47 @@ class TestConvertResponseInputToChatMessages:
|
||||||
assert result[0].tool_calls[0].function.name == "test_function"
|
assert result[0].tool_calls[0].function.name == "test_function"
|
||||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||||
|
|
||||||
|
async def test_convert_function_call_ordering(self):
|
||||||
|
input_items = [
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
call_id="call_123",
|
||||||
|
name="test_function_a",
|
||||||
|
arguments='{"param": "value"}',
|
||||||
|
),
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
call_id="call_456",
|
||||||
|
name="test_function_b",
|
||||||
|
arguments='{"param": "value"}',
|
||||||
|
),
|
||||||
|
OpenAIResponseInputFunctionToolCallOutput(
|
||||||
|
output="AAA",
|
||||||
|
call_id="call_123",
|
||||||
|
),
|
||||||
|
OpenAIResponseInputFunctionToolCallOutput(
|
||||||
|
output="BBB",
|
||||||
|
call_id="call_456",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await convert_response_input_to_chat_messages(input_items)
|
||||||
|
assert len(result) == 4
|
||||||
|
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||||
|
assert len(result[0].tool_calls) == 1
|
||||||
|
assert result[0].tool_calls[0].id == "call_123"
|
||||||
|
assert result[0].tool_calls[0].function.name == "test_function_a"
|
||||||
|
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||||
|
assert isinstance(result[1], OpenAIToolMessageParam)
|
||||||
|
assert result[1].content == "AAA"
|
||||||
|
assert result[1].tool_call_id == "call_123"
|
||||||
|
assert isinstance(result[2], OpenAIAssistantMessageParam)
|
||||||
|
assert len(result[2].tool_calls) == 1
|
||||||
|
assert result[2].tool_calls[0].id == "call_456"
|
||||||
|
assert result[2].tool_calls[0].function.name == "test_function_b"
|
||||||
|
assert result[2].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||||
|
assert isinstance(result[3], OpenAIToolMessageParam)
|
||||||
|
assert result[3].content == "BBB"
|
||||||
|
assert result[3].tool_call_id == "call_456"
|
||||||
|
|
||||||
async def test_convert_response_message(self):
|
async def test_convert_response_message(self):
|
||||||
input_items = [
|
input_items = [
|
||||||
OpenAIResponseMessage(
|
OpenAIResponseMessage(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue