fix(responses): fix subtle bugs in non-function tool calling (#3817)

We were generating "FunctionToolCall" items even for MCP (and
file-search, etc.) server-side calls. ID mismatches, etc. galore.
This commit is contained in:
Ashwin Bharambe 2025-10-15 13:57:37 -07:00 committed by GitHub
parent d709eeb33f
commit 0a96a7faa5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 10660 additions and 51 deletions

View file

@ -44,8 +44,11 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStreamResponseRefusalDone,
OpenAIResponseOutput,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseText,
OpenAIResponseUsage,
OpenAIResponseUsageInputTokensDetails,
@ -177,6 +180,7 @@ class StreamingResponseOrchestrator:
# (some providers don't support non-empty response_format when tools are present)
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.ctx.model,
messages=messages,
@ -613,19 +617,22 @@ class StreamingResponseOrchestrator:
# Emit output_item.added event for the new function call
self.sequence_number += 1
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
arguments="", # Will be filled incrementally via delta events
call_id=tool_call.id or "",
name=tool_call.function.name if tool_call.function else "",
id=tool_call_item_id,
status="in_progress",
)
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=function_call_item,
output_index=len(output_messages),
sequence_number=self.sequence_number,
)
is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server
if not is_mcp_tool and tool_call.function.name not in ["web_search", "knowledge_search"]:
# for MCP tools (and even other non-function tools) we emit an output message item later
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
arguments="", # Will be filled incrementally via delta events
call_id=tool_call.id or "",
name=tool_call.function.name if tool_call.function else "",
id=tool_call_item_id,
status="in_progress",
)
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=function_call_item,
output_index=len(output_messages),
sequence_number=self.sequence_number,
)
# Stream tool call arguments as they arrive (differentiate between MCP and function calls)
if tool_call.function and tool_call.function.arguments:
@ -806,6 +813,35 @@ class StreamingResponseOrchestrator:
if not matching_item_id:
matching_item_id = f"tc_{uuid.uuid4()}"
self.sequence_number += 1
if tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server:
item = OpenAIResponseOutputMessageMCPCall(
arguments="",
name=tool_call.function.name,
id=matching_item_id,
server_label=self.mcp_tool_to_server[tool_call.function.name].server_label,
status="in_progress",
)
elif tool_call.function.name == "web_search":
item = OpenAIResponseOutputMessageWebSearchToolCall(
id=matching_item_id,
status="in_progress",
)
elif tool_call.function.name == "knowledge_search":
item = OpenAIResponseOutputMessageFileSearchToolCall(
id=matching_item_id,
status="in_progress",
)
else:
raise ValueError(f"Unsupported tool call: {tool_call.function.name}")
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=item,
output_index=len(output_messages),
sequence_number=self.sequence_number,
)
# Execute tool call with streaming
tool_call_log = None
tool_response_message = None
@ -1064,7 +1100,11 @@ class StreamingResponseOrchestrator:
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_list_message,
item=OpenAIResponseOutputMessageMCPListTools(
id=mcp_list_message.id,
server_label=mcp_list_message.server_label,
tools=[],
),
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)

View file

@ -93,7 +93,7 @@ class ToolExecutor:
# Build result messages from tool execution
output_message, input_message = await self._build_result_messages(
function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
function, tool_call_id, item_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
)
# Yield the final result
@ -356,6 +356,7 @@ class ToolExecutor:
self,
function,
tool_call_id: str,
item_id: str,
tool_kwargs: dict,
ctx: ChatCompletionContext,
error_exc: Exception | None,
@ -375,7 +376,7 @@ class ToolExecutor:
)
message = OpenAIResponseOutputMessageMCPCall(
id=tool_call_id,
id=item_id,
arguments=function.arguments,
name=function.name,
server_label=mcp_tool_to_server[function.name].server_label,
@ -389,14 +390,14 @@ class ToolExecutor:
else:
if function.name == "web_search":
message = OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
id=item_id,
status="completed",
)
if has_error:
message.status = "failed"
elif function.name == "knowledge_search":
message = OpenAIResponseOutputMessageFileSearchToolCall(
id=tool_call_id,
id=item_id,
queries=[tool_kwargs.get("query", "")],
status="completed",
)