From d350e3662b541228aa10bb1123c36b4ae11ccd97 Mon Sep 17 00:00:00 2001 From: grs Date: Tue, 30 Sep 2025 22:18:34 +0100 Subject: [PATCH] feat: add support for require_approval argument when creating response (#3608) # What does this PR do? This PR adds support for the require_approval on an mcp tool definition passed to create response in the Responses API. This allows the caller to indicate whether they want to approve calls to that server, or let them be called without approval. Closes #3443 ## Test Plan Tested both approval and denial. Added automated integration test for both cases. --------- Signed-off-by: Gordon Sim Co-authored-by: Matthew Farrellee --- docs/static/llama-stack-spec.html | 86 ++++++++++++++++++- docs/static/llama-stack-spec.yaml | 55 ++++++++++++ llama_stack/apis/agents/openai_responses.py | 31 ++++++- .../responses/openai_responses.py | 1 + .../meta_reference/responses/streaming.py | 75 +++++++++++++++- .../agents/meta_reference/responses/types.py | 37 ++++++++ .../agents/meta_reference/responses/utils.py | 7 ++ .../responses/test_tool_responses.py | 76 ++++++++++++++++ 8 files changed, 360 insertions(+), 8 deletions(-) diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html index 97671f084..20f05a110 100644 --- a/docs/static/llama-stack-spec.html +++ b/docs/static/llama-stack-spec.html @@ -9028,6 +9028,12 @@ { "$ref": "#/components/schemas/OpenAIResponseInputFunctionToolCallOutput" }, + { + "$ref": "#/components/schemas/OpenAIResponseMCPApprovalRequest" + }, + { + "$ref": "#/components/schemas/OpenAIResponseMCPApprovalResponse" + }, { "$ref": "#/components/schemas/OpenAIResponseMessage" } @@ -9445,6 +9451,68 @@ "title": "OpenAIResponseInputToolWebSearch", "description": "Web search tool configuration for OpenAI response inputs." }, + "OpenAIResponseMCPApprovalRequest": { + "type": "object", + "properties": { + "arguments": { + "type": "string" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "server_label": { + "type": "string" + }, + "type": { + "type": "string", + "const": "mcp_approval_request", + "default": "mcp_approval_request" + } + }, + "additionalProperties": false, + "required": [ + "arguments", + "id", + "name", + "server_label", + "type" + ], + "title": "OpenAIResponseMCPApprovalRequest", + "description": "A request for human approval of a tool invocation." + }, + "OpenAIResponseMCPApprovalResponse": { + "type": "object", + "properties": { + "approval_request_id": { + "type": "string" + }, + "approve": { + "type": "boolean" + }, + "type": { + "type": "string", + "const": "mcp_approval_response", + "default": "mcp_approval_response" + }, + "id": { + "type": "string" + }, + "reason": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "approval_request_id", + "approve", + "type" + ], + "title": "OpenAIResponseMCPApprovalResponse", + "description": "A response to an MCP approval request." + }, "OpenAIResponseMessage": { "type": "object", "properties": { @@ -9949,6 +10017,9 @@ }, { "$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" + }, + { + "$ref": "#/components/schemas/OpenAIResponseMCPApprovalRequest" } ], "discriminator": { @@ -9959,7 +10030,8 @@ "file_search_call": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall", "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall", "mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall", - "mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" + "mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools", + "mcp_approval_request": "#/components/schemas/OpenAIResponseMCPApprovalRequest" } } }, @@ -10658,6 +10730,9 @@ }, { "$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" + }, + { + "$ref": "#/components/schemas/OpenAIResponseMCPApprovalRequest" } ], "discriminator": { @@ -10668,7 +10743,8 @@ "file_search_call": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall", "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall", "mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall", - "mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" + "mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools", + "mcp_approval_request": "#/components/schemas/OpenAIResponseMCPApprovalRequest" } }, "description": "The output item that was added (message, tool call, etc.)" @@ -10725,6 +10801,9 @@ }, { "$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" + }, + { + "$ref": "#/components/schemas/OpenAIResponseMCPApprovalRequest" } ], "discriminator": { @@ -10735,7 +10814,8 @@ "file_search_call": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall", "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall", "mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall", - "mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" + "mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools", + "mcp_approval_request": "#/components/schemas/OpenAIResponseMCPApprovalRequest" } }, "description": "The completed output item (message, tool call, etc.)" diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index 33a7e66d8..bf8357333 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -6541,6 +6541,8 @@ components: - $ref: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' - $ref: '#/components/schemas/OpenAIResponseInputFunctionToolCallOutput' + - $ref: '#/components/schemas/OpenAIResponseMCPApprovalRequest' + - $ref: '#/components/schemas/OpenAIResponseMCPApprovalResponse' - $ref: '#/components/schemas/OpenAIResponseMessage' "OpenAIResponseInputFunctionToolCallOutput": type: object @@ -6835,6 +6837,53 @@ components: title: OpenAIResponseInputToolWebSearch description: >- Web search tool configuration for OpenAI response inputs. + OpenAIResponseMCPApprovalRequest: + type: object + properties: + arguments: + type: string + id: + type: string + name: + type: string + server_label: + type: string + type: + type: string + const: mcp_approval_request + default: mcp_approval_request + additionalProperties: false + required: + - arguments + - id + - name + - server_label + - type + title: OpenAIResponseMCPApprovalRequest + description: >- + A request for human approval of a tool invocation. + OpenAIResponseMCPApprovalResponse: + type: object + properties: + approval_request_id: + type: string + approve: + type: boolean + type: + type: string + const: mcp_approval_response + default: mcp_approval_response + id: + type: string + reason: + type: string + additionalProperties: false + required: + - approval_request_id + - approve + - type + title: OpenAIResponseMCPApprovalResponse + description: A response to an MCP approval request. OpenAIResponseMessage: type: object properties: @@ -7227,6 +7276,7 @@ components: - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' + - $ref: '#/components/schemas/OpenAIResponseMCPApprovalRequest' discriminator: propertyName: type mapping: @@ -7236,6 +7286,7 @@ components: function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' + mcp_approval_request: '#/components/schemas/OpenAIResponseMCPApprovalRequest' OpenAIResponseOutputMessageMCPCall: type: object properties: @@ -7785,6 +7836,7 @@ components: - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' + - $ref: '#/components/schemas/OpenAIResponseMCPApprovalRequest' discriminator: propertyName: type mapping: @@ -7794,6 +7846,7 @@ components: function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' + mcp_approval_request: '#/components/schemas/OpenAIResponseMCPApprovalRequest' description: >- The output item that was added (message, tool call, etc.) output_index: @@ -7836,6 +7889,7 @@ components: - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' + - $ref: '#/components/schemas/OpenAIResponseMCPApprovalRequest' discriminator: propertyName: type mapping: @@ -7845,6 +7899,7 @@ components: function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' + mcp_approval_request: '#/components/schemas/OpenAIResponseMCPApprovalRequest' description: >- The completed output item (message, tool call, etc.) output_index: diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index b26b11f4f..190e35fd0 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -276,13 +276,40 @@ class OpenAIResponseOutputMessageMCPListTools(BaseModel): tools: list[MCPListToolsTool] +@json_schema_type +class OpenAIResponseMCPApprovalRequest(BaseModel): + """ + A request for human approval of a tool invocation. + """ + + arguments: str + id: str + name: str + server_label: str + type: Literal["mcp_approval_request"] = "mcp_approval_request" + + +@json_schema_type +class OpenAIResponseMCPApprovalResponse(BaseModel): + """ + A response to an MCP approval request. + """ + + approval_request_id: str + approve: bool + type: Literal["mcp_approval_response"] = "mcp_approval_response" + id: str | None = None + reason: str | None = None + + OpenAIResponseOutput = Annotated[ OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFileSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall | OpenAIResponseOutputMessageMCPCall - | OpenAIResponseOutputMessageMCPListTools, + | OpenAIResponseOutputMessageMCPListTools + | OpenAIResponseMCPApprovalRequest, Field(discriminator="type"), ] register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput") @@ -723,6 +750,8 @@ OpenAIResponseInput = Annotated[ | OpenAIResponseOutputMessageFileSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall | OpenAIResponseInputFunctionToolCallOutput + | OpenAIResponseMCPApprovalRequest + | OpenAIResponseMCPApprovalResponse | # Fallback to the generic message type as a last resort OpenAIResponseMessage, diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index c632e61aa..c27dc8467 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -237,6 +237,7 @@ class OpenAIResponsesImpl: response_tools=tools, temperature=temperature, response_format=response_format, + inputs=input, ) # Create orchestrator and delegate streaming logic diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 2f45ad2a3..1df37d1e6 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -10,10 +10,12 @@ from typing import Any from llama_stack.apis.agents.openai_responses import ( AllowedToolsFilter, + ApprovalFilter, MCPListToolsTool, OpenAIResponseContentPartOutputText, OpenAIResponseInputTool, OpenAIResponseInputToolMCP, + OpenAIResponseMCPApprovalRequest, OpenAIResponseObject, OpenAIResponseObjectStream, OpenAIResponseObjectStreamResponseCompleted, @@ -147,10 +149,17 @@ class StreamingResponseOrchestrator: raise ValueError("Streaming chunk processor failed to return completion data") current_response = self._build_chat_completion(completion_result_data) - function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls( + function_tool_calls, non_function_tool_calls, approvals, next_turn_messages = self._separate_tool_calls( current_response, messages ) + # add any approval requests required + for tool_call in approvals: + async for evt in self._add_mcp_approval_request( + tool_call.function.name, tool_call.function.arguments, output_messages + ): + yield evt + # Handle choices with no tool calls for choice in current_response.choices: if not (choice.message.tool_calls and self.ctx.response_tools): @@ -194,10 +203,11 @@ class StreamingResponseOrchestrator: # Emit response.completed yield OpenAIResponseObjectStreamResponseCompleted(response=final_response) - def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]: + def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list, list]: """Separate tool calls into function and non-function categories.""" function_tool_calls = [] non_function_tool_calls = [] + approvals = [] next_turn_messages = messages.copy() for choice in current_response.choices: @@ -208,9 +218,23 @@ class StreamingResponseOrchestrator: if is_function_tool_call(tool_call, self.ctx.response_tools): function_tool_calls.append(tool_call) else: - non_function_tool_calls.append(tool_call) + if self._approval_required(tool_call.function.name): + approval_response = self.ctx.approval_response( + tool_call.function.name, tool_call.function.arguments + ) + if approval_response: + if approval_response.approve: + logger.info(f"Approval granted for {tool_call.id} on {tool_call.function.name}") + non_function_tool_calls.append(tool_call) + else: + logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}") + else: + logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}") + approvals.append(tool_call) + else: + non_function_tool_calls.append(tool_call) - return function_tool_calls, non_function_tool_calls, next_turn_messages + return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages async def _process_streaming_chunks( self, completion_result, output_messages: list[OpenAIResponseOutput] @@ -646,3 +670,46 @@ class StreamingResponseOrchestrator: # TODO: Emit mcp_list_tools.failed event if needed logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}") raise + + def _approval_required(self, tool_name: str) -> bool: + if tool_name not in self.mcp_tool_to_server: + return False + mcp_server = self.mcp_tool_to_server[tool_name] + if mcp_server.require_approval == "always": + return True + if mcp_server.require_approval == "never": + return False + if isinstance(mcp_server, ApprovalFilter): + if tool_name in mcp_server.always: + return True + if tool_name in mcp_server.never: + return False + return True + + async def _add_mcp_approval_request( + self, tool_name: str, arguments: str, output_messages: list[OpenAIResponseOutput] + ) -> AsyncIterator[OpenAIResponseObjectStream]: + mcp_server = self.mcp_tool_to_server[tool_name] + mcp_approval_request = OpenAIResponseMCPApprovalRequest( + arguments=arguments, + id=f"approval_{uuid.uuid4()}", + name=tool_name, + server_label=mcp_server.server_label, + ) + output_messages.append(mcp_approval_request) + + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemAdded( + response_id=self.response_id, + item=mcp_approval_request, + output_index=len(output_messages) - 1, + sequence_number=self.sequence_number, + ) + + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemDone( + response_id=self.response_id, + item=mcp_approval_request, + output_index=len(output_messages) - 1, + sequence_number=self.sequence_number, + ) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/types.py b/llama_stack/providers/inline/agents/meta_reference/responses/types.py index 89086c262..d3b5a16bd 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/types.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/types.py @@ -10,7 +10,10 @@ from openai.types.chat import ChatCompletionToolParam from pydantic import BaseModel from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseInput, OpenAIResponseInputTool, + OpenAIResponseMCPApprovalRequest, + OpenAIResponseMCPApprovalResponse, OpenAIResponseObjectStream, OpenAIResponseOutput, ) @@ -58,3 +61,37 @@ class ChatCompletionContext(BaseModel): chat_tools: list[ChatCompletionToolParam] | None = None temperature: float | None response_format: OpenAIResponseFormatParam + approval_requests: list[OpenAIResponseMCPApprovalRequest] = [] + approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {} + + def __init__( + self, + model: str, + messages: list[OpenAIMessageParam], + response_tools: list[OpenAIResponseInputTool] | None, + temperature: float | None, + response_format: OpenAIResponseFormatParam, + inputs: list[OpenAIResponseInput] | str, + ): + super().__init__( + model=model, + messages=messages, + response_tools=response_tools, + temperature=temperature, + response_format=response_format, + ) + if not isinstance(inputs, str): + self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"] + self.approval_responses = { + input.approval_request_id: input for input in inputs if input.type == "mcp_approval_response" + } + + def approval_response(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalResponse | None: + request = self._approval_request(tool_name, arguments) + return self.approval_responses.get(request.id, None) if request else None + + def _approval_request(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalRequest | None: + for request in self.approval_requests: + if request.name == tool_name and request.arguments == arguments: + return request + return None diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 7aaeb4cd5..310a88298 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -13,6 +13,8 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseInputMessageContentImage, OpenAIResponseInputMessageContentText, OpenAIResponseInputTool, + OpenAIResponseMCPApprovalRequest, + OpenAIResponseMCPApprovalResponse, OpenAIResponseMessage, OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContentOutputText, @@ -149,6 +151,11 @@ async def convert_response_input_to_chat_messages( elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools): # the tool list will be handled separately pass + elif isinstance(input_item, OpenAIResponseMCPApprovalRequest) or isinstance( + input_item, OpenAIResponseMCPApprovalResponse + ): + # these are handled by the responses impl itself and not pass through to chat completions + pass else: content = await convert_response_content_to_chat_content(input_item.content) message_type = await get_message_type_by_role(input_item.role) diff --git a/tests/integration/responses/test_tool_responses.py b/tests/integration/responses/test_tool_responses.py index c5c9e6fc1..f23734892 100644 --- a/tests/integration/responses/test_tool_responses.py +++ b/tests/integration/responses/test_tool_responses.py @@ -246,6 +246,82 @@ def test_response_sequential_mcp_tool(compat_client, text_model_id, case): assert "boiling point" in text_content.lower() +@pytest.mark.parametrize("case", mcp_tool_test_cases) +@pytest.mark.parametrize("approve", [True, False]) +def test_response_mcp_tool_approval(compat_client, text_model_id, case, approve): + if not isinstance(compat_client, LlamaStackAsLibraryClient): + pytest.skip("in-process MCP server is only supported in library client") + + with make_mcp_server() as mcp_server_info: + tools = setup_mcp_tools(case.tools, mcp_server_info) + for tool in tools: + tool["require_approval"] = "always" + + response = compat_client.responses.create( + model=text_model_id, + input=case.input, + tools=tools, + stream=False, + ) + + assert len(response.output) >= 2 + list_tools = response.output[0] + assert list_tools.type == "mcp_list_tools" + assert list_tools.server_label == "localmcp" + assert len(list_tools.tools) == 2 + assert {t.name for t in list_tools.tools} == { + "get_boiling_point", + "greet_everyone", + } + + approval_request = response.output[1] + assert approval_request.type == "mcp_approval_request" + assert approval_request.name == "get_boiling_point" + assert json.loads(approval_request.arguments) == { + "liquid_name": "myawesomeliquid", + "celsius": True, + } + + # send approval response + response = compat_client.responses.create( + previous_response_id=response.id, + model=text_model_id, + input=[{"type": "mcp_approval_response", "approval_request_id": approval_request.id, "approve": approve}], + tools=tools, + stream=False, + ) + + if approve: + assert len(response.output) >= 3 + list_tools = response.output[0] + assert list_tools.type == "mcp_list_tools" + assert list_tools.server_label == "localmcp" + assert len(list_tools.tools) == 2 + assert {t.name for t in list_tools.tools} == { + "get_boiling_point", + "greet_everyone", + } + + call = response.output[1] + assert call.type == "mcp_call" + assert call.name == "get_boiling_point" + assert json.loads(call.arguments) == { + "liquid_name": "myawesomeliquid", + "celsius": True, + } + assert call.error is None + assert "-100" in call.output + + # sometimes the model will call the tool again, so we need to get the last message + message = response.output[-1] + text_content = message.content[0].text + assert "boiling point" in text_content.lower() + else: + assert len(response.output) >= 1 + for output in response.output: + assert output.type != "mcp_call" + + @pytest.mark.parametrize("case", custom_tool_test_cases) def test_response_non_streaming_custom_tool(compat_client, text_model_id, case): response = compat_client.responses.create(