mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 21:48:36 +00:00
feat: reuse previous mcp tool listings where possible (#3710)
# What does this PR do? This PR checks whether, if a previous response is linked, there are mcp_list_tools objects that can be reused instead of listing the tools explicitly every time. Closes #3106 ## Test Plan Tested manually. Added unit tests to cover new behaviour. --------- Signed-off-by: Gordon Sim <gsim@redhat.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
0066d986c5
commit
8bf07f91cb
12 changed files with 1835 additions and 983 deletions
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
|
@ -20,6 +20,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
ListOpenAIResponseInputItem,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
|
@ -38,7 +39,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.datatypes import ResponsesStoreConfig
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
|
@ -963,6 +964,57 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
|||
assert result.status == "completed"
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.tools.mcp.list_mcp_tools")
|
||||
async def test_reuse_mcp_tool_list(
|
||||
mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api
|
||||
):
|
||||
"""Test that mcp_list_tools can be reused where appropriate."""
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
mock_list_mcp_tools.return_value = ListToolDefsResponse(
|
||||
data=[ToolDef(name="test_tool", description="a test tool", input_schema={}, output_schema={})]
|
||||
)
|
||||
|
||||
res1 = await openai_responses_impl.create_openai_response(
|
||||
input="What is 2+2?",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
store=True,
|
||||
tools=[
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
],
|
||||
)
|
||||
args = mock_responses_store.store_response_object.call_args
|
||||
data = args.kwargs["response_object"].model_dump()
|
||||
data["input"] = [input_item.model_dump() for input_item in args.kwargs["input"]]
|
||||
data["messages"] = [msg.model_dump() for msg in args.kwargs["messages"]]
|
||||
stored = _OpenAIResponseObjectWithInputAndMessages(**data)
|
||||
mock_responses_store.get_response_object.return_value = stored
|
||||
|
||||
res2 = await openai_responses_impl.create_openai_response(
|
||||
previous_response_id=res1.id,
|
||||
input="Now what is 3+3?",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
store=True,
|
||||
tools=[
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
],
|
||||
)
|
||||
assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2
|
||||
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||
tools_seen = second_call.kwargs["tools"]
|
||||
assert len(tools_seen) == 1
|
||||
assert tools_seen[0]["function"]["name"] == "test_tool"
|
||||
assert tools_seen[0]["function"]["description"] == "a test tool"
|
||||
|
||||
assert mock_list_mcp_tools.call_count == 1
|
||||
listings = [obj for obj in res2.output if obj.type == "mcp_list_tools"]
|
||||
assert len(listings) == 1
|
||||
assert listings[0].server_label == "alabel"
|
||||
assert len(listings[0].tools) == 1
|
||||
assert listings[0].tools[0].name == "test_tool"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text_format, response_format",
|
||||
[
|
||||
|
|
|
@ -0,0 +1,183 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseToolMCP,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
|
||||
|
||||
|
||||
class TestToolContext:
|
||||
def test_no_tools(self):
|
||||
tools = []
|
||||
context = ToolContext(tools)
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 0
|
||||
assert len(context.previous_tools) == 0
|
||||
assert len(context.previous_tool_listings) == 0
|
||||
|
||||
def test_no_previous_tools(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
|
||||
OpenAIResponseInputToolMCP(server_label="label", server_url="url"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 2
|
||||
assert len(context.previous_tools) == 0
|
||||
assert len(context.previous_tool_listings) == 0
|
||||
|
||||
def test_reusable_server(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
|
||||
)
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
|
||||
OpenAIResponseToolMCP(server_label="alabel"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 1
|
||||
assert context.tools_to_process[0].type == "file_search"
|
||||
assert len(context.previous_tools) == 1
|
||||
assert context.previous_tools["test_tool"].server_label == "alabel"
|
||||
assert context.previous_tools["test_tool"].server_url == "aurl"
|
||||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "alabel"
|
||||
|
||||
def test_multiple_reusable_servers(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
|
||||
),
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test2",
|
||||
server_label="anotherlabel",
|
||||
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
|
||||
),
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 2
|
||||
assert context.tools_to_process[0].type == "function"
|
||||
assert context.tools_to_process[1].type == "web_search"
|
||||
assert len(context.previous_tools) == 2
|
||||
assert context.previous_tools["test_tool"].server_label == "alabel"
|
||||
assert context.previous_tools["test_tool"].server_url == "aurl"
|
||||
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
|
||||
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
|
||||
assert len(context.previous_tool_listings) == 2
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "alabel"
|
||||
assert len(context.previous_tool_listings[1].tools) == 1
|
||||
assert context.previous_tool_listings[1].server_label == "anotherlabel"
|
||||
|
||||
def test_multiple_servers_only_one_reusable(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test2",
|
||||
server_label="anotherlabel",
|
||||
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
|
||||
)
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 3
|
||||
assert context.tools_to_process[0].type == "function"
|
||||
assert context.tools_to_process[1].type == "web_search"
|
||||
assert context.tools_to_process[2].type == "mcp"
|
||||
assert len(context.previous_tools) == 1
|
||||
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
|
||||
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
|
||||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "anotherlabel"
|
||||
|
||||
def test_mismatched_allowed_tools(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl", allowed_tools=["test_tool_2"]),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool_1", input_schema={})]
|
||||
),
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test2",
|
||||
server_label="anotherlabel",
|
||||
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
|
||||
),
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 3
|
||||
assert context.tools_to_process[0].type == "function"
|
||||
assert context.tools_to_process[1].type == "web_search"
|
||||
assert context.tools_to_process[2].type == "mcp"
|
||||
assert len(context.previous_tools) == 1
|
||||
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
|
||||
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
|
||||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "anotherlabel"
|
Loading…
Add table
Add a link
Reference in a new issue