mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
fix: ensure assistant message is followed by tool call message as expected by openai (#3224)
Some checks failed
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Pre-commit / pre-commit (push) Failing after 4s
Python Package Build Test / build (3.13) (push) Failing after 3s
Test Llama Stack Build / build-single-provider (push) Failing after 5s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 4s
Python Package Build Test / build (3.12) (push) Failing after 5s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
UI Tests / ui-tests (22) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Test External API and Providers / test-external (venv) (push) Failing after 8s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 12s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 15s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 17s
Test Llama Stack Build / generate-matrix (push) Failing after 21s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 23s
Test Llama Stack Build / build (push) Has been skipped
Update ReadTheDocs / update-readthedocs (push) Failing after 20s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 24s
Some checks failed
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Pre-commit / pre-commit (push) Failing after 4s
Python Package Build Test / build (3.13) (push) Failing after 3s
Test Llama Stack Build / build-single-provider (push) Failing after 5s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 4s
Python Package Build Test / build (3.12) (push) Failing after 5s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
UI Tests / ui-tests (22) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Test External API and Providers / test-external (venv) (push) Failing after 8s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 12s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 15s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 17s
Test Llama Stack Build / generate-matrix (push) Failing after 21s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 23s
Test Llama Stack Build / build (push) Has been skipped
Update ReadTheDocs / update-readthedocs (push) Failing after 20s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 24s
# What does this PR do? As described in #3134 a langchain example works against openai's responses impl, but not against llama stack's. This turned out to be due to the order of the inputs. The langchain example has the two function call outputs first, followed by each call result in turn. This seems to be valid as it is accepted by openai's impl. However in llama stack, these inputs are converted to chat completion inputs and the resulting order for that api is not accpeted by openai. This PR fixes the issue by ensuring that the converted chat completions inputs are in the expected order. Closes #3134 ## Test Plan Added unit and integration tests. Verified this fixes original issue as reported. --------- Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
parent
b0797e4982
commit
da73f1a180
3 changed files with 163 additions and 10 deletions
|
@ -101,14 +101,22 @@ async def convert_response_input_to_chat_messages(
|
||||||
"""
|
"""
|
||||||
messages: list[OpenAIMessageParam] = []
|
messages: list[OpenAIMessageParam] = []
|
||||||
if isinstance(input, list):
|
if isinstance(input, list):
|
||||||
|
# extract all OpenAIResponseInputFunctionToolCallOutput items
|
||||||
|
# so their corresponding OpenAIToolMessageParam instances can
|
||||||
|
# be added immediately following the corresponding
|
||||||
|
# OpenAIAssistantMessageParam
|
||||||
|
tool_call_results = {}
|
||||||
for input_item in input:
|
for input_item in input:
|
||||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||||
messages.append(
|
tool_call_results[input_item.call_id] = OpenAIToolMessageParam(
|
||||||
OpenAIToolMessageParam(
|
content=input_item.output,
|
||||||
content=input_item.output,
|
tool_call_id=input_item.call_id,
|
||||||
tool_call_id=input_item.call_id,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for input_item in input:
|
||||||
|
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||||
|
# skip as these have been extracted and inserted in order
|
||||||
|
pass
|
||||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||||
tool_call = OpenAIChatCompletionToolCall(
|
tool_call = OpenAIChatCompletionToolCall(
|
||||||
index=0,
|
index=0,
|
||||||
|
@ -119,6 +127,9 @@ async def convert_response_input_to_chat_messages(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||||
|
if input_item.call_id in tool_call_results:
|
||||||
|
messages.append(tool_call_results[input_item.call_id])
|
||||||
|
del tool_call_results[input_item.call_id]
|
||||||
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
|
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
|
||||||
tool_call = OpenAIChatCompletionToolCall(
|
tool_call = OpenAIChatCompletionToolCall(
|
||||||
index=0,
|
index=0,
|
||||||
|
@ -146,6 +157,10 @@ async def convert_response_input_to_chat_messages(
|
||||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||||
)
|
)
|
||||||
messages.append(message_type(content=content))
|
messages.append(message_type(content=content))
|
||||||
|
if len(tool_call_results):
|
||||||
|
raise ValueError(
|
||||||
|
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
messages.append(OpenAIUserMessageParam(content=input))
|
messages.append(OpenAIUserMessageParam(content=input))
|
||||||
return messages
|
return messages
|
||||||
|
|
|
@ -260,6 +260,94 @@ def test_response_non_streaming_custom_tool(compat_client, text_model_id, case):
|
||||||
assert response.output[0].name == "get_weather"
|
assert response.output[0].name == "get_weather"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("case", custom_tool_test_cases)
|
||||||
|
def test_response_function_call_ordering_1(compat_client, text_model_id, case):
|
||||||
|
response = compat_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=case.input,
|
||||||
|
tools=case.tools,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
assert len(response.output) == 1
|
||||||
|
assert response.output[0].type == "function_call"
|
||||||
|
assert response.output[0].status == "completed"
|
||||||
|
assert response.output[0].name == "get_weather"
|
||||||
|
inputs = []
|
||||||
|
inputs.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": case.input,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
inputs.append(
|
||||||
|
{
|
||||||
|
"type": "function_call_output",
|
||||||
|
"output": "It is raining.",
|
||||||
|
"call_id": response.output[0].call_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response = compat_client.responses.create(
|
||||||
|
model=text_model_id, input=inputs, tools=case.tools, stream=False, previous_response_id=response.id
|
||||||
|
)
|
||||||
|
assert len(response.output) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_function_call_ordering_2(compat_client, text_model_id):
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get current temperature for a given location.",
|
||||||
|
"parameters": {
|
||||||
|
"additionalProperties": False,
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"description": "City and country e.g. Bogotá, Colombia",
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
"type": "object",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
inputs = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Is the weather better in San Francisco or Los Angeles?",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
response = compat_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=inputs,
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
for output in response.output:
|
||||||
|
if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
|
||||||
|
inputs.append(output)
|
||||||
|
for output in response.output:
|
||||||
|
if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
|
||||||
|
weather = "It is raining."
|
||||||
|
if "Los Angeles" in output.arguments:
|
||||||
|
weather = "It is cloudy."
|
||||||
|
inputs.append(
|
||||||
|
{
|
||||||
|
"type": "function_call_output",
|
||||||
|
"output": weather,
|
||||||
|
"call_id": output.call_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response = compat_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=inputs,
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
assert len(response.output) == 1
|
||||||
|
assert "Los Angeles" in response.output_text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases)
|
@pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases)
|
||||||
def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
|
def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
|
||||||
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
||||||
|
|
|
@ -115,18 +115,27 @@ class TestConvertResponseInputToChatMessages:
|
||||||
|
|
||||||
async def test_convert_function_tool_call_output(self):
|
async def test_convert_function_tool_call_output(self):
|
||||||
input_items = [
|
input_items = [
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
call_id="call_123",
|
||||||
|
name="test_function",
|
||||||
|
arguments='{"param": "value"}',
|
||||||
|
),
|
||||||
OpenAIResponseInputFunctionToolCallOutput(
|
OpenAIResponseInputFunctionToolCallOutput(
|
||||||
output="Tool output",
|
output="Tool output",
|
||||||
call_id="call_123",
|
call_id="call_123",
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
result = await convert_response_input_to_chat_messages(input_items)
|
result = await convert_response_input_to_chat_messages(input_items)
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 2
|
||||||
assert isinstance(result[0], OpenAIToolMessageParam)
|
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||||
assert result[0].content == "Tool output"
|
assert result[0].tool_calls[0].id == "call_123"
|
||||||
assert result[0].tool_call_id == "call_123"
|
assert result[0].tool_calls[0].function.name == "test_function"
|
||||||
|
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||||
|
assert isinstance(result[1], OpenAIToolMessageParam)
|
||||||
|
assert result[1].content == "Tool output"
|
||||||
|
assert result[1].tool_call_id == "call_123"
|
||||||
|
|
||||||
async def test_convert_function_tool_call(self):
|
async def test_convert_function_tool_call(self):
|
||||||
input_items = [
|
input_items = [
|
||||||
|
@ -146,6 +155,47 @@ class TestConvertResponseInputToChatMessages:
|
||||||
assert result[0].tool_calls[0].function.name == "test_function"
|
assert result[0].tool_calls[0].function.name == "test_function"
|
||||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||||
|
|
||||||
|
async def test_convert_function_call_ordering(self):
|
||||||
|
input_items = [
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
call_id="call_123",
|
||||||
|
name="test_function_a",
|
||||||
|
arguments='{"param": "value"}',
|
||||||
|
),
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
call_id="call_456",
|
||||||
|
name="test_function_b",
|
||||||
|
arguments='{"param": "value"}',
|
||||||
|
),
|
||||||
|
OpenAIResponseInputFunctionToolCallOutput(
|
||||||
|
output="AAA",
|
||||||
|
call_id="call_123",
|
||||||
|
),
|
||||||
|
OpenAIResponseInputFunctionToolCallOutput(
|
||||||
|
output="BBB",
|
||||||
|
call_id="call_456",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await convert_response_input_to_chat_messages(input_items)
|
||||||
|
assert len(result) == 4
|
||||||
|
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||||
|
assert len(result[0].tool_calls) == 1
|
||||||
|
assert result[0].tool_calls[0].id == "call_123"
|
||||||
|
assert result[0].tool_calls[0].function.name == "test_function_a"
|
||||||
|
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||||
|
assert isinstance(result[1], OpenAIToolMessageParam)
|
||||||
|
assert result[1].content == "AAA"
|
||||||
|
assert result[1].tool_call_id == "call_123"
|
||||||
|
assert isinstance(result[2], OpenAIAssistantMessageParam)
|
||||||
|
assert len(result[2].tool_calls) == 1
|
||||||
|
assert result[2].tool_calls[0].id == "call_456"
|
||||||
|
assert result[2].tool_calls[0].function.name == "test_function_b"
|
||||||
|
assert result[2].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||||
|
assert isinstance(result[3], OpenAIToolMessageParam)
|
||||||
|
assert result[3].content == "BBB"
|
||||||
|
assert result[3].tool_call_id == "call_456"
|
||||||
|
|
||||||
async def test_convert_response_message(self):
|
async def test_convert_response_message(self):
|
||||||
input_items = [
|
input_items = [
|
||||||
OpenAIResponseMessage(
|
OpenAIResponseMessage(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue