feat: add support for tool_choice to responses api (#4106)

# What does this PR do?
Adds support for enforcing tool usage via responses api. See
https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice
for details from official documentation.
Note: at present this PR only supports `file_search` and `web_search` as
options to enforce builtin tool usage

<!-- If resolving an issue, uncomment and update the line below -->
Closes #3548 

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->
`./scripts/unit-tests.sh
tests/unit/providers/agents/meta_reference/test_response_tool_context.py
`

---------

Signed-off-by: Jaideep Rao <jrao@redhat.com>
This commit is contained in:
Jaideep Rao 2025-12-16 00:52:06 +05:30 committed by GitHub
parent 62005dc1a9
commit 56f946f3f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 49989 additions and 3 deletions

View file

@ -19,6 +19,7 @@ from llama_stack_api import (
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseInputToolChoice,
OpenAIResponseObject,
OpenAIResponsePrompt,
OpenAIResponseText,
@ -105,6 +106,7 @@ class MetaReferenceAgentsImpl(Agents):
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tool_choice: OpenAIResponseInputToolChoice | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
@ -124,6 +126,7 @@ class MetaReferenceAgentsImpl(Agents):
stream,
temperature,
text,
tool_choice,
tools,
include,
max_infer_iters,

View file

@ -32,6 +32,7 @@ from llama_stack_api import (
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseInputToolChoice,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
@ -333,6 +334,7 @@ class OpenAIResponsesImpl:
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tool_choice: OpenAIResponseInputToolChoice | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[ResponseItemInclude] | None = None,
max_infer_iters: int | None = 10,
@ -390,6 +392,7 @@ class OpenAIResponsesImpl:
temperature=temperature,
text=text,
tools=tools,
tool_choice=tool_choice,
max_infer_iters=max_infer_iters,
guardrail_ids=guardrail_ids,
parallel_tool_calls=parallel_tool_calls,
@ -444,6 +447,7 @@ class OpenAIResponsesImpl:
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
tool_choice: OpenAIResponseInputToolChoice | None = None,
max_infer_iters: int | None = 10,
guardrail_ids: list[str] | None = None,
parallel_tool_calls: bool | None = True,
@ -474,6 +478,7 @@ class OpenAIResponsesImpl:
model=model,
messages=messages,
response_tools=tools,
tool_choice=tool_choice,
temperature=temperature,
response_format=response_format,
tool_context=tool_context,

View file

@ -8,6 +8,7 @@ import uuid
from collections.abc import AsyncIterator
from typing import Any
from openai.types.chat import ChatCompletionToolParam
from opentelemetry import trace
from llama_stack.log import get_logger
@ -23,6 +24,10 @@ from llama_stack_api import (
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolChoice,
OpenAIChatCompletionToolChoiceAllowedTools,
OpenAIChatCompletionToolChoiceCustomTool,
OpenAIChatCompletionToolChoiceFunctionTool,
OpenAIChoice,
OpenAIChoiceLogprobs,
OpenAIMessageParam,
@ -31,6 +36,14 @@ from llama_stack_api import (
OpenAIResponseContentPartRefusal,
OpenAIResponseError,
OpenAIResponseInputTool,
OpenAIResponseInputToolChoice,
OpenAIResponseInputToolChoiceAllowedTools,
OpenAIResponseInputToolChoiceCustomTool,
OpenAIResponseInputToolChoiceFileSearch,
OpenAIResponseInputToolChoiceFunctionTool,
OpenAIResponseInputToolChoiceMCPTool,
OpenAIResponseInputToolChoiceMode,
OpenAIResponseInputToolChoiceWebSearch,
OpenAIResponseInputToolMCP,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMessage,
@ -77,6 +90,7 @@ from llama_stack_api import (
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import (
convert_chat_choice_to_response_message,
convert_mcp_tool_choice,
is_function_tool_call,
run_guardrails,
)
@ -148,6 +162,13 @@ class StreamingResponseOrchestrator:
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
ctx.tool_context.previous_tools if ctx.tool_context else {}
)
# Reverse mapping: server_label -> list of tool names for efficient lookup
self.server_label_to_tools: dict[str, list[str]] = {}
# Build initial reverse mapping from previous_tools
for tool_name, mcp_server in self.mcp_tool_to_server.items():
if mcp_server.server_label not in self.server_label_to_tools:
self.server_label_to_tools[mcp_server.server_label] = []
self.server_label_to_tools[mcp_server.server_label].append(tool_name)
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
# mapping for annotations
@ -200,6 +221,7 @@ class StreamingResponseOrchestrator:
output=self._clone_outputs(outputs),
text=self.text,
tools=self.ctx.available_tools(),
tool_choice=self.ctx.tool_choice,
error=error,
usage=self.accumulated_usage,
instructions=self.instructions,
@ -235,6 +257,34 @@ class StreamingResponseOrchestrator:
async for stream_event in self._process_tools(output_messages):
yield stream_event
chat_tool_choice = None
# Track allowed tools for filtering (persists across iterations)
allowed_tool_names: set[str] | None = None
if self.ctx.tool_choice and len(self.ctx.chat_tools) > 0:
processed_tool_choice = await _process_tool_choice(
self.ctx.chat_tools,
self.ctx.tool_choice,
self.server_label_to_tools,
)
# chat_tool_choice can be str, dict-like object, or None
if isinstance(processed_tool_choice, str | type(None)):
chat_tool_choice = processed_tool_choice
elif isinstance(processed_tool_choice, OpenAIChatCompletionToolChoiceAllowedTools):
# For allowed_tools: filter the tools list instead of using tool_choice
# This maintains the constraint across all iterations while letting model
# decide freely whether to call a tool or respond
allowed_tool_names = {
tool["function"]["name"]
for tool in processed_tool_choice.allowed_tools.tools
if tool.get("type") == "function" and "function" in tool
}
# Use the mode (e.g., "required") for first iteration, then "auto"
chat_tool_choice = (
processed_tool_choice.allowed_tools.mode if processed_tool_choice.allowed_tools.mode else "auto"
)
else:
chat_tool_choice = processed_tool_choice.model_dump()
n_iter = 0
messages = self.ctx.messages.copy()
final_status = "completed"
@ -247,7 +297,15 @@ class StreamingResponseOrchestrator:
response_format = (
None if getattr(self.ctx.response_format, "type", None) == "text" else self.ctx.response_format
)
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
# Filter tools to only allowed ones if tool_choice specified an allowed list
effective_tools = self.ctx.chat_tools
if allowed_tool_names is not None:
effective_tools = [
tool
for tool in self.ctx.chat_tools
if tool.get("function", {}).get("name") in allowed_tool_names
]
logger.debug(f"calling openai_chat_completion with tools: {effective_tools}")
logprobs = (
True if self.include and ResponseItemInclude.message_output_text_logprobs in self.include else None
@ -257,7 +315,8 @@ class StreamingResponseOrchestrator:
model=self.ctx.model,
messages=messages,
# Pydantic models are dict-compatible but mypy treats them as distinct types
tools=self.ctx.chat_tools, # type: ignore[arg-type]
tools=effective_tools, # type: ignore[arg-type]
tool_choice=chat_tool_choice,
stream=True,
temperature=self.ctx.temperature,
response_format=response_format,
@ -335,6 +394,14 @@ class StreamingResponseOrchestrator:
break
n_iter += 1
# After first iteration, reset tool_choice to "auto" to let model decide freely
# based on tool results (prevents infinite loops when forcing specific tools)
# Note: When allowed_tool_names is set, tools are already filtered so model
# can only call allowed tools - we just need to let it decide whether to call
# a tool or respond (hence "auto" mode)
if n_iter == 1 and chat_tool_choice and chat_tool_choice != "auto":
chat_tool_choice = "auto"
if n_iter >= self.max_infer_iters:
logger.info(
f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}"
@ -1165,6 +1232,11 @@ class StreamingResponseOrchestrator:
raise ValueError(f"Duplicate tool name {t.name} found for server {mcp_tool.server_label}")
self.mcp_tool_to_server[t.name] = mcp_tool
# Add to reverse mapping for efficient server_label lookup
if mcp_tool.server_label not in self.server_label_to_tools:
self.server_label_to_tools[mcp_tool.server_label] = []
self.server_label_to_tools[mcp_tool.server_label].append(t.name)
# Add to MCP list message
mcp_list_message.tools.append(
MCPListToolsTool(
@ -1304,3 +1376,112 @@ class StreamingResponseOrchestrator:
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
yield stream_event
async def _process_tool_choice(
chat_tools: list[ChatCompletionToolParam],
tool_choice: OpenAIResponseInputToolChoice,
server_label_to_tools: dict[str, list[str]],
) -> str | OpenAIChatCompletionToolChoice | None:
"""Process and validate the OpenAI Responses tool choice and return the appropriate chat completion tool choice object.
:param chat_tools: The list of chat tools to enforce tool choice against.
:param tool_choice: The OpenAI Responses tool choice to process.
:param server_label_to_tools: A dictionary mapping server labels to the list of tools available on that server.
:return: The appropriate chat completion tool choice object.
"""
# retrieve all function tool names from the chat tools
# Note: chat_tools contains dicts, not objects
chat_tool_names = [tool["function"]["name"] for tool in chat_tools if tool["type"] == "function"]
if isinstance(tool_choice, OpenAIResponseInputToolChoiceMode):
if tool_choice.value == "required":
if len(chat_tool_names) == 0:
return None
# add all function tools to the allowed tools list and set mode to required
return OpenAIChatCompletionToolChoiceAllowedTools(
tools=[{"type": "function", "function": {"name": tool}} for tool in chat_tool_names],
mode="required",
)
# return other modes as is
return tool_choice.value
elif isinstance(tool_choice, OpenAIResponseInputToolChoiceAllowedTools):
# ensure that specified tool choices are available in the chat tools, if not, remove them from the list
final_tools = []
for tool in tool_choice.tools:
match tool.get("type"):
case "function":
final_tools.append({"type": "function", "function": {"name": tool.get("name")}})
case "custom":
final_tools.append({"type": "custom", "custom": {"name": tool.get("name")}})
case "mcp":
mcp_tools = convert_mcp_tool_choice(
chat_tool_names, tool.get("server_label"), server_label_to_tools, None
)
# convert_mcp_tool_choice can return a dict, list, or None
if isinstance(mcp_tools, list):
final_tools.extend(mcp_tools)
elif isinstance(mcp_tools, dict):
final_tools.append(mcp_tools)
# Skip if None or empty
case "file_search":
final_tools.append({"type": "function", "function": {"name": "file_search"}})
case _ if tool["type"] in WebSearchToolTypes:
final_tools.append({"type": "function", "function": {"name": "web_search"}})
case _:
logger.warning(f"Unsupported tool type: {tool['type']}, skipping tool choice enforcement for it")
continue
return OpenAIChatCompletionToolChoiceAllowedTools(
tools=final_tools,
mode=tool_choice.mode,
)
else:
# Handle specific tool choice by type
# Each case validates the tool exists in chat_tools before returning
tool_name = getattr(tool_choice, "name", None)
match tool_choice:
case OpenAIResponseInputToolChoiceCustomTool():
if tool_name and tool_name not in chat_tool_names:
logger.warning(f"Tool {tool_name} not found in chat tools")
return None
return OpenAIChatCompletionToolChoiceCustomTool(name=tool_name)
case OpenAIResponseInputToolChoiceFunctionTool():
if tool_name and tool_name not in chat_tool_names:
logger.warning(f"Tool {tool_name} not found in chat tools")
return None
return OpenAIChatCompletionToolChoiceFunctionTool(name=tool_name)
case OpenAIResponseInputToolChoiceFileSearch():
if "file_search" not in chat_tool_names:
logger.warning("Tool file_search not found in chat tools")
return None
return OpenAIChatCompletionToolChoiceFunctionTool(name="file_search")
case OpenAIResponseInputToolChoiceWebSearch():
if "web_search" not in chat_tool_names:
logger.warning("Tool web_search not found in chat tools")
return None
return OpenAIChatCompletionToolChoiceFunctionTool(name="web_search")
case OpenAIResponseInputToolChoiceMCPTool():
tool_choice = convert_mcp_tool_choice(
chat_tool_names,
tool_choice.server_label,
server_label_to_tools,
tool_name,
)
if isinstance(tool_choice, dict):
# for single tool choice, return as function tool choice
return OpenAIChatCompletionToolChoiceFunctionTool(name=tool_choice["function"]["name"])
elif isinstance(tool_choice, list):
# for multiple tool choices, return as allowed tools
return OpenAIChatCompletionToolChoiceAllowedTools(
tools=tool_choice,
mode="required",
)

View file

@ -16,6 +16,7 @@ from llama_stack_api import (
OpenAIResponseFormatParam,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseInputToolChoice,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
@ -162,6 +163,7 @@ class ChatCompletionContext(BaseModel):
temperature: float | None
response_format: OpenAIResponseFormatParam
tool_context: ToolContext | None
tool_choice: OpenAIResponseInputToolChoice | None = None
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
@ -174,6 +176,7 @@ class ChatCompletionContext(BaseModel):
response_format: OpenAIResponseFormatParam,
tool_context: ToolContext,
inputs: list[OpenAIResponseInput] | str,
tool_choice: OpenAIResponseInputToolChoice | None = None,
):
super().__init__(
model=model,
@ -182,6 +185,7 @@ class ChatCompletionContext(BaseModel):
temperature=temperature,
response_format=response_format,
tool_context=tool_context,
tool_choice=tool_choice,
)
if not isinstance(inputs, str):
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]

View file

@ -506,3 +506,28 @@ def extract_guardrail_ids(guardrails: list | None) -> list[str]:
raise ValueError(f"Unknown guardrail format: {guardrail}, expected str or ResponseGuardrailSpec")
return guardrail_ids
def convert_mcp_tool_choice(
chat_tool_names: list[str],
server_label: str | None = None,
server_label_to_tools: dict[str, list[str]] | None = None,
tool_name: str | None = None,
) -> dict[str, str] | list[dict[str, str]]:
"""Convert a responses tool choice of type mcp to a chat completions compatible function tool choice."""
if tool_name:
if tool_name not in chat_tool_names:
return None
return {"type": "function", "function": {"name": tool_name}}
elif server_label and server_label_to_tools:
# no tool name specified, so we need to enforce an allowed_tools with the function tools derived only from the given server label
# Use reverse mapping for lookup by server_label
# This already accounts for allowed_tools restrictions applied during _process_mcp_tool
tool_names = server_label_to_tools.get(server_label, [])
if not tool_names:
return None
matching_tools = [{"type": "function", "function": {"name": tool_name}} for tool_name in tool_names]
return matching_tools
return []