mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
assert that the reused tool was passed to inference
This commit is contained in:
parent
88eeade897
commit
a0bdd7580d
1 changed files with 10 additions and 42 deletions
|
|
@ -954,46 +954,6 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
|||
assert result.status == "completed"
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.tools.mcp.list_mcp_tools")
|
||||
async def test_stored_response_includes_tools(
|
||||
mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api
|
||||
):
|
||||
"""Test that a stored response includes any tools that were specified."""
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
mock_list_mcp_tools.return_value = ListToolDefsResponse(
|
||||
data=[ToolDef(name="test_tool", description="test tool", input_schema={}, output_schema={})]
|
||||
)
|
||||
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input="Now what is 3+3?",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
store=True,
|
||||
tools=[
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
],
|
||||
)
|
||||
|
||||
store_call_args = mock_responses_store.store_response_object.call_args
|
||||
stored_response = store_call_args.kwargs["response_object"]
|
||||
|
||||
assert len(stored_response.tools) == 2
|
||||
|
||||
assert stored_response.tools[0].type == "function"
|
||||
assert stored_response.tools[0].name == "fake"
|
||||
|
||||
assert stored_response.tools[1].type == "mcp"
|
||||
assert stored_response.tools[1].server_label == "alabel"
|
||||
|
||||
assert len(result.tools) == 2
|
||||
assert result.tools[0].type == "function"
|
||||
assert result.tools[0].name == "fake"
|
||||
|
||||
assert result.tools[1].type == "mcp"
|
||||
assert result.tools[1].server_label == "alabel"
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.tools.mcp.list_mcp_tools")
|
||||
async def test_reuse_mcp_tool_list(
|
||||
mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api
|
||||
|
|
@ -1002,7 +962,7 @@ async def test_reuse_mcp_tool_list(
|
|||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
mock_list_mcp_tools.return_value = ListToolDefsResponse(
|
||||
data=[ToolDef(name="test_tool", description="test tool", input_schema={}, output_schema={})]
|
||||
data=[ToolDef(name="test_tool", description="a test tool", input_schema={}, output_schema={})]
|
||||
)
|
||||
|
||||
res1 = await openai_responses_impl.create_openai_response(
|
||||
|
|
@ -1027,14 +987,22 @@ async def test_reuse_mcp_tool_list(
|
|||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
store=True,
|
||||
tools=[
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
],
|
||||
)
|
||||
assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2
|
||||
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||
tools_seen = second_call.kwargs["tools"]
|
||||
assert len(tools_seen) == 1
|
||||
assert tools_seen[0]["function"]["name"] == "test_tool"
|
||||
assert tools_seen[0]["function"]["description"] == "a test tool"
|
||||
|
||||
assert mock_list_mcp_tools.call_count == 1
|
||||
listings = [obj for obj in res2.output if obj.type == "mcp_list_tools"]
|
||||
assert len(listings) == 1
|
||||
assert listings[0].server_label == "alabel"
|
||||
assert len(listings[0].tools) == 1
|
||||
assert listings[0].tools[0].name == "test_tool"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue