Merge branch 'main' into remove-deprecated-completion

This commit is contained in:
Matthew Farrellee 2025-09-30 17:35:19 -04:00
commit d6f1d32d3c
11 changed files with 523 additions and 23 deletions

View file

@ -237,6 +237,7 @@ class OpenAIResponsesImpl:
response_tools=tools,
temperature=temperature,
response_format=response_format,
inputs=input,
)
# Create orchestrator and delegate streaming logic

View file

@ -10,10 +10,12 @@ from typing import Any
from llama_stack.apis.agents.openai_responses import (
AllowedToolsFilter,
ApprovalFilter,
MCPListToolsTool,
OpenAIResponseContentPartOutputText,
OpenAIResponseInputTool,
OpenAIResponseInputToolMCP,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
@ -147,10 +149,17 @@ class StreamingResponseOrchestrator:
raise ValueError("Streaming chunk processor failed to return completion data")
current_response = self._build_chat_completion(completion_result_data)
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
function_tool_calls, non_function_tool_calls, approvals, next_turn_messages = self._separate_tool_calls(
current_response, messages
)
# add any approval requests required
for tool_call in approvals:
async for evt in self._add_mcp_approval_request(
tool_call.function.name, tool_call.function.arguments, output_messages
):
yield evt
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
@ -194,10 +203,11 @@ class StreamingResponseOrchestrator:
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list, list]:
"""Separate tool calls into function and non-function categories."""
function_tool_calls = []
non_function_tool_calls = []
approvals = []
next_turn_messages = messages.copy()
for choice in current_response.choices:
@ -208,9 +218,23 @@ class StreamingResponseOrchestrator:
if is_function_tool_call(tool_call, self.ctx.response_tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
if self._approval_required(tool_call.function.name):
approval_response = self.ctx.approval_response(
tool_call.function.name, tool_call.function.arguments
)
if approval_response:
if approval_response.approve:
logger.info(f"Approval granted for {tool_call.id} on {tool_call.function.name}")
non_function_tool_calls.append(tool_call)
else:
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
else:
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
approvals.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
return function_tool_calls, non_function_tool_calls, next_turn_messages
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
async def _process_streaming_chunks(
self, completion_result, output_messages: list[OpenAIResponseOutput]
@ -646,3 +670,46 @@ class StreamingResponseOrchestrator:
# TODO: Emit mcp_list_tools.failed event if needed
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
raise
def _approval_required(self, tool_name: str) -> bool:
if tool_name not in self.mcp_tool_to_server:
return False
mcp_server = self.mcp_tool_to_server[tool_name]
if mcp_server.require_approval == "always":
return True
if mcp_server.require_approval == "never":
return False
if isinstance(mcp_server, ApprovalFilter):
if tool_name in mcp_server.always:
return True
if tool_name in mcp_server.never:
return False
return True
async def _add_mcp_approval_request(
self, tool_name: str, arguments: str, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
mcp_server = self.mcp_tool_to_server[tool_name]
mcp_approval_request = OpenAIResponseMCPApprovalRequest(
arguments=arguments,
id=f"approval_{uuid.uuid4()}",
name=tool_name,
server_label=mcp_server.server_label,
)
output_messages.append(mcp_approval_request)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_approval_request,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=mcp_approval_request,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)

View file

@ -10,7 +10,10 @@ from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseObjectStream,
OpenAIResponseOutput,
)
@ -58,3 +61,37 @@ class ChatCompletionContext(BaseModel):
chat_tools: list[ChatCompletionToolParam] | None = None
temperature: float | None
response_format: OpenAIResponseFormatParam
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
def __init__(
self,
model: str,
messages: list[OpenAIMessageParam],
response_tools: list[OpenAIResponseInputTool] | None,
temperature: float | None,
response_format: OpenAIResponseFormatParam,
inputs: list[OpenAIResponseInput] | str,
):
super().__init__(
model=model,
messages=messages,
response_tools=response_tools,
temperature=temperature,
response_format=response_format,
)
if not isinstance(inputs, str):
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
self.approval_responses = {
input.approval_request_id: input for input in inputs if input.type == "mcp_approval_response"
}
def approval_response(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalResponse | None:
request = self._approval_request(tool_name, arguments)
return self.approval_responses.get(request.id, None) if request else None
def _approval_request(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalRequest | None:
for request in self.approval_requests:
if request.name == tool_name and request.arguments == arguments:
return request
return None

View file

@ -13,6 +13,8 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
@ -149,6 +151,11 @@ async def convert_response_input_to_chat_messages(
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
# the tool list will be handled separately
pass
elif isinstance(input_item, OpenAIResponseMCPApprovalRequest) or isinstance(
input_item, OpenAIResponseMCPApprovalResponse
):
# these are handled by the responses impl itself and not pass through to chat completions
pass
else:
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)