address review comments

This commit is contained in:
Gordon Sim 2025-09-30 21:58:09 +01:00
parent 142908a7e1
commit 5ac7e1fa12
2 changed files with 19 additions and 5 deletions

View file

@ -237,8 +237,8 @@ class OpenAIResponsesImpl:
response_tools=tools, response_tools=tools,
temperature=temperature, temperature=temperature,
response_format=response_format, response_format=response_format,
inputs=input,
) )
ctx.extract_approvals(input)
# Create orchestrator and delegate streaming logic # Create orchestrator and delegate streaming logic
response_id = f"resp-{uuid.uuid4()}" response_id = f"resp-{uuid.uuid4()}"

View file

@ -10,6 +10,7 @@ from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import ( from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseMCPApprovalRequest, OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse, OpenAIResponseMCPApprovalResponse,
@ -63,7 +64,22 @@ class ChatCompletionContext(BaseModel):
approval_requests: list[OpenAIResponseMCPApprovalRequest] = [] approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {} approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
def extract_approvals(self, inputs): def __init__(
self,
model: str,
messages: list[OpenAIMessageParam],
response_tools: list[OpenAIResponseInputTool] | None,
temperature: float | None,
response_format: OpenAIResponseFormatParam,
inputs: list[OpenAIResponseInput] | str,
):
super().__init__(
model=model,
messages=messages,
response_tools=response_tools,
temperature=temperature,
response_format=response_format,
)
if not isinstance(inputs, str): if not isinstance(inputs, str):
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"] self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
self.approval_responses = { self.approval_responses = {
@ -72,9 +88,7 @@ class ChatCompletionContext(BaseModel):
def approval_response(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalResponse | None: def approval_response(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalResponse | None:
request = self._approval_request(tool_name, arguments) request = self._approval_request(tool_name, arguments)
if request and request.id in self.approval_responses: return self.approval_responses.get(request.id, None) if request else None
return self.approval_responses[request.id]
return None
def _approval_request(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalRequest | None: def _approval_request(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalRequest | None:
for request in self.approval_requests: for request in self.approval_requests: