feat: add support for require_approval argument when creating response

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
Gordon Sim 2025-09-26 20:36:42 +01:00
parent 7c466a7ec5
commit 449177d316
11 changed files with 362 additions and 36 deletions

View file

@ -8614,6 +8614,12 @@
{ {
"$ref": "#/components/schemas/OpenAIResponseInputFunctionToolCallOutput" "$ref": "#/components/schemas/OpenAIResponseInputFunctionToolCallOutput"
}, },
{
"$ref": "#/components/schemas/OpenAIResponseMCPApprovalRequest"
},
{
"$ref": "#/components/schemas/OpenAIResponseMCPApprovalResponse"
},
{ {
"$ref": "#/components/schemas/OpenAIResponseMessage" "$ref": "#/components/schemas/OpenAIResponseMessage"
} }
@ -9031,6 +9037,68 @@
"title": "OpenAIResponseInputToolWebSearch", "title": "OpenAIResponseInputToolWebSearch",
"description": "Web search tool configuration for OpenAI response inputs." "description": "Web search tool configuration for OpenAI response inputs."
}, },
"OpenAIResponseMCPApprovalRequest": {
"type": "object",
"properties": {
"arguments": {
"type": "string"
},
"id": {
"type": "string"
},
"name": {
"type": "string"
},
"server_label": {
"type": "string"
},
"type": {
"type": "string",
"const": "mcp_approval_request",
"default": "mcp_approval_request"
}
},
"additionalProperties": false,
"required": [
"arguments",
"id",
"name",
"server_label",
"type"
],
"title": "OpenAIResponseMCPApprovalRequest",
"description": "A request for human approval of a tool invocation."
},
"OpenAIResponseMCPApprovalResponse": {
"type": "object",
"properties": {
"approval_request_id": {
"type": "string"
},
"approve": {
"type": "boolean"
},
"type": {
"type": "string",
"const": "mcp_approval_response",
"default": "mcp_approval_response"
},
"id": {
"type": "string"
},
"reason": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"approval_request_id",
"approve",
"type"
],
"title": "OpenAIResponseMCPApprovalResponse",
"description": "A response to an MCP approval request."
},
"OpenAIResponseMessage": { "OpenAIResponseMessage": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9539,6 +9607,9 @@
}, },
{ {
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" "$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
},
{
"$ref": "#/components/schemas/OpenAIResponseMCPApprovalRequest"
} }
], ],
"discriminator": { "discriminator": {
@ -9549,7 +9620,8 @@
"file_search_call": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall", "file_search_call": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall",
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall", "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall",
"mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall", "mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall",
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" "mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools",
"mcp_approval_request": "#/components/schemas/OpenAIResponseMCPApprovalRequest"
} }
} }
}, },

View file

@ -6254,6 +6254,8 @@ components:
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
- $ref: '#/components/schemas/OpenAIResponseInputFunctionToolCallOutput' - $ref: '#/components/schemas/OpenAIResponseInputFunctionToolCallOutput'
- $ref: '#/components/schemas/OpenAIResponseMCPApprovalRequest'
- $ref: '#/components/schemas/OpenAIResponseMCPApprovalResponse'
- $ref: '#/components/schemas/OpenAIResponseMessage' - $ref: '#/components/schemas/OpenAIResponseMessage'
"OpenAIResponseInputFunctionToolCallOutput": "OpenAIResponseInputFunctionToolCallOutput":
type: object type: object
@ -6548,6 +6550,53 @@ components:
title: OpenAIResponseInputToolWebSearch title: OpenAIResponseInputToolWebSearch
description: >- description: >-
Web search tool configuration for OpenAI response inputs. Web search tool configuration for OpenAI response inputs.
OpenAIResponseMCPApprovalRequest:
type: object
properties:
arguments:
type: string
id:
type: string
name:
type: string
server_label:
type: string
type:
type: string
const: mcp_approval_request
default: mcp_approval_request
additionalProperties: false
required:
- arguments
- id
- name
- server_label
- type
title: OpenAIResponseMCPApprovalRequest
description: >-
A request for human approval of a tool invocation.
OpenAIResponseMCPApprovalResponse:
type: object
properties:
approval_request_id:
type: string
approve:
type: boolean
type:
type: string
const: mcp_approval_response
default: mcp_approval_response
id:
type: string
reason:
type: string
additionalProperties: false
required:
- approval_request_id
- approve
- type
title: OpenAIResponseMCPApprovalResponse
description: A response to an MCP approval request.
OpenAIResponseMessage: OpenAIResponseMessage:
type: object type: object
properties: properties:
@ -6944,6 +6993,7 @@ components:
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
- $ref: '#/components/schemas/OpenAIResponseMCPApprovalRequest'
discriminator: discriminator:
propertyName: type propertyName: type
mapping: mapping:
@ -6953,6 +7003,7 @@ components:
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
mcp_approval_request: '#/components/schemas/OpenAIResponseMCPApprovalRequest'
OpenAIResponseOutputMessageMCPCall: OpenAIResponseOutputMessageMCPCall:
type: object type: object
properties: properties:

View file

@ -276,13 +276,40 @@ class OpenAIResponseOutputMessageMCPListTools(BaseModel):
tools: list[MCPListToolsTool] tools: list[MCPListToolsTool]
@json_schema_type
class OpenAIResponseMCPApprovalRequest(BaseModel):
"""
A request for human approval of a tool invocation.
"""
arguments: str
id: str
name: str
server_label: str
type: Literal["mcp_approval_request"] = "mcp_approval_request"
@json_schema_type
class OpenAIResponseMCPApprovalResponse(BaseModel):
"""
A response to an MCP approval request.
"""
approval_request_id: str
approve: bool
type: Literal["mcp_approval_response"] = "mcp_approval_response"
id: str | None = None
reason: str | None = None
OpenAIResponseOutput = Annotated[ OpenAIResponseOutput = Annotated[
OpenAIResponseMessage OpenAIResponseMessage
| OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFileSearchToolCall | OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall | OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageMCPCall | OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools, | OpenAIResponseOutputMessageMCPListTools
| OpenAIResponseMCPApprovalRequest,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput") register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
@ -725,6 +752,8 @@ OpenAIResponseInput = Annotated[
| OpenAIResponseOutputMessageFileSearchToolCall | OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall | OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseInputFunctionToolCallOutput | OpenAIResponseInputFunctionToolCallOutput
| OpenAIResponseMCPApprovalRequest
| OpenAIResponseMCPApprovalResponse
| |
# Fallback to the generic message type as a last resort # Fallback to the generic message type as a last resort
OpenAIResponseMessage, OpenAIResponseMessage,

View file

@ -238,6 +238,7 @@ class OpenAIResponsesImpl:
temperature=temperature, temperature=temperature,
response_format=response_format, response_format=response_format,
) )
ctx.extract_approvals(input)
# Create orchestrator and delegate streaming logic # Create orchestrator and delegate streaming logic
response_id = f"resp-{uuid.uuid4()}" response_id = f"resp-{uuid.uuid4()}"

View file

@ -10,10 +10,12 @@ from typing import Any
from llama_stack.apis.agents.openai_responses import ( from llama_stack.apis.agents.openai_responses import (
AllowedToolsFilter, AllowedToolsFilter,
ApprovalFilter,
MCPListToolsTool, MCPListToolsTool,
OpenAIResponseContentPartOutputText, OpenAIResponseContentPartOutputText,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseInputToolMCP, OpenAIResponseInputToolMCP,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseObject, OpenAIResponseObject,
OpenAIResponseObjectStream, OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted, OpenAIResponseObjectStreamResponseCompleted,
@ -117,10 +119,17 @@ class StreamingResponseOrchestrator:
raise ValueError("Streaming chunk processor failed to return completion data") raise ValueError("Streaming chunk processor failed to return completion data")
current_response = self._build_chat_completion(completion_result_data) current_response = self._build_chat_completion(completion_result_data)
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls( function_tool_calls, non_function_tool_calls, approvals, next_turn_messages = self._separate_tool_calls(
current_response, messages current_response, messages
) )
# add any approval requests required
for tool_call in approvals:
async for evt in self._add_mcp_approval_request(
tool_call.function.name, tool_call.function.arguments, output_messages
):
yield evt
# Handle choices with no tool calls # Handle choices with no tool calls
for choice in current_response.choices: for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools): if not (choice.message.tool_calls and self.ctx.response_tools):
@ -164,10 +173,11 @@ class StreamingResponseOrchestrator:
# Emit response.completed # Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response) yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]: def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list, list]:
"""Separate tool calls into function and non-function categories.""" """Separate tool calls into function and non-function categories."""
function_tool_calls = [] function_tool_calls = []
non_function_tool_calls = [] non_function_tool_calls = []
approvals = []
next_turn_messages = messages.copy() next_turn_messages = messages.copy()
for choice in current_response.choices: for choice in current_response.choices:
@ -177,10 +187,24 @@ class StreamingResponseOrchestrator:
for tool_call in choice.message.tool_calls: for tool_call in choice.message.tool_calls:
if is_function_tool_call(tool_call, self.ctx.response_tools): if is_function_tool_call(tool_call, self.ctx.response_tools):
function_tool_calls.append(tool_call) function_tool_calls.append(tool_call)
else:
if self._approval_required(tool_call.function.name):
approval_response = self.ctx.approval_response(
tool_call.function.name, tool_call.function.arguments
)
if approval_response:
if approval_response.approve:
logger.info(f"Approval granted for {tool_call.id} on {tool_call.function.name}")
non_function_tool_calls.append(tool_call)
else:
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
else:
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
approvals.append(tool_call)
else: else:
non_function_tool_calls.append(tool_call) non_function_tool_calls.append(tool_call)
return function_tool_calls, non_function_tool_calls, next_turn_messages return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
async def _process_streaming_chunks( async def _process_streaming_chunks(
self, completion_result, output_messages: list[OpenAIResponseOutput] self, completion_result, output_messages: list[OpenAIResponseOutput]
@ -632,3 +656,46 @@ class StreamingResponseOrchestrator:
# TODO: Emit mcp_list_tools.failed event if needed # TODO: Emit mcp_list_tools.failed event if needed
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}") logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
raise raise
def _approval_required(self, tool_name: str) -> bool:
if tool_name not in self.mcp_tool_to_server:
return False
mcp_server = self.mcp_tool_to_server[tool_name]
if mcp_server.require_approval == "always":
return True
if mcp_server.require_approval == "never":
return False
if isinstance(mcp_server, ApprovalFilter):
if tool_name in mcp_server.always:
return True
if tool_name in mcp_server.never:
return False
return True
async def _add_mcp_approval_request(
self, tool_name: str, arguments: str, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
mcp_server = self.mcp_tool_to_server[tool_name]
mcp_approval_request = OpenAIResponseMCPApprovalRequest(
arguments=arguments,
id=f"approval_{uuid.uuid4()}",
name=tool_name,
server_label=mcp_server.server_label,
)
output_messages.append(mcp_approval_request)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_approval_request,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=mcp_approval_request,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)

View file

@ -11,6 +11,8 @@ from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import ( from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseObjectStream, OpenAIResponseObjectStream,
OpenAIResponseOutput, OpenAIResponseOutput,
) )
@ -58,3 +60,24 @@ class ChatCompletionContext(BaseModel):
chat_tools: list[ChatCompletionToolParam] | None = None chat_tools: list[ChatCompletionToolParam] | None = None
temperature: float | None temperature: float | None
response_format: OpenAIResponseFormatParam response_format: OpenAIResponseFormatParam
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
def extract_approvals(self, inputs):
if not isinstance(inputs, str):
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
self.approval_responses = {
input.approval_request_id: input for input in inputs if input.type == "mcp_approval_response"
}
def approval_response(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalResponse | None:
request = self._approval_request(tool_name, arguments)
if request and request.id in self.approval_responses:
return self.approval_responses[request.id]
return None
def _approval_request(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalRequest | None:
for request in self.approval_requests:
if request.name == tool_name and request.arguments == arguments:
return request
return None

View file

@ -13,6 +13,8 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentImage, OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseMessage, OpenAIResponseMessage,
OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageContentOutputText,
@ -149,6 +151,11 @@ async def convert_response_input_to_chat_messages(
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools): elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
# the tool list will be handled separately # the tool list will be handled separately
pass pass
elif isinstance(input_item, OpenAIResponseMCPApprovalRequest) or isinstance(
input_item, OpenAIResponseMCPApprovalResponse
):
# these are handled by the responses impl itself and not pass through to chat completions
pass
else: else:
content = await convert_response_content_to_chat_content(input_item.content) content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role) message_type = await get_message_type_by_role(input_item.role)

View file

@ -246,6 +246,82 @@ def test_response_sequential_mcp_tool(compat_client, text_model_id, case):
assert "boiling point" in text_content.lower() assert "boiling point" in text_content.lower()
@pytest.mark.parametrize("case", mcp_tool_test_cases)
@pytest.mark.parametrize("approve", [True, False])
def test_response_mcp_tool_approval(compat_client, text_model_id, case, approve):
if not isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("in-process MCP server is only supported in library client")
with make_mcp_server() as mcp_server_info:
tools = setup_mcp_tools(case.tools, mcp_server_info)
for tool in tools:
tool["require_approval"] = "always"
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=tools,
stream=False,
)
assert len(response.output) >= 2
list_tools = response.output[0]
assert list_tools.type == "mcp_list_tools"
assert list_tools.server_label == "localmcp"
assert len(list_tools.tools) == 2
assert {t.name for t in list_tools.tools} == {
"get_boiling_point",
"greet_everyone",
}
approval_request = response.output[1]
assert approval_request.type == "mcp_approval_request"
assert approval_request.name == "get_boiling_point"
assert json.loads(approval_request.arguments) == {
"liquid_name": "myawesomeliquid",
"celsius": True,
}
# send approval response
response = compat_client.responses.create(
previous_response_id=response.id,
model=text_model_id,
input=[{"type": "mcp_approval_response", "approval_request_id": approval_request.id, "approve": approve}],
tools=tools,
stream=False,
)
if approve:
assert len(response.output) >= 3
list_tools = response.output[0]
assert list_tools.type == "mcp_list_tools"
assert list_tools.server_label == "localmcp"
assert len(list_tools.tools) == 2
assert {t.name for t in list_tools.tools} == {
"get_boiling_point",
"greet_everyone",
}
call = response.output[1]
assert call.type == "mcp_call"
assert call.name == "get_boiling_point"
assert json.loads(call.arguments) == {
"liquid_name": "myawesomeliquid",
"celsius": True,
}
assert call.error is None
assert "-100" in call.output
# sometimes the model will call the tool again, so we need to get the last message
message = response.output[-1]
text_content = message.content[0].text
assert "boiling point" in text_content.lower()
else:
assert len(response.output) >= 1
for output in response.output:
assert output.type != "mcp_call"
@pytest.mark.parametrize("case", custom_tool_test_cases) @pytest.mark.parametrize("case", custom_tool_test_cases)
def test_response_non_streaming_custom_tool(compat_client, text_model_id, case): def test_response_non_streaming_custom_tool(compat_client, text_model_id, case):
response = compat_client.responses.create( response = compat_client.responses.create(