mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
feat(responses): add mcp list tool streaming event (#3159)
# What does this PR do? Adds proper streaming events for MCP tool listing (`mcp_list_tools.in_progress` and `mcp_list_tools.completed`). Also refactors things a bit more. ## Test Plan Verified existing integration tests pass with the refactored code. The test `test_response_streaming_multi_turn_tool_execution` has been updated to check for the new MCP list tools streaming events
This commit is contained in:
parent
9324e902f1
commit
ba664474de
5 changed files with 260 additions and 145 deletions
|
@ -8,40 +8,29 @@ import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Order
|
from llama_stack.apis.agents import Order
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
AllowedToolsFilter,
|
|
||||||
ListOpenAIResponseInputItem,
|
ListOpenAIResponseInputItem,
|
||||||
ListOpenAIResponseObject,
|
ListOpenAIResponseObject,
|
||||||
MCPListToolsTool,
|
|
||||||
OpenAIDeleteResponseObject,
|
OpenAIDeleteResponseObject,
|
||||||
OpenAIResponseInput,
|
OpenAIResponseInput,
|
||||||
OpenAIResponseInputMessageContentText,
|
OpenAIResponseInputMessageContentText,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseInputToolMCP,
|
|
||||||
OpenAIResponseMessage,
|
OpenAIResponseMessage,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
OpenAIResponseObjectStream,
|
OpenAIResponseObjectStream,
|
||||||
OpenAIResponseOutput,
|
|
||||||
OpenAIResponseOutputMessageMCPListTools,
|
|
||||||
OpenAIResponseText,
|
OpenAIResponseText,
|
||||||
OpenAIResponseTextFormat,
|
OpenAIResponseTextFormat,
|
||||||
WebSearchToolTypes,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
OpenAISystemMessageParam,
|
OpenAISystemMessageParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools import Tool, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
|
||||||
convert_tooldef_to_openai_tool,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
|
|
||||||
from .streaming import StreamingResponseOrchestrator
|
from .streaming import StreamingResponseOrchestrator
|
||||||
|
@ -242,17 +231,10 @@ class OpenAIResponsesImpl:
|
||||||
# Structured outputs
|
# Structured outputs
|
||||||
response_format = await convert_response_text_to_chat_response_format(text)
|
response_format = await convert_response_text_to_chat_response_format(text)
|
||||||
|
|
||||||
# Tool setup, TODO: refactor this slightly since this can also yield events
|
|
||||||
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
|
||||||
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx = ChatCompletionContext(
|
ctx = ChatCompletionContext(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_tools=tools,
|
response_tools=tools,
|
||||||
chat_tools=chat_tools,
|
|
||||||
mcp_tool_to_server=mcp_tool_to_server,
|
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
)
|
)
|
||||||
|
@ -269,7 +251,6 @@ class OpenAIResponsesImpl:
|
||||||
text=text,
|
text=text,
|
||||||
max_infer_iters=max_infer_iters,
|
max_infer_iters=max_infer_iters,
|
||||||
tool_executor=self.tool_executor,
|
tool_executor=self.tool_executor,
|
||||||
mcp_list_message=mcp_list_message,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stream the response
|
# Stream the response
|
||||||
|
@ -288,98 +269,3 @@ class OpenAIResponsesImpl:
|
||||||
|
|
||||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||||
return await self.responses_store.delete_response_object(response_id)
|
return await self.responses_store.delete_response_object(response_id)
|
||||||
|
|
||||||
async def _convert_response_tools_to_chat_tools(
|
|
||||||
self, tools: list[OpenAIResponseInputTool]
|
|
||||||
) -> tuple[
|
|
||||||
list[ChatCompletionToolParam],
|
|
||||||
dict[str, OpenAIResponseInputToolMCP],
|
|
||||||
OpenAIResponseOutput | None,
|
|
||||||
]:
|
|
||||||
mcp_tool_to_server = {}
|
|
||||||
|
|
||||||
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
|
||||||
tool_def = ToolDefinition(
|
|
||||||
tool_name=tool_name,
|
|
||||||
description=tool.description,
|
|
||||||
parameters={
|
|
||||||
param.name: ToolParamDefinition(
|
|
||||||
param_type=param.parameter_type,
|
|
||||||
description=param.description,
|
|
||||||
required=param.required,
|
|
||||||
default=param.default,
|
|
||||||
)
|
|
||||||
for param in tool.parameters
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return convert_tooldef_to_openai_tool(tool_def)
|
|
||||||
|
|
||||||
mcp_list_message = None
|
|
||||||
chat_tools: list[ChatCompletionToolParam] = []
|
|
||||||
for input_tool in tools:
|
|
||||||
# TODO: Handle other tool types
|
|
||||||
if input_tool.type == "function":
|
|
||||||
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
|
||||||
elif input_tool.type in WebSearchToolTypes:
|
|
||||||
tool_name = "web_search"
|
|
||||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
|
||||||
if not tool:
|
|
||||||
raise ValueError(f"Tool {tool_name} not found")
|
|
||||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
|
||||||
elif input_tool.type == "file_search":
|
|
||||||
tool_name = "knowledge_search"
|
|
||||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
|
||||||
if not tool:
|
|
||||||
raise ValueError(f"Tool {tool_name} not found")
|
|
||||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
|
||||||
elif input_tool.type == "mcp":
|
|
||||||
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
|
|
||||||
|
|
||||||
always_allowed = None
|
|
||||||
never_allowed = None
|
|
||||||
if input_tool.allowed_tools:
|
|
||||||
if isinstance(input_tool.allowed_tools, list):
|
|
||||||
always_allowed = input_tool.allowed_tools
|
|
||||||
elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
|
|
||||||
always_allowed = input_tool.allowed_tools.always
|
|
||||||
never_allowed = input_tool.allowed_tools.never
|
|
||||||
|
|
||||||
tool_defs = await list_mcp_tools(
|
|
||||||
endpoint=input_tool.server_url,
|
|
||||||
headers=input_tool.headers or {},
|
|
||||||
)
|
|
||||||
|
|
||||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
|
||||||
id=f"mcp_list_{uuid.uuid4()}",
|
|
||||||
status="completed",
|
|
||||||
server_label=input_tool.server_label,
|
|
||||||
tools=[],
|
|
||||||
)
|
|
||||||
for t in tool_defs.data:
|
|
||||||
if never_allowed and t.name in never_allowed:
|
|
||||||
continue
|
|
||||||
if not always_allowed or t.name in always_allowed:
|
|
||||||
chat_tools.append(make_openai_tool(t.name, t))
|
|
||||||
if t.name in mcp_tool_to_server:
|
|
||||||
raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
|
|
||||||
mcp_tool_to_server[t.name] = input_tool
|
|
||||||
mcp_list_message.tools.append(
|
|
||||||
MCPListToolsTool(
|
|
||||||
name=t.name,
|
|
||||||
description=t.description,
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
p.name: {
|
|
||||||
"type": p.parameter_type,
|
|
||||||
"description": p.description,
|
|
||||||
}
|
|
||||||
for p in t.parameters
|
|
||||||
},
|
|
||||||
"required": [p.name for p in t.parameters if p.required],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
|
||||||
return chat_tools, mcp_tool_to_server, mcp_list_message
|
|
||||||
|
|
|
@ -9,7 +9,11 @@ from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
AllowedToolsFilter,
|
||||||
|
MCPListToolsTool,
|
||||||
OpenAIResponseContentPartOutputText,
|
OpenAIResponseContentPartOutputText,
|
||||||
|
OpenAIResponseInputTool,
|
||||||
|
OpenAIResponseInputToolMCP,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
OpenAIResponseObjectStream,
|
OpenAIResponseObjectStream,
|
||||||
OpenAIResponseObjectStreamResponseCompleted,
|
OpenAIResponseObjectStreamResponseCompleted,
|
||||||
|
@ -20,12 +24,16 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
|
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
|
||||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
|
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
|
||||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
|
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
|
||||||
|
OpenAIResponseObjectStreamResponseMcpListToolsCompleted,
|
||||||
|
OpenAIResponseObjectStreamResponseMcpListToolsInProgress,
|
||||||
OpenAIResponseObjectStreamResponseOutputItemAdded,
|
OpenAIResponseObjectStreamResponseOutputItemAdded,
|
||||||
OpenAIResponseObjectStreamResponseOutputItemDone,
|
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||||
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||||
OpenAIResponseOutput,
|
OpenAIResponseOutput,
|
||||||
OpenAIResponseOutputMessageFunctionToolCall,
|
OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
|
OpenAIResponseOutputMessageMCPListTools,
|
||||||
OpenAIResponseText,
|
OpenAIResponseText,
|
||||||
|
WebSearchToolTypes,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
|
@ -52,7 +60,6 @@ class StreamingResponseOrchestrator:
|
||||||
text: OpenAIResponseText,
|
text: OpenAIResponseText,
|
||||||
max_infer_iters: int,
|
max_infer_iters: int,
|
||||||
tool_executor, # Will be the tool execution logic from the main class
|
tool_executor, # Will be the tool execution logic from the main class
|
||||||
mcp_list_message: OpenAIResponseOutput | None = None,
|
|
||||||
):
|
):
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
|
@ -62,13 +69,12 @@ class StreamingResponseOrchestrator:
|
||||||
self.max_infer_iters = max_infer_iters
|
self.max_infer_iters = max_infer_iters
|
||||||
self.tool_executor = tool_executor
|
self.tool_executor = tool_executor
|
||||||
self.sequence_number = 0
|
self.sequence_number = 0
|
||||||
self.mcp_list_message = mcp_list_message
|
# Store MCP tool mapping that gets built during tool processing
|
||||||
|
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||||
|
|
||||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
# Initialize output messages with MCP list message if present
|
# Initialize output messages
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
if self.mcp_list_message:
|
|
||||||
output_messages.append(self.mcp_list_message)
|
|
||||||
# Create initial response and emit response.created immediately
|
# Create initial response and emit response.created immediately
|
||||||
initial_response = OpenAIResponseObject(
|
initial_response = OpenAIResponseObject(
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
|
@ -82,6 +88,11 @@ class StreamingResponseOrchestrator:
|
||||||
|
|
||||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||||
|
|
||||||
|
# Process all tools (including MCP tools) and emit streaming events
|
||||||
|
if self.ctx.response_tools:
|
||||||
|
async for stream_event in self._process_tools(self.ctx.response_tools, output_messages):
|
||||||
|
yield stream_event
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = 0
|
||||||
messages = self.ctx.messages.copy()
|
messages = self.ctx.messages.copy()
|
||||||
|
|
||||||
|
@ -261,9 +272,7 @@ class StreamingResponseOrchestrator:
|
||||||
self.sequence_number += 1
|
self.sequence_number += 1
|
||||||
|
|
||||||
# Check if this is an MCP tool call
|
# Check if this is an MCP tool call
|
||||||
is_mcp_tool = (
|
is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server
|
||||||
tool_call.function.name and tool_call.function.name in self.ctx.mcp_tool_to_server
|
|
||||||
)
|
|
||||||
if is_mcp_tool:
|
if is_mcp_tool:
|
||||||
# Emit MCP-specific argument delta event
|
# Emit MCP-specific argument delta event
|
||||||
yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(
|
yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(
|
||||||
|
@ -294,9 +303,7 @@ class StreamingResponseOrchestrator:
|
||||||
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
|
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
|
||||||
|
|
||||||
# Check if this is an MCP tool call
|
# Check if this is an MCP tool call
|
||||||
is_mcp_tool = (
|
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
|
||||||
self.ctx.mcp_tool_to_server and tool_call_name and tool_call_name in self.ctx.mcp_tool_to_server
|
|
||||||
)
|
|
||||||
self.sequence_number += 1
|
self.sequence_number += 1
|
||||||
done_event_cls = (
|
done_event_cls = (
|
||||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone
|
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone
|
||||||
|
@ -391,7 +398,12 @@ class StreamingResponseOrchestrator:
|
||||||
tool_call_log = None
|
tool_call_log = None
|
||||||
tool_response_message = None
|
tool_response_message = None
|
||||||
async for result in self.tool_executor.execute_tool_call(
|
async for result in self.tool_executor.execute_tool_call(
|
||||||
tool_call, self.ctx, self.sequence_number, len(output_messages), matching_item_id
|
tool_call,
|
||||||
|
self.ctx,
|
||||||
|
self.sequence_number,
|
||||||
|
len(output_messages),
|
||||||
|
matching_item_id,
|
||||||
|
self.mcp_tool_to_server,
|
||||||
):
|
):
|
||||||
if result.stream_event:
|
if result.stream_event:
|
||||||
# Forward streaming events
|
# Forward streaming events
|
||||||
|
@ -449,3 +461,174 @@ class StreamingResponseOrchestrator:
|
||||||
output_index=len(output_messages) - 1,
|
output_index=len(output_messages) - 1,
|
||||||
sequence_number=self.sequence_number,
|
sequence_number=self.sequence_number,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _process_tools(
|
||||||
|
self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput]
|
||||||
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
"""Process all tools and emit appropriate streaming events."""
|
||||||
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
|
from llama_stack.apis.tools import Tool
|
||||||
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||||
|
|
||||||
|
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
||||||
|
tool_def = ToolDefinition(
|
||||||
|
tool_name=tool_name,
|
||||||
|
description=tool.description,
|
||||||
|
parameters={
|
||||||
|
param.name: ToolParamDefinition(
|
||||||
|
param_type=param.parameter_type,
|
||||||
|
description=param.description,
|
||||||
|
required=param.required,
|
||||||
|
default=param.default,
|
||||||
|
)
|
||||||
|
for param in tool.parameters
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return convert_tooldef_to_openai_tool(tool_def)
|
||||||
|
|
||||||
|
# Initialize chat_tools if not already set
|
||||||
|
if self.ctx.chat_tools is None:
|
||||||
|
self.ctx.chat_tools = []
|
||||||
|
|
||||||
|
for input_tool in tools:
|
||||||
|
if input_tool.type == "function":
|
||||||
|
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||||
|
elif input_tool.type in WebSearchToolTypes:
|
||||||
|
tool_name = "web_search"
|
||||||
|
# Need to access tool_groups_api from tool_executor
|
||||||
|
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
|
||||||
|
if not tool:
|
||||||
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
|
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
|
||||||
|
elif input_tool.type == "file_search":
|
||||||
|
tool_name = "knowledge_search"
|
||||||
|
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
|
||||||
|
if not tool:
|
||||||
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
|
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
|
||||||
|
elif input_tool.type == "mcp":
|
||||||
|
async for stream_event in self._process_mcp_tool(input_tool, output_messages):
|
||||||
|
yield stream_event
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||||
|
|
||||||
|
async def _process_mcp_tool(
|
||||||
|
self, mcp_tool: OpenAIResponseInputToolMCP, output_messages: list[OpenAIResponseOutput]
|
||||||
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
"""Process an MCP tool configuration and emit appropriate streaming events."""
|
||||||
|
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
|
||||||
|
|
||||||
|
# Emit mcp_list_tools.in_progress
|
||||||
|
self.sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse allowed/never allowed tools
|
||||||
|
always_allowed = None
|
||||||
|
never_allowed = None
|
||||||
|
if mcp_tool.allowed_tools:
|
||||||
|
if isinstance(mcp_tool.allowed_tools, list):
|
||||||
|
always_allowed = mcp_tool.allowed_tools
|
||||||
|
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
|
||||||
|
always_allowed = mcp_tool.allowed_tools.always
|
||||||
|
never_allowed = mcp_tool.allowed_tools.never
|
||||||
|
|
||||||
|
# Call list_mcp_tools
|
||||||
|
tool_defs = await list_mcp_tools(
|
||||||
|
endpoint=mcp_tool.server_url,
|
||||||
|
headers=mcp_tool.headers or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the MCP list tools message
|
||||||
|
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||||
|
id=f"mcp_list_{uuid.uuid4()}",
|
||||||
|
server_label=mcp_tool.server_label,
|
||||||
|
tools=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process tools and update context
|
||||||
|
for t in tool_defs.data:
|
||||||
|
if never_allowed and t.name in never_allowed:
|
||||||
|
continue
|
||||||
|
if not always_allowed or t.name in always_allowed:
|
||||||
|
# Add to chat tools for inference
|
||||||
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||||
|
|
||||||
|
tool_def = ToolDefinition(
|
||||||
|
tool_name=t.name,
|
||||||
|
description=t.description,
|
||||||
|
parameters={
|
||||||
|
param.name: ToolParamDefinition(
|
||||||
|
param_type=param.parameter_type,
|
||||||
|
description=param.description,
|
||||||
|
required=param.required,
|
||||||
|
default=param.default,
|
||||||
|
)
|
||||||
|
for param in t.parameters
|
||||||
|
},
|
||||||
|
)
|
||||||
|
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||||
|
if self.ctx.chat_tools is None:
|
||||||
|
self.ctx.chat_tools = []
|
||||||
|
self.ctx.chat_tools.append(openai_tool)
|
||||||
|
|
||||||
|
# Add to MCP tool mapping
|
||||||
|
if t.name in self.mcp_tool_to_server:
|
||||||
|
raise ValueError(f"Duplicate tool name {t.name} found for server {mcp_tool.server_label}")
|
||||||
|
self.mcp_tool_to_server[t.name] = mcp_tool
|
||||||
|
|
||||||
|
# Add to MCP list message
|
||||||
|
mcp_list_message.tools.append(
|
||||||
|
MCPListToolsTool(
|
||||||
|
name=t.name,
|
||||||
|
description=t.description,
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
p.name: {
|
||||||
|
"type": p.parameter_type,
|
||||||
|
"description": p.description,
|
||||||
|
}
|
||||||
|
for p in t.parameters
|
||||||
|
},
|
||||||
|
"required": [p.name for p in t.parameters if p.required],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the MCP list message to output
|
||||||
|
output_messages.append(mcp_list_message)
|
||||||
|
|
||||||
|
# Emit output_item.added for the MCP list tools message
|
||||||
|
self.sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||||
|
response_id=self.response_id,
|
||||||
|
item=mcp_list_message,
|
||||||
|
output_index=len(output_messages) - 1,
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit mcp_list_tools.completed
|
||||||
|
self.sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit output_item.done for the MCP list tools message
|
||||||
|
self.sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||||
|
response_id=self.response_id,
|
||||||
|
item=mcp_list_message,
|
||||||
|
output_index=len(output_messages) - 1,
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# TODO: Emit mcp_list_tools.failed event if needed
|
||||||
|
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
|
||||||
|
raise
|
||||||
|
|
|
@ -10,6 +10,7 @@ from collections.abc import AsyncIterator
|
||||||
|
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseInputToolFileSearch,
|
OpenAIResponseInputToolFileSearch,
|
||||||
|
OpenAIResponseInputToolMCP,
|
||||||
OpenAIResponseObjectStreamResponseMcpCallCompleted,
|
OpenAIResponseObjectStreamResponseMcpCallCompleted,
|
||||||
OpenAIResponseObjectStreamResponseMcpCallFailed,
|
OpenAIResponseObjectStreamResponseMcpCallFailed,
|
||||||
OpenAIResponseObjectStreamResponseMcpCallInProgress,
|
OpenAIResponseObjectStreamResponseMcpCallInProgress,
|
||||||
|
@ -58,6 +59,7 @@ class ToolExecutor:
|
||||||
sequence_number: int,
|
sequence_number: int,
|
||||||
output_index: int,
|
output_index: int,
|
||||||
item_id: str,
|
item_id: str,
|
||||||
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
) -> AsyncIterator[ToolExecutionResult]:
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
tool_call_id = tool_call.id
|
tool_call_id = tool_call.id
|
||||||
function = tool_call.function
|
function = tool_call.function
|
||||||
|
@ -69,25 +71,25 @@ class ToolExecutor:
|
||||||
|
|
||||||
# Emit progress events for tool execution start
|
# Emit progress events for tool execution start
|
||||||
async for event_result in self._emit_progress_events(
|
async for event_result in self._emit_progress_events(
|
||||||
function.name, ctx, sequence_number, output_index, item_id
|
function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server
|
||||||
):
|
):
|
||||||
sequence_number = event_result.sequence_number
|
sequence_number = event_result.sequence_number
|
||||||
yield event_result
|
yield event_result
|
||||||
|
|
||||||
# Execute the actual tool call
|
# Execute the actual tool call
|
||||||
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx)
|
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
||||||
|
|
||||||
# Emit completion events for tool execution
|
# Emit completion events for tool execution
|
||||||
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
||||||
async for event_result in self._emit_completion_events(
|
async for event_result in self._emit_completion_events(
|
||||||
function.name, ctx, sequence_number, output_index, item_id, has_error
|
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
||||||
):
|
):
|
||||||
sequence_number = event_result.sequence_number
|
sequence_number = event_result.sequence_number
|
||||||
yield event_result
|
yield event_result
|
||||||
|
|
||||||
# Build result messages from tool execution
|
# Build result messages from tool execution
|
||||||
output_message, input_message = await self._build_result_messages(
|
output_message, input_message = await self._build_result_messages(
|
||||||
function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error
|
function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
|
||||||
)
|
)
|
||||||
|
|
||||||
# Yield the final result
|
# Yield the final result
|
||||||
|
@ -161,12 +163,18 @@ class ToolExecutor:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _emit_progress_events(
|
async def _emit_progress_events(
|
||||||
self, function_name: str, ctx: ChatCompletionContext, sequence_number: int, output_index: int, item_id: str
|
self,
|
||||||
|
function_name: str,
|
||||||
|
ctx: ChatCompletionContext,
|
||||||
|
sequence_number: int,
|
||||||
|
output_index: int,
|
||||||
|
item_id: str,
|
||||||
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
) -> AsyncIterator[ToolExecutionResult]:
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
"""Emit progress events for tool execution start."""
|
"""Emit progress events for tool execution start."""
|
||||||
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||||
progress_event = None
|
progress_event = None
|
||||||
if ctx.mcp_tool_to_server and function_name in ctx.mcp_tool_to_server:
|
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
|
@ -196,17 +204,21 @@ class ToolExecutor:
|
||||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||||
|
|
||||||
async def _execute_tool(
|
async def _execute_tool(
|
||||||
self, function_name: str, tool_kwargs: dict, ctx: ChatCompletionContext
|
self,
|
||||||
|
function_name: str,
|
||||||
|
tool_kwargs: dict,
|
||||||
|
ctx: ChatCompletionContext,
|
||||||
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
) -> tuple[Exception | None, any]:
|
) -> tuple[Exception | None, any]:
|
||||||
"""Execute the tool and return error exception and result."""
|
"""Execute the tool and return error exception and result."""
|
||||||
error_exc = None
|
error_exc = None
|
||||||
result = None
|
result = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if ctx.mcp_tool_to_server and function_name in ctx.mcp_tool_to_server:
|
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||||
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
|
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
|
||||||
|
|
||||||
mcp_tool = ctx.mcp_tool_to_server[function_name]
|
mcp_tool = mcp_tool_to_server[function_name]
|
||||||
result = await invoke_mcp_tool(
|
result = await invoke_mcp_tool(
|
||||||
endpoint=mcp_tool.server_url,
|
endpoint=mcp_tool.server_url,
|
||||||
headers=mcp_tool.headers or {},
|
headers=mcp_tool.headers or {},
|
||||||
|
@ -244,11 +256,12 @@ class ToolExecutor:
|
||||||
output_index: int,
|
output_index: int,
|
||||||
item_id: str,
|
item_id: str,
|
||||||
has_error: bool,
|
has_error: bool,
|
||||||
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
) -> AsyncIterator[ToolExecutionResult]:
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
"""Emit completion or failure events for tool execution."""
|
"""Emit completion or failure events for tool execution."""
|
||||||
completion_event = None
|
completion_event = None
|
||||||
|
|
||||||
if ctx.mcp_tool_to_server and function_name in ctx.mcp_tool_to_server:
|
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
if has_error:
|
if has_error:
|
||||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||||
|
@ -279,6 +292,7 @@ class ToolExecutor:
|
||||||
error_exc: Exception | None,
|
error_exc: Exception | None,
|
||||||
result: any,
|
result: any,
|
||||||
has_error: bool,
|
has_error: bool,
|
||||||
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
) -> tuple[any, any]:
|
) -> tuple[any, any]:
|
||||||
"""Build output and input messages from tool execution results."""
|
"""Build output and input messages from tool execution results."""
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
@ -286,7 +300,7 @@ class ToolExecutor:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build output message
|
# Build output message
|
||||||
if function.name in ctx.mcp_tool_to_server:
|
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseOutputMessageMCPCall,
|
OpenAIResponseOutputMessageMCPCall,
|
||||||
)
|
)
|
||||||
|
@ -295,7 +309,7 @@ class ToolExecutor:
|
||||||
id=tool_call_id,
|
id=tool_call_id,
|
||||||
arguments=function.arguments,
|
arguments=function.arguments,
|
||||||
name=function.name,
|
name=function.name,
|
||||||
server_label=ctx.mcp_tool_to_server[function.name].server_label,
|
server_label=mcp_tool_to_server[function.name].server_label,
|
||||||
)
|
)
|
||||||
if error_exc:
|
if error_exc:
|
||||||
message.error = str(error_exc)
|
message.error = str(error_exc)
|
||||||
|
|
|
@ -11,7 +11,6 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseInputToolMCP,
|
|
||||||
OpenAIResponseObjectStream,
|
OpenAIResponseObjectStream,
|
||||||
OpenAIResponseOutput,
|
OpenAIResponseOutput,
|
||||||
)
|
)
|
||||||
|
@ -57,6 +56,5 @@ class ChatCompletionContext(BaseModel):
|
||||||
messages: list[OpenAIMessageParam]
|
messages: list[OpenAIMessageParam]
|
||||||
response_tools: list[OpenAIResponseInputTool] | None = None
|
response_tools: list[OpenAIResponseInputTool] | None = None
|
||||||
chat_tools: list[ChatCompletionToolParam] | None = None
|
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
|
||||||
temperature: float | None
|
temperature: float | None
|
||||||
response_format: OpenAIResponseFormatParam
|
response_format: OpenAIResponseFormatParam
|
||||||
|
|
|
@ -610,6 +610,14 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
|
||||||
mcp_in_progress_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.in_progress"]
|
mcp_in_progress_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.in_progress"]
|
||||||
mcp_completed_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.completed"]
|
mcp_completed_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.completed"]
|
||||||
|
|
||||||
|
# Should have MCP list tools streaming events
|
||||||
|
mcp_list_tools_in_progress_events = [
|
||||||
|
chunk for chunk in chunks if chunk.type == "response.mcp_list_tools.in_progress"
|
||||||
|
]
|
||||||
|
mcp_list_tools_completed_events = [
|
||||||
|
chunk for chunk in chunks if chunk.type == "response.mcp_list_tools.completed"
|
||||||
|
]
|
||||||
|
|
||||||
# Verify we have substantial streaming activity (not just batch events)
|
# Verify we have substantial streaming activity (not just batch events)
|
||||||
assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks"
|
assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks"
|
||||||
|
|
||||||
|
@ -632,6 +640,14 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
|
||||||
assert len(mcp_completed_events) > 0, (
|
assert len(mcp_completed_events) > 0, (
|
||||||
f"Expected response.mcp_call.completed events, got chunk types: {chunk_types}"
|
f"Expected response.mcp_call.completed events, got chunk types: {chunk_types}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Should have MCP list tools streaming events
|
||||||
|
assert len(mcp_list_tools_in_progress_events) > 0, (
|
||||||
|
f"Expected response.mcp_list_tools.in_progress events, got chunk types: {chunk_types}"
|
||||||
|
)
|
||||||
|
assert len(mcp_list_tools_completed_events) > 0, (
|
||||||
|
f"Expected response.mcp_list_tools.completed events, got chunk types: {chunk_types}"
|
||||||
|
)
|
||||||
# MCP failed events are optional (only if errors occur)
|
# MCP failed events are optional (only if errors occur)
|
||||||
|
|
||||||
# Verify progress events have proper structure
|
# Verify progress events have proper structure
|
||||||
|
@ -643,6 +659,17 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
|
||||||
for completed_event in mcp_completed_events:
|
for completed_event in mcp_completed_events:
|
||||||
assert hasattr(completed_event, "sequence_number"), "Completed event should have 'sequence_number' field"
|
assert hasattr(completed_event, "sequence_number"), "Completed event should have 'sequence_number' field"
|
||||||
|
|
||||||
|
# Verify MCP list tools events have proper structure
|
||||||
|
for list_tools_progress_event in mcp_list_tools_in_progress_events:
|
||||||
|
assert hasattr(list_tools_progress_event, "sequence_number"), (
|
||||||
|
"MCP list tools progress event should have 'sequence_number' field"
|
||||||
|
)
|
||||||
|
|
||||||
|
for list_tools_completed_event in mcp_list_tools_completed_events:
|
||||||
|
assert hasattr(list_tools_completed_event, "sequence_number"), (
|
||||||
|
"MCP list tools completed event should have 'sequence_number' field"
|
||||||
|
)
|
||||||
|
|
||||||
# Verify delta events have proper structure
|
# Verify delta events have proper structure
|
||||||
for delta_event in delta_events:
|
for delta_event in delta_events:
|
||||||
assert hasattr(delta_event, "delta"), "Delta event should have 'delta' field"
|
assert hasattr(delta_event, "delta"), "Delta event should have 'delta' field"
|
||||||
|
@ -662,8 +689,12 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
|
||||||
assert hasattr(added_event, "output_index"), "Added event should have 'output_index' field"
|
assert hasattr(added_event, "output_index"), "Added event should have 'output_index' field"
|
||||||
assert hasattr(added_event, "sequence_number"), "Added event should have 'sequence_number' field"
|
assert hasattr(added_event, "sequence_number"), "Added event should have 'sequence_number' field"
|
||||||
assert hasattr(added_event, "response_id"), "Added event should have 'response_id' field"
|
assert hasattr(added_event, "response_id"), "Added event should have 'response_id' field"
|
||||||
assert added_event.item.type in ["function_call", "mcp_call"], "Added item should be a tool call"
|
assert added_event.item.type in ["function_call", "mcp_call", "mcp_list_tools"], (
|
||||||
assert added_event.item.status == "in_progress", "Added item should be in progress"
|
"Added item should be a tool call or MCP list tools"
|
||||||
|
)
|
||||||
|
if added_event.item.type in ["function_call", "mcp_call"]:
|
||||||
|
assert added_event.item.status == "in_progress", "Added tool call should be in progress"
|
||||||
|
# Note: mcp_list_tools doesn't have a status field, it's implicitly completed when added
|
||||||
assert added_event.response_id, "Response ID should not be empty"
|
assert added_event.response_id, "Response ID should not be empty"
|
||||||
assert isinstance(added_event.output_index, int), "Output index should be integer"
|
assert isinstance(added_event.output_index, int), "Output index should be integer"
|
||||||
assert added_event.output_index >= 0, "Output index should be non-negative"
|
assert added_event.output_index >= 0, "Output index should be non-negative"
|
||||||
|
@ -674,10 +705,13 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
|
||||||
assert hasattr(done_event, "output_index"), "Done event should have 'output_index' field"
|
assert hasattr(done_event, "output_index"), "Done event should have 'output_index' field"
|
||||||
assert hasattr(done_event, "sequence_number"), "Done event should have 'sequence_number' field"
|
assert hasattr(done_event, "sequence_number"), "Done event should have 'sequence_number' field"
|
||||||
assert hasattr(done_event, "response_id"), "Done event should have 'response_id' field"
|
assert hasattr(done_event, "response_id"), "Done event should have 'response_id' field"
|
||||||
assert done_event.item.type in ["function_call", "mcp_call"], "Done item should be a tool call"
|
assert done_event.item.type in ["function_call", "mcp_call", "mcp_list_tools"], (
|
||||||
# Note: MCP calls don't have a status field, only function calls do
|
"Done item should be a tool call or MCP list tools"
|
||||||
|
)
|
||||||
|
# Note: MCP calls and mcp_list_tools don't have a status field, only function calls do
|
||||||
if done_event.item.type == "function_call":
|
if done_event.item.type == "function_call":
|
||||||
assert done_event.item.status == "completed", "Function call should be completed"
|
assert done_event.item.status == "completed", "Function call should be completed"
|
||||||
|
# Note: mcp_call and mcp_list_tools don't have status fields
|
||||||
assert done_event.response_id, "Response ID should not be empty"
|
assert done_event.response_id, "Response ID should not be empty"
|
||||||
assert isinstance(done_event.output_index, int), "Output index should be integer"
|
assert isinstance(done_event.output_index, int), "Output index should be integer"
|
||||||
assert done_event.output_index >= 0, "Output index should be non-negative"
|
assert done_event.output_index >= 0, "Output index should be non-negative"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue