mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
assert that the reused tool was passed to inference
This commit is contained in:
parent
88eeade897
commit
a0bdd7580d
1 changed files with 10 additions and 42 deletions
|
|
@ -954,46 +954,6 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.tools.mcp.list_mcp_tools")
|
|
||||||
async def test_stored_response_includes_tools(
|
|
||||||
mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api
|
|
||||||
):
|
|
||||||
"""Test that a stored response includes any tools that were specified."""
|
|
||||||
|
|
||||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
|
||||||
mock_list_mcp_tools.return_value = ListToolDefsResponse(
|
|
||||||
data=[ToolDef(name="test_tool", description="test tool", input_schema={}, output_schema={})]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await openai_responses_impl.create_openai_response(
|
|
||||||
input="Now what is 3+3?",
|
|
||||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
|
||||||
store=True,
|
|
||||||
tools=[
|
|
||||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
|
||||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
store_call_args = mock_responses_store.store_response_object.call_args
|
|
||||||
stored_response = store_call_args.kwargs["response_object"]
|
|
||||||
|
|
||||||
assert len(stored_response.tools) == 2
|
|
||||||
|
|
||||||
assert stored_response.tools[0].type == "function"
|
|
||||||
assert stored_response.tools[0].name == "fake"
|
|
||||||
|
|
||||||
assert stored_response.tools[1].type == "mcp"
|
|
||||||
assert stored_response.tools[1].server_label == "alabel"
|
|
||||||
|
|
||||||
assert len(result.tools) == 2
|
|
||||||
assert result.tools[0].type == "function"
|
|
||||||
assert result.tools[0].name == "fake"
|
|
||||||
|
|
||||||
assert result.tools[1].type == "mcp"
|
|
||||||
assert result.tools[1].server_label == "alabel"
|
|
||||||
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.tools.mcp.list_mcp_tools")
|
@patch("llama_stack.providers.utils.tools.mcp.list_mcp_tools")
|
||||||
async def test_reuse_mcp_tool_list(
|
async def test_reuse_mcp_tool_list(
|
||||||
mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api
|
mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api
|
||||||
|
|
@ -1002,7 +962,7 @@ async def test_reuse_mcp_tool_list(
|
||||||
|
|
||||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||||
mock_list_mcp_tools.return_value = ListToolDefsResponse(
|
mock_list_mcp_tools.return_value = ListToolDefsResponse(
|
||||||
data=[ToolDef(name="test_tool", description="test tool", input_schema={}, output_schema={})]
|
data=[ToolDef(name="test_tool", description="a test tool", input_schema={}, output_schema={})]
|
||||||
)
|
)
|
||||||
|
|
||||||
res1 = await openai_responses_impl.create_openai_response(
|
res1 = await openai_responses_impl.create_openai_response(
|
||||||
|
|
@ -1027,14 +987,22 @@ async def test_reuse_mcp_tool_list(
|
||||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
store=True,
|
store=True,
|
||||||
tools=[
|
tools=[
|
||||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
|
||||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2
|
||||||
|
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||||
|
tools_seen = second_call.kwargs["tools"]
|
||||||
|
assert len(tools_seen) == 1
|
||||||
|
assert tools_seen[0]["function"]["name"] == "test_tool"
|
||||||
|
assert tools_seen[0]["function"]["description"] == "a test tool"
|
||||||
|
|
||||||
assert mock_list_mcp_tools.call_count == 1
|
assert mock_list_mcp_tools.call_count == 1
|
||||||
listings = [obj for obj in res2.output if obj.type == "mcp_list_tools"]
|
listings = [obj for obj in res2.output if obj.type == "mcp_list_tools"]
|
||||||
assert len(listings) == 1
|
assert len(listings) == 1
|
||||||
|
assert listings[0].server_label == "alabel"
|
||||||
|
assert len(listings[0].tools) == 1
|
||||||
|
assert listings[0].tools[0].name == "test_tool"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue