mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
feat: add support for require_approval argument when creating response (#3608)
# What does this PR do? This PR adds support for the require_approval on an mcp tool definition passed to create response in the Responses API. This allows the caller to indicate whether they want to approve calls to that server, or let them be called without approval. Closes #3443 ## Test Plan Tested both approval and denial. Added automated integration test for both cases. --------- Signed-off-by: Gordon Sim <gsim@redhat.com> Co-authored-by: Matthew Farrellee <matt@cs.wisc.edu>
This commit is contained in:
parent
0837fa7bef
commit
d350e3662b
8 changed files with 360 additions and 8 deletions
86
docs/static/llama-stack-spec.html
vendored
86
docs/static/llama-stack-spec.html
vendored
|
@ -9028,6 +9028,12 @@
|
|||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseInputFunctionToolCallOutput"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseMCPApprovalRequest"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseMCPApprovalResponse"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseMessage"
|
||||
}
|
||||
|
@ -9445,6 +9451,68 @@
|
|||
"title": "OpenAIResponseInputToolWebSearch",
|
||||
"description": "Web search tool configuration for OpenAI response inputs."
|
||||
},
|
||||
"OpenAIResponseMCPApprovalRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arguments": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"server_label": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "mcp_approval_request",
|
||||
"default": "mcp_approval_request"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"arguments",
|
||||
"id",
|
||||
"name",
|
||||
"server_label",
|
||||
"type"
|
||||
],
|
||||
"title": "OpenAIResponseMCPApprovalRequest",
|
||||
"description": "A request for human approval of a tool invocation."
|
||||
},
|
||||
"OpenAIResponseMCPApprovalResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"approval_request_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"approve": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "mcp_approval_response",
|
||||
"default": "mcp_approval_response"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"reason": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"approval_request_id",
|
||||
"approve",
|
||||
"type"
|
||||
],
|
||||
"title": "OpenAIResponseMCPApprovalResponse",
|
||||
"description": "A response to an MCP approval request."
|
||||
},
|
||||
"OpenAIResponseMessage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -9949,6 +10017,9 @@
|
|||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseMCPApprovalRequest"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
|
@ -9959,7 +10030,8 @@
|
|||
"file_search_call": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall",
|
||||
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall",
|
||||
"mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall",
|
||||
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
|
||||
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools",
|
||||
"mcp_approval_request": "#/components/schemas/OpenAIResponseMCPApprovalRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -10658,6 +10730,9 @@
|
|||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseMCPApprovalRequest"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
|
@ -10668,7 +10743,8 @@
|
|||
"file_search_call": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall",
|
||||
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall",
|
||||
"mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall",
|
||||
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
|
||||
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools",
|
||||
"mcp_approval_request": "#/components/schemas/OpenAIResponseMCPApprovalRequest"
|
||||
}
|
||||
},
|
||||
"description": "The output item that was added (message, tool call, etc.)"
|
||||
|
@ -10725,6 +10801,9 @@
|
|||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseMCPApprovalRequest"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
|
@ -10735,7 +10814,8 @@
|
|||
"file_search_call": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall",
|
||||
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall",
|
||||
"mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall",
|
||||
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
|
||||
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools",
|
||||
"mcp_approval_request": "#/components/schemas/OpenAIResponseMCPApprovalRequest"
|
||||
}
|
||||
},
|
||||
"description": "The completed output item (message, tool call, etc.)"
|
||||
|
|
55
docs/static/llama-stack-spec.yaml
vendored
55
docs/static/llama-stack-spec.yaml
vendored
|
@ -6541,6 +6541,8 @@ components:
|
|||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseInputFunctionToolCallOutput'
|
||||
- $ref: '#/components/schemas/OpenAIResponseMCPApprovalRequest'
|
||||
- $ref: '#/components/schemas/OpenAIResponseMCPApprovalResponse'
|
||||
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
||||
"OpenAIResponseInputFunctionToolCallOutput":
|
||||
type: object
|
||||
|
@ -6835,6 +6837,53 @@ components:
|
|||
title: OpenAIResponseInputToolWebSearch
|
||||
description: >-
|
||||
Web search tool configuration for OpenAI response inputs.
|
||||
OpenAIResponseMCPApprovalRequest:
|
||||
type: object
|
||||
properties:
|
||||
arguments:
|
||||
type: string
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
server_label:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
const: mcp_approval_request
|
||||
default: mcp_approval_request
|
||||
additionalProperties: false
|
||||
required:
|
||||
- arguments
|
||||
- id
|
||||
- name
|
||||
- server_label
|
||||
- type
|
||||
title: OpenAIResponseMCPApprovalRequest
|
||||
description: >-
|
||||
A request for human approval of a tool invocation.
|
||||
OpenAIResponseMCPApprovalResponse:
|
||||
type: object
|
||||
properties:
|
||||
approval_request_id:
|
||||
type: string
|
||||
approve:
|
||||
type: boolean
|
||||
type:
|
||||
type: string
|
||||
const: mcp_approval_response
|
||||
default: mcp_approval_response
|
||||
id:
|
||||
type: string
|
||||
reason:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- approval_request_id
|
||||
- approve
|
||||
- type
|
||||
title: OpenAIResponseMCPApprovalResponse
|
||||
description: A response to an MCP approval request.
|
||||
OpenAIResponseMessage:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -7227,6 +7276,7 @@ components:
|
|||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||
- $ref: '#/components/schemas/OpenAIResponseMCPApprovalRequest'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
|
@ -7236,6 +7286,7 @@ components:
|
|||
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||
mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||
mcp_approval_request: '#/components/schemas/OpenAIResponseMCPApprovalRequest'
|
||||
OpenAIResponseOutputMessageMCPCall:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -7785,6 +7836,7 @@ components:
|
|||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||
- $ref: '#/components/schemas/OpenAIResponseMCPApprovalRequest'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
|
@ -7794,6 +7846,7 @@ components:
|
|||
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||
mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||
mcp_approval_request: '#/components/schemas/OpenAIResponseMCPApprovalRequest'
|
||||
description: >-
|
||||
The output item that was added (message, tool call, etc.)
|
||||
output_index:
|
||||
|
@ -7836,6 +7889,7 @@ components:
|
|||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||
- $ref: '#/components/schemas/OpenAIResponseMCPApprovalRequest'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
|
@ -7845,6 +7899,7 @@ components:
|
|||
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||
mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||
mcp_approval_request: '#/components/schemas/OpenAIResponseMCPApprovalRequest'
|
||||
description: >-
|
||||
The completed output item (message, tool call, etc.)
|
||||
output_index:
|
||||
|
|
|
@ -276,13 +276,40 @@ class OpenAIResponseOutputMessageMCPListTools(BaseModel):
|
|||
tools: list[MCPListToolsTool]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseMCPApprovalRequest(BaseModel):
|
||||
"""
|
||||
A request for human approval of a tool invocation.
|
||||
"""
|
||||
|
||||
arguments: str
|
||||
id: str
|
||||
name: str
|
||||
server_label: str
|
||||
type: Literal["mcp_approval_request"] = "mcp_approval_request"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseMCPApprovalResponse(BaseModel):
|
||||
"""
|
||||
A response to an MCP approval request.
|
||||
"""
|
||||
|
||||
approval_request_id: str
|
||||
approve: bool
|
||||
type: Literal["mcp_approval_response"] = "mcp_approval_response"
|
||||
id: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
OpenAIResponseOutput = Annotated[
|
||||
OpenAIResponseMessage
|
||||
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseOutputMessageMCPCall
|
||||
| OpenAIResponseOutputMessageMCPListTools,
|
||||
| OpenAIResponseOutputMessageMCPListTools
|
||||
| OpenAIResponseMCPApprovalRequest,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||
|
@ -723,6 +750,8 @@ OpenAIResponseInput = Annotated[
|
|||
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseInputFunctionToolCallOutput
|
||||
| OpenAIResponseMCPApprovalRequest
|
||||
| OpenAIResponseMCPApprovalResponse
|
||||
|
|
||||
# Fallback to the generic message type as a last resort
|
||||
OpenAIResponseMessage,
|
||||
|
|
|
@ -237,6 +237,7 @@ class OpenAIResponsesImpl:
|
|||
response_tools=tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
inputs=input,
|
||||
)
|
||||
|
||||
# Create orchestrator and delegate streaming logic
|
||||
|
|
|
@ -10,10 +10,12 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
AllowedToolsFilter,
|
||||
ApprovalFilter,
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
|
@ -147,10 +149,17 @@ class StreamingResponseOrchestrator:
|
|||
raise ValueError("Streaming chunk processor failed to return completion data")
|
||||
current_response = self._build_chat_completion(completion_result_data)
|
||||
|
||||
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
|
||||
function_tool_calls, non_function_tool_calls, approvals, next_turn_messages = self._separate_tool_calls(
|
||||
current_response, messages
|
||||
)
|
||||
|
||||
# add any approval requests required
|
||||
for tool_call in approvals:
|
||||
async for evt in self._add_mcp_approval_request(
|
||||
tool_call.function.name, tool_call.function.arguments, output_messages
|
||||
):
|
||||
yield evt
|
||||
|
||||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
|
@ -194,10 +203,11 @@ class StreamingResponseOrchestrator:
|
|||
# Emit response.completed
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
|
||||
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
|
||||
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list, list]:
|
||||
"""Separate tool calls into function and non-function categories."""
|
||||
function_tool_calls = []
|
||||
non_function_tool_calls = []
|
||||
approvals = []
|
||||
next_turn_messages = messages.copy()
|
||||
|
||||
for choice in current_response.choices:
|
||||
|
@ -208,9 +218,23 @@ class StreamingResponseOrchestrator:
|
|||
if is_function_tool_call(tool_call, self.ctx.response_tools):
|
||||
function_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
if self._approval_required(tool_call.function.name):
|
||||
approval_response = self.ctx.approval_response(
|
||||
tool_call.function.name, tool_call.function.arguments
|
||||
)
|
||||
if approval_response:
|
||||
if approval_response.approve:
|
||||
logger.info(f"Approval granted for {tool_call.id} on {tool_call.function.name}")
|
||||
non_function_tool_calls.append(tool_call)
|
||||
else:
|
||||
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
|
||||
else:
|
||||
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
|
||||
approvals.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
|
||||
return function_tool_calls, non_function_tool_calls, next_turn_messages
|
||||
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
|
||||
|
||||
async def _process_streaming_chunks(
|
||||
self, completion_result, output_messages: list[OpenAIResponseOutput]
|
||||
|
@ -646,3 +670,46 @@ class StreamingResponseOrchestrator:
|
|||
# TODO: Emit mcp_list_tools.failed event if needed
|
||||
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
|
||||
raise
|
||||
|
||||
def _approval_required(self, tool_name: str) -> bool:
|
||||
if tool_name not in self.mcp_tool_to_server:
|
||||
return False
|
||||
mcp_server = self.mcp_tool_to_server[tool_name]
|
||||
if mcp_server.require_approval == "always":
|
||||
return True
|
||||
if mcp_server.require_approval == "never":
|
||||
return False
|
||||
if isinstance(mcp_server, ApprovalFilter):
|
||||
if tool_name in mcp_server.always:
|
||||
return True
|
||||
if tool_name in mcp_server.never:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _add_mcp_approval_request(
|
||||
self, tool_name: str, arguments: str, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
mcp_server = self.mcp_tool_to_server[tool_name]
|
||||
mcp_approval_request = OpenAIResponseMCPApprovalRequest(
|
||||
arguments=arguments,
|
||||
id=f"approval_{uuid.uuid4()}",
|
||||
name=tool_name,
|
||||
server_label=mcp_server.server_label,
|
||||
)
|
||||
output_messages.append(mcp_approval_request)
|
||||
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=mcp_approval_request,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=mcp_approval_request,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
|
|
@ -10,7 +10,10 @@ from openai.types.chat import ChatCompletionToolParam
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMCPApprovalResponse,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
)
|
||||
|
@ -58,3 +61,37 @@ class ChatCompletionContext(BaseModel):
|
|||
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
||||
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
|
||||
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
response_tools: list[OpenAIResponseInputTool] | None,
|
||||
temperature: float | None,
|
||||
response_format: OpenAIResponseFormatParam,
|
||||
inputs: list[OpenAIResponseInput] | str,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_tools=response_tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
)
|
||||
if not isinstance(inputs, str):
|
||||
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
|
||||
self.approval_responses = {
|
||||
input.approval_request_id: input for input in inputs if input.type == "mcp_approval_response"
|
||||
}
|
||||
|
||||
def approval_response(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalResponse | None:
|
||||
request = self._approval_request(tool_name, arguments)
|
||||
return self.approval_responses.get(request.id, None) if request else None
|
||||
|
||||
def _approval_request(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalRequest | None:
|
||||
for request in self.approval_requests:
|
||||
if request.name == tool_name and request.arguments == arguments:
|
||||
return request
|
||||
return None
|
||||
|
|
|
@ -13,6 +13,8 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMCPApprovalResponse,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
|
@ -149,6 +151,11 @@ async def convert_response_input_to_chat_messages(
|
|||
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
|
||||
# the tool list will be handled separately
|
||||
pass
|
||||
elif isinstance(input_item, OpenAIResponseMCPApprovalRequest) or isinstance(
|
||||
input_item, OpenAIResponseMCPApprovalResponse
|
||||
):
|
||||
# these are handled by the responses impl itself and not pass through to chat completions
|
||||
pass
|
||||
else:
|
||||
content = await convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await get_message_type_by_role(input_item.role)
|
||||
|
|
|
@ -246,6 +246,82 @@ def test_response_sequential_mcp_tool(compat_client, text_model_id, case):
|
|||
assert "boiling point" in text_content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", mcp_tool_test_cases)
|
||||
@pytest.mark.parametrize("approve", [True, False])
|
||||
def test_response_mcp_tool_approval(compat_client, text_model_id, case, approve):
|
||||
if not isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("in-process MCP server is only supported in library client")
|
||||
|
||||
with make_mcp_server() as mcp_server_info:
|
||||
tools = setup_mcp_tools(case.tools, mcp_server_info)
|
||||
for tool in tools:
|
||||
tool["require_approval"] = "always"
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case.input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert len(response.output) >= 2
|
||||
list_tools = response.output[0]
|
||||
assert list_tools.type == "mcp_list_tools"
|
||||
assert list_tools.server_label == "localmcp"
|
||||
assert len(list_tools.tools) == 2
|
||||
assert {t.name for t in list_tools.tools} == {
|
||||
"get_boiling_point",
|
||||
"greet_everyone",
|
||||
}
|
||||
|
||||
approval_request = response.output[1]
|
||||
assert approval_request.type == "mcp_approval_request"
|
||||
assert approval_request.name == "get_boiling_point"
|
||||
assert json.loads(approval_request.arguments) == {
|
||||
"liquid_name": "myawesomeliquid",
|
||||
"celsius": True,
|
||||
}
|
||||
|
||||
# send approval response
|
||||
response = compat_client.responses.create(
|
||||
previous_response_id=response.id,
|
||||
model=text_model_id,
|
||||
input=[{"type": "mcp_approval_response", "approval_request_id": approval_request.id, "approve": approve}],
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
if approve:
|
||||
assert len(response.output) >= 3
|
||||
list_tools = response.output[0]
|
||||
assert list_tools.type == "mcp_list_tools"
|
||||
assert list_tools.server_label == "localmcp"
|
||||
assert len(list_tools.tools) == 2
|
||||
assert {t.name for t in list_tools.tools} == {
|
||||
"get_boiling_point",
|
||||
"greet_everyone",
|
||||
}
|
||||
|
||||
call = response.output[1]
|
||||
assert call.type == "mcp_call"
|
||||
assert call.name == "get_boiling_point"
|
||||
assert json.loads(call.arguments) == {
|
||||
"liquid_name": "myawesomeliquid",
|
||||
"celsius": True,
|
||||
}
|
||||
assert call.error is None
|
||||
assert "-100" in call.output
|
||||
|
||||
# sometimes the model will call the tool again, so we need to get the last message
|
||||
message = response.output[-1]
|
||||
text_content = message.content[0].text
|
||||
assert "boiling point" in text_content.lower()
|
||||
else:
|
||||
assert len(response.output) >= 1
|
||||
for output in response.output:
|
||||
assert output.type != "mcp_call"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", custom_tool_test_cases)
|
||||
def test_response_non_streaming_custom_tool(compat_client, text_model_id, case):
|
||||
response = compat_client.responses.create(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue