assert that the reused tool was passed to inference

This commit is contained in:
Gordon Sim 2025-10-06 23:24:50 +01:00
parent 88eeade897
commit a0bdd7580d

View file

@ -954,46 +954,6 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
assert result.status == "completed"
@patch("llama_stack.providers.utils.tools.mcp.list_mcp_tools")
async def test_stored_response_includes_tools(
mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api
):
"""Test that a stored response includes any tools that were specified."""
mock_inference_api.openai_chat_completion.return_value = fake_stream()
mock_list_mcp_tools.return_value = ListToolDefsResponse(
data=[ToolDef(name="test_tool", description="test tool", input_schema={}, output_schema={})]
)
result = await openai_responses_impl.create_openai_response(
input="Now what is 3+3?",
model="meta-llama/Llama-3.1-8B-Instruct",
store=True,
tools=[
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
],
)
store_call_args = mock_responses_store.store_response_object.call_args
stored_response = store_call_args.kwargs["response_object"]
assert len(stored_response.tools) == 2
assert stored_response.tools[0].type == "function"
assert stored_response.tools[0].name == "fake"
assert stored_response.tools[1].type == "mcp"
assert stored_response.tools[1].server_label == "alabel"
assert len(result.tools) == 2
assert result.tools[0].type == "function"
assert result.tools[0].name == "fake"
assert result.tools[1].type == "mcp"
assert result.tools[1].server_label == "alabel"
@patch("llama_stack.providers.utils.tools.mcp.list_mcp_tools")
async def test_reuse_mcp_tool_list(
mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api
@ -1002,7 +962,7 @@ async def test_reuse_mcp_tool_list(
mock_inference_api.openai_chat_completion.return_value = fake_stream()
mock_list_mcp_tools.return_value = ListToolDefsResponse(
data=[ToolDef(name="test_tool", description="test tool", input_schema={}, output_schema={})]
data=[ToolDef(name="test_tool", description="a test tool", input_schema={}, output_schema={})]
)
res1 = await openai_responses_impl.create_openai_response(
@ -1027,14 +987,22 @@ async def test_reuse_mcp_tool_list(
model="meta-llama/Llama-3.1-8B-Instruct",
store=True,
tools=[
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
],
)
assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
tools_seen = second_call.kwargs["tools"]
assert len(tools_seen) == 1
assert tools_seen[0]["function"]["name"] == "test_tool"
assert tools_seen[0]["function"]["description"] == "a test tool"
assert mock_list_mcp_tools.call_count == 1
listings = [obj for obj in res2.output if obj.type == "mcp_list_tools"]
assert len(listings) == 1
assert listings[0].server_label == "alabel"
assert len(listings[0].tools) == 1
assert listings[0].tools[0].name == "test_tool"
@pytest.mark.parametrize(