feat: reuse previous mcp tool listings where possible (#3710)

# What does this PR do?
This PR checks whether, if a previous response is linked, there are
mcp_list_tools objects that can be reused instead of listing the tools
explicitly every time.

 Closes #3106 

## Test Plan
Tested manually.
Added unit tests to cover new behaviour.

---------

Signed-off-by: Gordon Sim <gsim@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
grs 2025-10-10 17:28:25 +01:00 committed by GitHub
parent 0066d986c5
commit 8bf07f91cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1835 additions and 983 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, patch
import pytest
from openai.types.chat.chat_completion_chunk import (
@ -20,6 +20,7 @@ from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContentOutputText,
@ -38,7 +39,7 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatJSONSchema,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.datatypes import ResponsesStoreConfig
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
@ -963,6 +964,57 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
assert result.status == "completed"
@patch("llama_stack.providers.utils.tools.mcp.list_mcp_tools")
async def test_reuse_mcp_tool_list(
mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api
):
"""Test that mcp_list_tools can be reused where appropriate."""
mock_inference_api.openai_chat_completion.return_value = fake_stream()
mock_list_mcp_tools.return_value = ListToolDefsResponse(
data=[ToolDef(name="test_tool", description="a test tool", input_schema={}, output_schema={})]
)
res1 = await openai_responses_impl.create_openai_response(
input="What is 2+2?",
model="meta-llama/Llama-3.1-8B-Instruct",
store=True,
tools=[
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
],
)
args = mock_responses_store.store_response_object.call_args
data = args.kwargs["response_object"].model_dump()
data["input"] = [input_item.model_dump() for input_item in args.kwargs["input"]]
data["messages"] = [msg.model_dump() for msg in args.kwargs["messages"]]
stored = _OpenAIResponseObjectWithInputAndMessages(**data)
mock_responses_store.get_response_object.return_value = stored
res2 = await openai_responses_impl.create_openai_response(
previous_response_id=res1.id,
input="Now what is 3+3?",
model="meta-llama/Llama-3.1-8B-Instruct",
store=True,
tools=[
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
],
)
assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
tools_seen = second_call.kwargs["tools"]
assert len(tools_seen) == 1
assert tools_seen[0]["function"]["name"] == "test_tool"
assert tools_seen[0]["function"]["description"] == "a test tool"
assert mock_list_mcp_tools.call_count == 1
listings = [obj for obj in res2.output if obj.type == "mcp_list_tools"]
assert len(listings) == 1
assert listings[0].server_label == "alabel"
assert len(listings[0].tools) == 1
assert listings[0].tools[0].name == "test_tool"
@pytest.mark.parametrize(
"text_format, response_format",
[

View file

@ -0,0 +1,183 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseInputToolWebSearch,
OpenAIResponseObject,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseToolMCP,
)
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
class TestToolContext:
def test_no_tools(self):
tools = []
context = ToolContext(tools)
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 0
assert len(context.previous_tools) == 0
assert len(context.previous_tool_listings) == 0
def test_no_previous_tools(self):
tools = [
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
OpenAIResponseInputToolMCP(server_label="label", server_url="url"),
]
context = ToolContext(tools)
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 2
assert len(context.previous_tools) == 0
assert len(context.previous_tool_listings) == 0
def test_reusable_server(self):
tools = [
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
)
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
OpenAIResponseToolMCP(server_label="alabel"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 1
assert context.tools_to_process[0].type == "file_search"
assert len(context.previous_tools) == 1
assert context.previous_tools["test_tool"].server_label == "alabel"
assert context.previous_tools["test_tool"].server_url == "aurl"
assert len(context.previous_tool_listings) == 1
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "alabel"
def test_multiple_reusable_servers(self):
tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
),
OpenAIResponseOutputMessageMCPListTools(
id="test2",
server_label="anotherlabel",
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
),
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 2
assert context.tools_to_process[0].type == "function"
assert context.tools_to_process[1].type == "web_search"
assert len(context.previous_tools) == 2
assert context.previous_tools["test_tool"].server_label == "alabel"
assert context.previous_tools["test_tool"].server_url == "aurl"
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
assert len(context.previous_tool_listings) == 2
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "alabel"
assert len(context.previous_tool_listings[1].tools) == 1
assert context.previous_tool_listings[1].server_label == "anotherlabel"
def test_multiple_servers_only_one_reusable(self):
tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test2",
server_label="anotherlabel",
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
)
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 3
assert context.tools_to_process[0].type == "function"
assert context.tools_to_process[1].type == "web_search"
assert context.tools_to_process[2].type == "mcp"
assert len(context.previous_tools) == 1
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
assert len(context.previous_tool_listings) == 1
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "anotherlabel"
def test_mismatched_allowed_tools(self):
tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl", allowed_tools=["test_tool_2"]),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool_1", input_schema={})]
),
OpenAIResponseOutputMessageMCPListTools(
id="test2",
server_label="anotherlabel",
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
),
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 3
assert context.tools_to_process[0].type == "function"
assert context.tools_to_process[1].type == "web_search"
assert context.tools_to_process[2].type == "mcp"
assert len(context.previous_tools) == 1
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
assert len(context.previous_tool_listings) == 1
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "anotherlabel"