mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 21:48:36 +00:00
feat: reuse previous mcp tool listings where possible (#3710)
# What does this PR do? This PR checks whether, if a previous response is linked, there are mcp_list_tools objects that can be reused instead of listing the tools explicitly every time. Closes #3106 ## Test Plan Tested manually. Added unit tests to cover new behaviour. --------- Signed-off-by: Gordon Sim <gsim@redhat.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
0066d986c5
commit
8bf07f91cb
12 changed files with 1835 additions and 983 deletions
|
@ -39,7 +39,7 @@ from llama_stack.providers.utils.responses.responses_store import (
|
|||
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
from .tool_executor import ToolExecutor
|
||||
from .types import ChatCompletionContext
|
||||
from .types import ChatCompletionContext, ToolContext
|
||||
from .utils import (
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
|
@ -91,13 +91,15 @@ class OpenAIResponsesImpl:
|
|||
async def _process_input_with_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
previous_response_id: str | None,
|
||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
|
||||
"""Process input with optional previous response context.
|
||||
|
||||
Returns:
|
||||
tuple: (all_input for storage, messages for chat completion)
|
||||
tuple: (all_input for storage, messages for chat completion, tool context)
|
||||
"""
|
||||
tool_context = ToolContext(tools)
|
||||
if previous_response_id:
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
|
||||
await self.responses_store.get_response_object(previous_response_id)
|
||||
|
@ -113,11 +115,13 @@ class OpenAIResponsesImpl:
|
|||
else:
|
||||
# Backward compatibility: reconstruct from inputs
|
||||
messages = await convert_response_input_to_chat_messages(all_input)
|
||||
|
||||
tool_context.recover_tools_from_previous_response(previous_response)
|
||||
else:
|
||||
all_input = input
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
|
||||
return all_input, messages
|
||||
return all_input, messages, tool_context
|
||||
|
||||
async def _prepend_instructions(self, messages, instructions):
|
||||
if instructions:
|
||||
|
@ -273,7 +277,9 @@ class OpenAIResponsesImpl:
|
|||
max_infer_iters: int | None = 10,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
|
||||
all_input, messages, tool_context = await self._process_input_with_previous_response(
|
||||
input, tools, previous_response_id
|
||||
)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Structured outputs
|
||||
|
@ -285,6 +291,7 @@ class OpenAIResponsesImpl:
|
|||
response_tools=tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
tool_context=tool_context,
|
||||
inputs=all_input,
|
||||
)
|
||||
|
||||
|
|
|
@ -99,7 +99,7 @@ class StreamingResponseOrchestrator:
|
|||
self.tool_executor = tool_executor
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
|
||||
# Track final messages after all tool executions
|
||||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
# mapping for annotations
|
||||
|
@ -129,6 +129,7 @@ class StreamingResponseOrchestrator:
|
|||
status=status,
|
||||
output=self._clone_outputs(outputs),
|
||||
text=self.text,
|
||||
tools=self.ctx.available_tools(),
|
||||
error=error,
|
||||
)
|
||||
|
||||
|
@ -146,10 +147,8 @@ class StreamingResponseOrchestrator:
|
|||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Process all tools (including MCP tools) and emit streaming events
|
||||
if self.ctx.response_tools:
|
||||
async for stream_event in self._process_tools(self.ctx.response_tools, output_messages):
|
||||
yield stream_event
|
||||
async for stream_event in self._process_tools(output_messages):
|
||||
yield stream_event
|
||||
|
||||
n_iter = 0
|
||||
messages = self.ctx.messages.copy()
|
||||
|
@ -590,7 +589,7 @@ class StreamingResponseOrchestrator:
|
|||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _process_tools(
|
||||
async def _process_new_tools(
|
||||
self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Process all tools and emit appropriate streaming events."""
|
||||
|
@ -645,7 +644,6 @@ class StreamingResponseOrchestrator:
|
|||
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse allowed/never allowed tools
|
||||
always_allowed = None
|
||||
|
@ -707,39 +705,26 @@ class StreamingResponseOrchestrator:
|
|||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add the MCP list message to output
|
||||
output_messages.append(mcp_list_message)
|
||||
|
||||
# Emit output_item.added for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit mcp_list_tools.completed
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit output_item.done for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
|
||||
yield stream_event
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Emit mcp_list_tools.failed event if needed
|
||||
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
|
||||
raise
|
||||
|
||||
async def _process_tools(
|
||||
self, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Handle all mcp tool lists from previous response that are still valid:
|
||||
for tool in self.ctx.tool_context.previous_tool_listings:
|
||||
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
|
||||
yield evt
|
||||
# Process all remaining tools (including MCP tools) and emit streaming events
|
||||
if self.ctx.tool_context.tools_to_process:
|
||||
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
|
||||
yield stream_event
|
||||
|
||||
def _approval_required(self, tool_name: str) -> bool:
|
||||
if tool_name not in self.mcp_tool_to_server:
|
||||
return False
|
||||
|
@ -774,7 +759,6 @@ class StreamingResponseOrchestrator:
|
|||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
|
@ -782,3 +766,60 @@ class StreamingResponseOrchestrator:
|
|||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _add_mcp_list_tools(
|
||||
self, mcp_list_message: OpenAIResponseOutputMessageMCPListTools, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Add the MCP list message to output
|
||||
output_messages.append(mcp_list_message)
|
||||
|
||||
# Emit output_item.added for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
# Emit mcp_list_tools.completed
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit output_item.done for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _reuse_mcp_list_tools(
|
||||
self, original: OpenAIResponseOutputMessageMCPListTools, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
for t in original.tools:
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
# convert from input_schema to map of ToolParamDefinitions...
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=t.name,
|
||||
description=t.description,
|
||||
input_schema=t.input_schema,
|
||||
)
|
||||
# ...then can convert that to openai completions tool
|
||||
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool)
|
||||
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
server_label=original.server_label,
|
||||
tools=original.tools,
|
||||
)
|
||||
|
||||
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
|
||||
yield stream_event
|
||||
|
|
|
@ -12,10 +12,18 @@ from pydantic import BaseModel
|
|||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMCPApprovalResponse,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseTool,
|
||||
OpenAIResponseToolMCP,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
|
||||
|
||||
|
@ -55,6 +63,86 @@ class ChatCompletionResult:
|
|||
return bool(self.tool_calls)
|
||||
|
||||
|
||||
class ToolContext(BaseModel):
|
||||
"""Holds information about tools from this and (if relevant)
|
||||
previous response in order to facilitate reuse of previous
|
||||
listings where appropriate."""
|
||||
|
||||
# tools argument passed into current request:
|
||||
current_tools: list[OpenAIResponseInputTool]
|
||||
# reconstructed map of tool -> mcp server from previous response:
|
||||
previous_tools: dict[str, OpenAIResponseInputToolMCP]
|
||||
# reusable mcp-list-tools objects from previous response:
|
||||
previous_tool_listings: list[OpenAIResponseOutputMessageMCPListTools]
|
||||
# tool arguments from current request that still need to be processed:
|
||||
tools_to_process: list[OpenAIResponseInputTool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
current_tools: list[OpenAIResponseInputTool] | None,
|
||||
):
|
||||
super().__init__(
|
||||
current_tools=current_tools or [],
|
||||
previous_tools={},
|
||||
previous_tool_listings=[],
|
||||
tools_to_process=current_tools or [],
|
||||
)
|
||||
|
||||
def recover_tools_from_previous_response(
|
||||
self,
|
||||
previous_response: OpenAIResponseObject,
|
||||
):
|
||||
"""Determine which mcp_list_tools objects from previous response we can reuse."""
|
||||
|
||||
if self.current_tools and previous_response.tools:
|
||||
previous_tools_by_label: dict[str, OpenAIResponseToolMCP] = {}
|
||||
for tool in previous_response.tools:
|
||||
if isinstance(tool, OpenAIResponseToolMCP):
|
||||
previous_tools_by_label[tool.server_label] = tool
|
||||
# collect tool definitions which are the same in current and previous requests:
|
||||
tools_to_process = []
|
||||
matched: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
for tool in self.current_tools:
|
||||
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
|
||||
previous_tool = previous_tools_by_label[tool.server_label]
|
||||
if previous_tool.allowed_tools == tool.allowed_tools:
|
||||
matched[tool.server_label] = tool
|
||||
else:
|
||||
tools_to_process.append(tool)
|
||||
else:
|
||||
tools_to_process.append(tool)
|
||||
# tools that are not the same or were not previously defined need to be processed:
|
||||
self.tools_to_process = tools_to_process
|
||||
# for all matched definitions, get the mcp_list_tools objects from the previous output:
|
||||
self.previous_tool_listings = [
|
||||
obj for obj in previous_response.output if obj.type == "mcp_list_tools" and obj.server_label in matched
|
||||
]
|
||||
# reconstruct the tool to server mappings that can be reused:
|
||||
for listing in self.previous_tool_listings:
|
||||
definition = matched[listing.server_label]
|
||||
for tool in listing.tools:
|
||||
self.previous_tools[tool.name] = definition
|
||||
|
||||
def available_tools(self) -> list[OpenAIResponseTool]:
|
||||
if not self.current_tools:
|
||||
return []
|
||||
|
||||
def convert_tool(tool: OpenAIResponseInputTool) -> OpenAIResponseTool:
|
||||
if isinstance(tool, OpenAIResponseInputToolWebSearch):
|
||||
return tool
|
||||
if isinstance(tool, OpenAIResponseInputToolFileSearch):
|
||||
return tool
|
||||
if isinstance(tool, OpenAIResponseInputToolFunction):
|
||||
return tool
|
||||
if isinstance(tool, OpenAIResponseInputToolMCP):
|
||||
return OpenAIResponseToolMCP(
|
||||
server_label=tool.server_label,
|
||||
allowed_tools=tool.allowed_tools,
|
||||
)
|
||||
|
||||
return [convert_tool(tool) for tool in self.current_tools]
|
||||
|
||||
|
||||
class ChatCompletionContext(BaseModel):
|
||||
model: str
|
||||
messages: list[OpenAIMessageParam]
|
||||
|
@ -62,6 +150,7 @@ class ChatCompletionContext(BaseModel):
|
|||
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
||||
tool_context: ToolContext | None
|
||||
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
|
||||
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
|
||||
|
||||
|
@ -72,6 +161,7 @@ class ChatCompletionContext(BaseModel):
|
|||
response_tools: list[OpenAIResponseInputTool] | None,
|
||||
temperature: float | None,
|
||||
response_format: OpenAIResponseFormatParam,
|
||||
tool_context: ToolContext,
|
||||
inputs: list[OpenAIResponseInput] | str,
|
||||
):
|
||||
super().__init__(
|
||||
|
@ -80,6 +170,7 @@ class ChatCompletionContext(BaseModel):
|
|||
response_tools=response_tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
tool_context=tool_context,
|
||||
)
|
||||
if not isinstance(inputs, str):
|
||||
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
|
||||
|
@ -96,3 +187,8 @@ class ChatCompletionContext(BaseModel):
|
|||
if request.name == tool_name and request.arguments == arguments:
|
||||
return request
|
||||
return None
|
||||
|
||||
def available_tools(self) -> list[OpenAIResponseTool]:
|
||||
if not self.tool_context:
|
||||
return []
|
||||
return self.tool_context.available_tools()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue