mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
feat(responses): add mcp list tool streaming event (#3159)
# What does this PR do? Adds proper streaming events for MCP tool listing (`mcp_list_tools.in_progress` and `mcp_list_tools.completed`). Also refactors things a bit more. ## Test Plan Verified existing integration tests pass with the refactored code. The test `test_response_streaming_multi_turn_tool_execution` has been updated to check for the new MCP list tools streaming events
This commit is contained in:
parent
9324e902f1
commit
ba664474de
5 changed files with 260 additions and 145 deletions
|
@ -8,40 +8,29 @@ import time
|
|||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
AllowedToolsFilter,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
MCPListToolsTool,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools import Tool, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_tooldef_to_openai_tool,
|
||||
)
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
|
@ -242,17 +231,10 @@ class OpenAIResponsesImpl:
|
|||
# Structured outputs
|
||||
response_format = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
# Tool setup, TODO: refactor this slightly since this can also yield events
|
||||
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
||||
)
|
||||
|
||||
ctx = ChatCompletionContext(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_tools=tools,
|
||||
chat_tools=chat_tools,
|
||||
mcp_tool_to_server=mcp_tool_to_server,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
@ -269,7 +251,6 @@ class OpenAIResponsesImpl:
|
|||
text=text,
|
||||
max_infer_iters=max_infer_iters,
|
||||
tool_executor=self.tool_executor,
|
||||
mcp_list_message=mcp_list_message,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
|
@ -288,98 +269,3 @@ class OpenAIResponsesImpl:
|
|||
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
return await self.responses_store.delete_response_object(response_id)
|
||||
|
||||
async def _convert_response_tools_to_chat_tools(
|
||||
self, tools: list[OpenAIResponseInputTool]
|
||||
) -> tuple[
|
||||
list[ChatCompletionToolParam],
|
||||
dict[str, OpenAIResponseInputToolMCP],
|
||||
OpenAIResponseOutput | None,
|
||||
]:
|
||||
mcp_tool_to_server = {}
|
||||
|
||||
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool.parameters
|
||||
},
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(tool_def)
|
||||
|
||||
mcp_list_message = None
|
||||
chat_tools: list[ChatCompletionToolParam] = []
|
||||
for input_tool in tools:
|
||||
# TODO: Handle other tool types
|
||||
if input_tool.type == "function":
|
||||
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||
elif input_tool.type in WebSearchToolTypes:
|
||||
tool_name = "web_search"
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "file_search":
|
||||
tool_name = "knowledge_search"
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "mcp":
|
||||
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
|
||||
|
||||
always_allowed = None
|
||||
never_allowed = None
|
||||
if input_tool.allowed_tools:
|
||||
if isinstance(input_tool.allowed_tools, list):
|
||||
always_allowed = input_tool.allowed_tools
|
||||
elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
|
||||
always_allowed = input_tool.allowed_tools.always
|
||||
never_allowed = input_tool.allowed_tools.never
|
||||
|
||||
tool_defs = await list_mcp_tools(
|
||||
endpoint=input_tool.server_url,
|
||||
headers=input_tool.headers or {},
|
||||
)
|
||||
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
status="completed",
|
||||
server_label=input_tool.server_label,
|
||||
tools=[],
|
||||
)
|
||||
for t in tool_defs.data:
|
||||
if never_allowed and t.name in never_allowed:
|
||||
continue
|
||||
if not always_allowed or t.name in always_allowed:
|
||||
chat_tools.append(make_openai_tool(t.name, t))
|
||||
if t.name in mcp_tool_to_server:
|
||||
raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
|
||||
mcp_tool_to_server[t.name] = input_tool
|
||||
mcp_list_message.tools.append(
|
||||
MCPListToolsTool(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
p.name: {
|
||||
"type": p.parameter_type,
|
||||
"description": p.description,
|
||||
}
|
||||
for p in t.parameters
|
||||
},
|
||||
"required": [p.name for p in t.parameters if p.required],
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||
return chat_tools, mcp_tool_to_server, mcp_list_message
|
||||
|
|
|
@ -9,7 +9,11 @@ from collections.abc import AsyncIterator
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
AllowedToolsFilter,
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
|
@ -20,12 +24,16 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
|
||||
OpenAIResponseObjectStreamResponseMcpListToolsCompleted,
|
||||
OpenAIResponseObjectStreamResponseMcpListToolsInProgress,
|
||||
OpenAIResponseObjectStreamResponseOutputItemAdded,
|
||||
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseText,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
|
@ -52,7 +60,6 @@ class StreamingResponseOrchestrator:
|
|||
text: OpenAIResponseText,
|
||||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
mcp_list_message: OpenAIResponseOutput | None = None,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.ctx = ctx
|
||||
|
@ -62,13 +69,12 @@ class StreamingResponseOrchestrator:
|
|||
self.max_infer_iters = max_infer_iters
|
||||
self.tool_executor = tool_executor
|
||||
self.sequence_number = 0
|
||||
self.mcp_list_message = mcp_list_message
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Initialize output messages with MCP list message if present
|
||||
# Initialize output messages
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
if self.mcp_list_message:
|
||||
output_messages.append(self.mcp_list_message)
|
||||
# Create initial response and emit response.created immediately
|
||||
initial_response = OpenAIResponseObject(
|
||||
created_at=self.created_at,
|
||||
|
@ -82,6 +88,11 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||
|
||||
# Process all tools (including MCP tools) and emit streaming events
|
||||
if self.ctx.response_tools:
|
||||
async for stream_event in self._process_tools(self.ctx.response_tools, output_messages):
|
||||
yield stream_event
|
||||
|
||||
n_iter = 0
|
||||
messages = self.ctx.messages.copy()
|
||||
|
||||
|
@ -261,9 +272,7 @@ class StreamingResponseOrchestrator:
|
|||
self.sequence_number += 1
|
||||
|
||||
# Check if this is an MCP tool call
|
||||
is_mcp_tool = (
|
||||
tool_call.function.name and tool_call.function.name in self.ctx.mcp_tool_to_server
|
||||
)
|
||||
is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server
|
||||
if is_mcp_tool:
|
||||
# Emit MCP-specific argument delta event
|
||||
yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(
|
||||
|
@ -294,9 +303,7 @@ class StreamingResponseOrchestrator:
|
|||
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
|
||||
|
||||
# Check if this is an MCP tool call
|
||||
is_mcp_tool = (
|
||||
self.ctx.mcp_tool_to_server and tool_call_name and tool_call_name in self.ctx.mcp_tool_to_server
|
||||
)
|
||||
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
|
||||
self.sequence_number += 1
|
||||
done_event_cls = (
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone
|
||||
|
@ -391,7 +398,12 @@ class StreamingResponseOrchestrator:
|
|||
tool_call_log = None
|
||||
tool_response_message = None
|
||||
async for result in self.tool_executor.execute_tool_call(
|
||||
tool_call, self.ctx, self.sequence_number, len(output_messages), matching_item_id
|
||||
tool_call,
|
||||
self.ctx,
|
||||
self.sequence_number,
|
||||
len(output_messages),
|
||||
matching_item_id,
|
||||
self.mcp_tool_to_server,
|
||||
):
|
||||
if result.stream_event:
|
||||
# Forward streaming events
|
||||
|
@ -449,3 +461,174 @@ class StreamingResponseOrchestrator:
|
|||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _process_tools(
|
||||
self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Process all tools and emit appropriate streaming events."""
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from llama_stack.apis.tools import Tool
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool.parameters
|
||||
},
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(tool_def)
|
||||
|
||||
# Initialize chat_tools if not already set
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
|
||||
for input_tool in tools:
|
||||
if input_tool.type == "function":
|
||||
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||
elif input_tool.type in WebSearchToolTypes:
|
||||
tool_name = "web_search"
|
||||
# Need to access tool_groups_api from tool_executor
|
||||
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "file_search":
|
||||
tool_name = "knowledge_search"
|
||||
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "mcp":
|
||||
async for stream_event in self._process_mcp_tool(input_tool, output_messages):
|
||||
yield stream_event
|
||||
else:
|
||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||
|
||||
async def _process_mcp_tool(
|
||||
self, mcp_tool: OpenAIResponseInputToolMCP, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Process an MCP tool configuration and emit appropriate streaming events."""
|
||||
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
|
||||
|
||||
# Emit mcp_list_tools.in_progress
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse allowed/never allowed tools
|
||||
always_allowed = None
|
||||
never_allowed = None
|
||||
if mcp_tool.allowed_tools:
|
||||
if isinstance(mcp_tool.allowed_tools, list):
|
||||
always_allowed = mcp_tool.allowed_tools
|
||||
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
|
||||
always_allowed = mcp_tool.allowed_tools.always
|
||||
never_allowed = mcp_tool.allowed_tools.never
|
||||
|
||||
# Call list_mcp_tools
|
||||
tool_defs = await list_mcp_tools(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
)
|
||||
|
||||
# Create the MCP list tools message
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
server_label=mcp_tool.server_label,
|
||||
tools=[],
|
||||
)
|
||||
|
||||
# Process tools and update context
|
||||
for t in tool_defs.data:
|
||||
if never_allowed and t.name in never_allowed:
|
||||
continue
|
||||
if not always_allowed or t.name in always_allowed:
|
||||
# Add to chat tools for inference
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=t.name,
|
||||
description=t.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in t.parameters
|
||||
},
|
||||
)
|
||||
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool)
|
||||
|
||||
# Add to MCP tool mapping
|
||||
if t.name in self.mcp_tool_to_server:
|
||||
raise ValueError(f"Duplicate tool name {t.name} found for server {mcp_tool.server_label}")
|
||||
self.mcp_tool_to_server[t.name] = mcp_tool
|
||||
|
||||
# Add to MCP list message
|
||||
mcp_list_message.tools.append(
|
||||
MCPListToolsTool(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
p.name: {
|
||||
"type": p.parameter_type,
|
||||
"description": p.description,
|
||||
}
|
||||
for p in t.parameters
|
||||
},
|
||||
"required": [p.name for p in t.parameters if p.required],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add the MCP list message to output
|
||||
output_messages.append(mcp_list_message)
|
||||
|
||||
# Emit output_item.added for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit mcp_list_tools.completed
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit output_item.done for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Emit mcp_list_tools.failed event if needed
|
||||
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
|
||||
raise
|
||||
|
|
|
@ -10,6 +10,7 @@ from collections.abc import AsyncIterator
|
|||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseObjectStreamResponseMcpCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseMcpCallFailed,
|
||||
OpenAIResponseObjectStreamResponseMcpCallInProgress,
|
||||
|
@ -58,6 +59,7 @@ class ToolExecutor:
|
|||
sequence_number: int,
|
||||
output_index: int,
|
||||
item_id: str,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
|
@ -69,25 +71,25 @@ class ToolExecutor:
|
|||
|
||||
# Emit progress events for tool execution start
|
||||
async for event_result in self._emit_progress_events(
|
||||
function.name, ctx, sequence_number, output_index, item_id
|
||||
function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server
|
||||
):
|
||||
sequence_number = event_result.sequence_number
|
||||
yield event_result
|
||||
|
||||
# Execute the actual tool call
|
||||
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx)
|
||||
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
||||
|
||||
# Emit completion events for tool execution
|
||||
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
||||
async for event_result in self._emit_completion_events(
|
||||
function.name, ctx, sequence_number, output_index, item_id, has_error
|
||||
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
||||
):
|
||||
sequence_number = event_result.sequence_number
|
||||
yield event_result
|
||||
|
||||
# Build result messages from tool execution
|
||||
output_message, input_message = await self._build_result_messages(
|
||||
function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error
|
||||
function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
|
||||
)
|
||||
|
||||
# Yield the final result
|
||||
|
@ -161,12 +163,18 @@ class ToolExecutor:
|
|||
)
|
||||
|
||||
async def _emit_progress_events(
|
||||
self, function_name: str, ctx: ChatCompletionContext, sequence_number: int, output_index: int, item_id: str
|
||||
self,
|
||||
function_name: str,
|
||||
ctx: ChatCompletionContext,
|
||||
sequence_number: int,
|
||||
output_index: int,
|
||||
item_id: str,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit progress events for tool execution start."""
|
||||
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||
progress_event = None
|
||||
if ctx.mcp_tool_to_server and function_name in ctx.mcp_tool_to_server:
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||
item_id=item_id,
|
||||
|
@ -196,17 +204,21 @@ class ToolExecutor:
|
|||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
async def _execute_tool(
|
||||
self, function_name: str, tool_kwargs: dict, ctx: ChatCompletionContext
|
||||
self,
|
||||
function_name: str,
|
||||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> tuple[Exception | None, any]:
|
||||
"""Execute the tool and return error exception and result."""
|
||||
error_exc = None
|
||||
result = None
|
||||
|
||||
try:
|
||||
if ctx.mcp_tool_to_server and function_name in ctx.mcp_tool_to_server:
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
|
||||
|
||||
mcp_tool = ctx.mcp_tool_to_server[function_name]
|
||||
mcp_tool = mcp_tool_to_server[function_name]
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
|
@ -244,11 +256,12 @@ class ToolExecutor:
|
|||
output_index: int,
|
||||
item_id: str,
|
||||
has_error: bool,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit completion or failure events for tool execution."""
|
||||
completion_event = None
|
||||
|
||||
if ctx.mcp_tool_to_server and function_name in ctx.mcp_tool_to_server:
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
if has_error:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||
|
@ -279,6 +292,7 @@ class ToolExecutor:
|
|||
error_exc: Exception | None,
|
||||
result: any,
|
||||
has_error: bool,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> tuple[any, any]:
|
||||
"""Build output and input messages from tool execution results."""
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
@ -286,7 +300,7 @@ class ToolExecutor:
|
|||
)
|
||||
|
||||
# Build output message
|
||||
if function.name in ctx.mcp_tool_to_server:
|
||||
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
)
|
||||
|
@ -295,7 +309,7 @@ class ToolExecutor:
|
|||
id=tool_call_id,
|
||||
arguments=function.arguments,
|
||||
name=function.name,
|
||||
server_label=ctx.mcp_tool_to_server[function.name].server_label,
|
||||
server_label=mcp_tool_to_server[function.name].server_label,
|
||||
)
|
||||
if error_exc:
|
||||
message.error = str(error_exc)
|
||||
|
|
|
@ -11,7 +11,6 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
)
|
||||
|
@ -57,6 +56,5 @@ class ChatCompletionContext(BaseModel):
|
|||
messages: list[OpenAIMessageParam]
|
||||
response_tools: list[OpenAIResponseInputTool] | None = None
|
||||
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
||||
|
|
|
@ -610,6 +610,14 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
|
|||
mcp_in_progress_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.in_progress"]
|
||||
mcp_completed_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.completed"]
|
||||
|
||||
# Should have MCP list tools streaming events
|
||||
mcp_list_tools_in_progress_events = [
|
||||
chunk for chunk in chunks if chunk.type == "response.mcp_list_tools.in_progress"
|
||||
]
|
||||
mcp_list_tools_completed_events = [
|
||||
chunk for chunk in chunks if chunk.type == "response.mcp_list_tools.completed"
|
||||
]
|
||||
|
||||
# Verify we have substantial streaming activity (not just batch events)
|
||||
assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks"
|
||||
|
||||
|
@ -632,6 +640,14 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
|
|||
assert len(mcp_completed_events) > 0, (
|
||||
f"Expected response.mcp_call.completed events, got chunk types: {chunk_types}"
|
||||
)
|
||||
|
||||
# Should have MCP list tools streaming events
|
||||
assert len(mcp_list_tools_in_progress_events) > 0, (
|
||||
f"Expected response.mcp_list_tools.in_progress events, got chunk types: {chunk_types}"
|
||||
)
|
||||
assert len(mcp_list_tools_completed_events) > 0, (
|
||||
f"Expected response.mcp_list_tools.completed events, got chunk types: {chunk_types}"
|
||||
)
|
||||
# MCP failed events are optional (only if errors occur)
|
||||
|
||||
# Verify progress events have proper structure
|
||||
|
@ -643,6 +659,17 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
|
|||
for completed_event in mcp_completed_events:
|
||||
assert hasattr(completed_event, "sequence_number"), "Completed event should have 'sequence_number' field"
|
||||
|
||||
# Verify MCP list tools events have proper structure
|
||||
for list_tools_progress_event in mcp_list_tools_in_progress_events:
|
||||
assert hasattr(list_tools_progress_event, "sequence_number"), (
|
||||
"MCP list tools progress event should have 'sequence_number' field"
|
||||
)
|
||||
|
||||
for list_tools_completed_event in mcp_list_tools_completed_events:
|
||||
assert hasattr(list_tools_completed_event, "sequence_number"), (
|
||||
"MCP list tools completed event should have 'sequence_number' field"
|
||||
)
|
||||
|
||||
# Verify delta events have proper structure
|
||||
for delta_event in delta_events:
|
||||
assert hasattr(delta_event, "delta"), "Delta event should have 'delta' field"
|
||||
|
@ -662,8 +689,12 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
|
|||
assert hasattr(added_event, "output_index"), "Added event should have 'output_index' field"
|
||||
assert hasattr(added_event, "sequence_number"), "Added event should have 'sequence_number' field"
|
||||
assert hasattr(added_event, "response_id"), "Added event should have 'response_id' field"
|
||||
assert added_event.item.type in ["function_call", "mcp_call"], "Added item should be a tool call"
|
||||
assert added_event.item.status == "in_progress", "Added item should be in progress"
|
||||
assert added_event.item.type in ["function_call", "mcp_call", "mcp_list_tools"], (
|
||||
"Added item should be a tool call or MCP list tools"
|
||||
)
|
||||
if added_event.item.type in ["function_call", "mcp_call"]:
|
||||
assert added_event.item.status == "in_progress", "Added tool call should be in progress"
|
||||
# Note: mcp_list_tools doesn't have a status field, it's implicitly completed when added
|
||||
assert added_event.response_id, "Response ID should not be empty"
|
||||
assert isinstance(added_event.output_index, int), "Output index should be integer"
|
||||
assert added_event.output_index >= 0, "Output index should be non-negative"
|
||||
|
@ -674,10 +705,13 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
|
|||
assert hasattr(done_event, "output_index"), "Done event should have 'output_index' field"
|
||||
assert hasattr(done_event, "sequence_number"), "Done event should have 'sequence_number' field"
|
||||
assert hasattr(done_event, "response_id"), "Done event should have 'response_id' field"
|
||||
assert done_event.item.type in ["function_call", "mcp_call"], "Done item should be a tool call"
|
||||
# Note: MCP calls don't have a status field, only function calls do
|
||||
assert done_event.item.type in ["function_call", "mcp_call", "mcp_list_tools"], (
|
||||
"Done item should be a tool call or MCP list tools"
|
||||
)
|
||||
# Note: MCP calls and mcp_list_tools don't have a status field, only function calls do
|
||||
if done_event.item.type == "function_call":
|
||||
assert done_event.item.status == "completed", "Function call should be completed"
|
||||
# Note: mcp_call and mcp_list_tools don't have status fields
|
||||
assert done_event.response_id, "Response ID should not be empty"
|
||||
assert isinstance(done_event.output_index, int), "Output index should be integer"
|
||||
assert done_event.output_index >= 0, "Output index should be non-negative"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue