mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: add support for require_approval argument when creating response
Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
parent
7c466a7ec5
commit
449177d316
11 changed files with 362 additions and 36 deletions
|
@ -276,13 +276,40 @@ class OpenAIResponseOutputMessageMCPListTools(BaseModel):
|
|||
tools: list[MCPListToolsTool]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseMCPApprovalRequest(BaseModel):
|
||||
"""
|
||||
A request for human approval of a tool invocation.
|
||||
"""
|
||||
|
||||
arguments: str
|
||||
id: str
|
||||
name: str
|
||||
server_label: str
|
||||
type: Literal["mcp_approval_request"] = "mcp_approval_request"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseMCPApprovalResponse(BaseModel):
|
||||
"""
|
||||
A response to an MCP approval request.
|
||||
"""
|
||||
|
||||
approval_request_id: str
|
||||
approve: bool
|
||||
type: Literal["mcp_approval_response"] = "mcp_approval_response"
|
||||
id: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
OpenAIResponseOutput = Annotated[
|
||||
OpenAIResponseMessage
|
||||
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseOutputMessageMCPCall
|
||||
| OpenAIResponseOutputMessageMCPListTools,
|
||||
| OpenAIResponseOutputMessageMCPListTools
|
||||
| OpenAIResponseMCPApprovalRequest,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||
|
@ -725,6 +752,8 @@ OpenAIResponseInput = Annotated[
|
|||
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseInputFunctionToolCallOutput
|
||||
| OpenAIResponseMCPApprovalRequest
|
||||
| OpenAIResponseMCPApprovalResponse
|
||||
|
|
||||
# Fallback to the generic message type as a last resort
|
||||
OpenAIResponseMessage,
|
||||
|
|
|
@ -238,6 +238,7 @@ class OpenAIResponsesImpl:
|
|||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
)
|
||||
ctx.extract_approvals(input)
|
||||
|
||||
# Create orchestrator and delegate streaming logic
|
||||
response_id = f"resp-{uuid.uuid4()}"
|
||||
|
|
|
@ -10,10 +10,12 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
AllowedToolsFilter,
|
||||
ApprovalFilter,
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
|
@ -117,10 +119,17 @@ class StreamingResponseOrchestrator:
|
|||
raise ValueError("Streaming chunk processor failed to return completion data")
|
||||
current_response = self._build_chat_completion(completion_result_data)
|
||||
|
||||
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
|
||||
function_tool_calls, non_function_tool_calls, approvals, next_turn_messages = self._separate_tool_calls(
|
||||
current_response, messages
|
||||
)
|
||||
|
||||
# add any approval requests required
|
||||
for tool_call in approvals:
|
||||
async for evt in self._add_mcp_approval_request(
|
||||
tool_call.function.name, tool_call.function.arguments, output_messages
|
||||
):
|
||||
yield evt
|
||||
|
||||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
|
@ -164,10 +173,11 @@ class StreamingResponseOrchestrator:
|
|||
# Emit response.completed
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
|
||||
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
|
||||
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list, list]:
|
||||
"""Separate tool calls into function and non-function categories."""
|
||||
function_tool_calls = []
|
||||
non_function_tool_calls = []
|
||||
approvals = []
|
||||
next_turn_messages = messages.copy()
|
||||
|
||||
for choice in current_response.choices:
|
||||
|
@ -178,9 +188,23 @@ class StreamingResponseOrchestrator:
|
|||
if is_function_tool_call(tool_call, self.ctx.response_tools):
|
||||
function_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
if self._approval_required(tool_call.function.name):
|
||||
approval_response = self.ctx.approval_response(
|
||||
tool_call.function.name, tool_call.function.arguments
|
||||
)
|
||||
if approval_response:
|
||||
if approval_response.approve:
|
||||
logger.info(f"Approval granted for {tool_call.id} on {tool_call.function.name}")
|
||||
non_function_tool_calls.append(tool_call)
|
||||
else:
|
||||
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
|
||||
else:
|
||||
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
|
||||
approvals.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
|
||||
return function_tool_calls, non_function_tool_calls, next_turn_messages
|
||||
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
|
||||
|
||||
async def _process_streaming_chunks(
|
||||
self, completion_result, output_messages: list[OpenAIResponseOutput]
|
||||
|
@ -632,3 +656,46 @@ class StreamingResponseOrchestrator:
|
|||
# TODO: Emit mcp_list_tools.failed event if needed
|
||||
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
|
||||
raise
|
||||
|
||||
def _approval_required(self, tool_name: str) -> bool:
|
||||
if tool_name not in self.mcp_tool_to_server:
|
||||
return False
|
||||
mcp_server = self.mcp_tool_to_server[tool_name]
|
||||
if mcp_server.require_approval == "always":
|
||||
return True
|
||||
if mcp_server.require_approval == "never":
|
||||
return False
|
||||
if isinstance(mcp_server, ApprovalFilter):
|
||||
if tool_name in mcp_server.always:
|
||||
return True
|
||||
if tool_name in mcp_server.never:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _add_mcp_approval_request(
|
||||
self, tool_name: str, arguments: str, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
mcp_server = self.mcp_tool_to_server[tool_name]
|
||||
mcp_approval_request = OpenAIResponseMCPApprovalRequest(
|
||||
arguments=arguments,
|
||||
id=f"approval_{uuid.uuid4()}",
|
||||
name=tool_name,
|
||||
server_label=mcp_server.server_label,
|
||||
)
|
||||
output_messages.append(mcp_approval_request)
|
||||
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=mcp_approval_request,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=mcp_approval_request,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
|
|
@ -11,6 +11,8 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMCPApprovalResponse,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
)
|
||||
|
@ -58,3 +60,24 @@ class ChatCompletionContext(BaseModel):
|
|||
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
||||
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
|
||||
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
|
||||
|
||||
def extract_approvals(self, inputs):
|
||||
if not isinstance(inputs, str):
|
||||
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
|
||||
self.approval_responses = {
|
||||
input.approval_request_id: input for input in inputs if input.type == "mcp_approval_response"
|
||||
}
|
||||
|
||||
def approval_response(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalResponse | None:
|
||||
request = self._approval_request(tool_name, arguments)
|
||||
if request and request.id in self.approval_responses:
|
||||
return self.approval_responses[request.id]
|
||||
return None
|
||||
|
||||
def _approval_request(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalRequest | None:
|
||||
for request in self.approval_requests:
|
||||
if request.name == tool_name and request.arguments == arguments:
|
||||
return request
|
||||
return None
|
||||
|
|
|
@ -13,6 +13,8 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMCPApprovalResponse,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
|
@ -149,6 +151,11 @@ async def convert_response_input_to_chat_messages(
|
|||
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
|
||||
# the tool list will be handled separately
|
||||
pass
|
||||
elif isinstance(input_item, OpenAIResponseMCPApprovalRequest) or isinstance(
|
||||
input_item, OpenAIResponseMCPApprovalResponse
|
||||
):
|
||||
# these are handled by the responses impl itself and not pass through to chat completions
|
||||
pass
|
||||
else:
|
||||
content = await convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await get_message_type_by_role(input_item.role)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue