feat: add support for tool_choice to responses api (#4106)

# What does this PR do?
Adds support for enforcing tool usage via responses api. See
https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice
for details from official documentation.
Note: at present this PR only supports `file_search` and `web_search` as
options to enforce builtin tool usage

<!-- If resolving an issue, uncomment and update the line below -->
Closes #3548 

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->
`./scripts/unit-tests.sh
tests/unit/providers/agents/meta_reference/test_response_tool_context.py
`

---------

Signed-off-by: Jaideep Rao <jrao@redhat.com>
This commit is contained in:
Jaideep Rao 2025-12-16 00:52:06 +05:30 committed by GitHub
parent 62005dc1a9
commit 56f946f3f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 49989 additions and 3 deletions

View file

@ -19,6 +19,7 @@ from llama_stack_api import (
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseInputToolChoice,
OpenAIResponseObject,
OpenAIResponsePrompt,
OpenAIResponseText,
@ -105,6 +106,7 @@ class MetaReferenceAgentsImpl(Agents):
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tool_choice: OpenAIResponseInputToolChoice | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
@ -124,6 +126,7 @@ class MetaReferenceAgentsImpl(Agents):
stream,
temperature,
text,
tool_choice,
tools,
include,
max_infer_iters,

View file

@ -32,6 +32,7 @@ from llama_stack_api import (
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseInputToolChoice,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
@ -333,6 +334,7 @@ class OpenAIResponsesImpl:
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tool_choice: OpenAIResponseInputToolChoice | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[ResponseItemInclude] | None = None,
max_infer_iters: int | None = 10,
@ -390,6 +392,7 @@ class OpenAIResponsesImpl:
temperature=temperature,
text=text,
tools=tools,
tool_choice=tool_choice,
max_infer_iters=max_infer_iters,
guardrail_ids=guardrail_ids,
parallel_tool_calls=parallel_tool_calls,
@ -444,6 +447,7 @@ class OpenAIResponsesImpl:
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
tool_choice: OpenAIResponseInputToolChoice | None = None,
max_infer_iters: int | None = 10,
guardrail_ids: list[str] | None = None,
parallel_tool_calls: bool | None = True,
@ -474,6 +478,7 @@ class OpenAIResponsesImpl:
model=model,
messages=messages,
response_tools=tools,
tool_choice=tool_choice,
temperature=temperature,
response_format=response_format,
tool_context=tool_context,

View file

@ -8,6 +8,7 @@ import uuid
from collections.abc import AsyncIterator
from typing import Any
from openai.types.chat import ChatCompletionToolParam
from opentelemetry import trace
from llama_stack.log import get_logger
@ -23,6 +24,10 @@ from llama_stack_api import (
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolChoice,
OpenAIChatCompletionToolChoiceAllowedTools,
OpenAIChatCompletionToolChoiceCustomTool,
OpenAIChatCompletionToolChoiceFunctionTool,
OpenAIChoice,
OpenAIChoiceLogprobs,
OpenAIMessageParam,
@ -31,6 +36,14 @@ from llama_stack_api import (
OpenAIResponseContentPartRefusal,
OpenAIResponseError,
OpenAIResponseInputTool,
OpenAIResponseInputToolChoice,
OpenAIResponseInputToolChoiceAllowedTools,
OpenAIResponseInputToolChoiceCustomTool,
OpenAIResponseInputToolChoiceFileSearch,
OpenAIResponseInputToolChoiceFunctionTool,
OpenAIResponseInputToolChoiceMCPTool,
OpenAIResponseInputToolChoiceMode,
OpenAIResponseInputToolChoiceWebSearch,
OpenAIResponseInputToolMCP,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMessage,
@ -77,6 +90,7 @@ from llama_stack_api import (
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import (
convert_chat_choice_to_response_message,
convert_mcp_tool_choice,
is_function_tool_call,
run_guardrails,
)
@ -148,6 +162,13 @@ class StreamingResponseOrchestrator:
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
ctx.tool_context.previous_tools if ctx.tool_context else {}
)
# Reverse mapping: server_label -> list of tool names for efficient lookup
self.server_label_to_tools: dict[str, list[str]] = {}
# Build initial reverse mapping from previous_tools
for tool_name, mcp_server in self.mcp_tool_to_server.items():
if mcp_server.server_label not in self.server_label_to_tools:
self.server_label_to_tools[mcp_server.server_label] = []
self.server_label_to_tools[mcp_server.server_label].append(tool_name)
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
# mapping for annotations
@ -200,6 +221,7 @@ class StreamingResponseOrchestrator:
output=self._clone_outputs(outputs),
text=self.text,
tools=self.ctx.available_tools(),
tool_choice=self.ctx.tool_choice,
error=error,
usage=self.accumulated_usage,
instructions=self.instructions,
@ -235,6 +257,34 @@ class StreamingResponseOrchestrator:
async for stream_event in self._process_tools(output_messages):
yield stream_event
chat_tool_choice = None
# Track allowed tools for filtering (persists across iterations)
allowed_tool_names: set[str] | None = None
if self.ctx.tool_choice and len(self.ctx.chat_tools) > 0:
processed_tool_choice = await _process_tool_choice(
self.ctx.chat_tools,
self.ctx.tool_choice,
self.server_label_to_tools,
)
# chat_tool_choice can be str, dict-like object, or None
if isinstance(processed_tool_choice, str | type(None)):
chat_tool_choice = processed_tool_choice
elif isinstance(processed_tool_choice, OpenAIChatCompletionToolChoiceAllowedTools):
# For allowed_tools: filter the tools list instead of using tool_choice
# This maintains the constraint across all iterations while letting model
# decide freely whether to call a tool or respond
allowed_tool_names = {
tool["function"]["name"]
for tool in processed_tool_choice.allowed_tools.tools
if tool.get("type") == "function" and "function" in tool
}
# Use the mode (e.g., "required") for first iteration, then "auto"
chat_tool_choice = (
processed_tool_choice.allowed_tools.mode if processed_tool_choice.allowed_tools.mode else "auto"
)
else:
chat_tool_choice = processed_tool_choice.model_dump()
n_iter = 0
messages = self.ctx.messages.copy()
final_status = "completed"
@ -247,7 +297,15 @@ class StreamingResponseOrchestrator:
response_format = (
None if getattr(self.ctx.response_format, "type", None) == "text" else self.ctx.response_format
)
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
# Filter tools to only allowed ones if tool_choice specified an allowed list
effective_tools = self.ctx.chat_tools
if allowed_tool_names is not None:
effective_tools = [
tool
for tool in self.ctx.chat_tools
if tool.get("function", {}).get("name") in allowed_tool_names
]
logger.debug(f"calling openai_chat_completion with tools: {effective_tools}")
logprobs = (
True if self.include and ResponseItemInclude.message_output_text_logprobs in self.include else None
@ -257,7 +315,8 @@ class StreamingResponseOrchestrator:
model=self.ctx.model,
messages=messages,
# Pydantic models are dict-compatible but mypy treats them as distinct types
tools=self.ctx.chat_tools, # type: ignore[arg-type]
tools=effective_tools, # type: ignore[arg-type]
tool_choice=chat_tool_choice,
stream=True,
temperature=self.ctx.temperature,
response_format=response_format,
@ -335,6 +394,14 @@ class StreamingResponseOrchestrator:
break
n_iter += 1
# After first iteration, reset tool_choice to "auto" to let model decide freely
# based on tool results (prevents infinite loops when forcing specific tools)
# Note: When allowed_tool_names is set, tools are already filtered so model
# can only call allowed tools - we just need to let it decide whether to call
# a tool or respond (hence "auto" mode)
if n_iter == 1 and chat_tool_choice and chat_tool_choice != "auto":
chat_tool_choice = "auto"
if n_iter >= self.max_infer_iters:
logger.info(
f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}"
@ -1165,6 +1232,11 @@ class StreamingResponseOrchestrator:
raise ValueError(f"Duplicate tool name {t.name} found for server {mcp_tool.server_label}")
self.mcp_tool_to_server[t.name] = mcp_tool
# Add to reverse mapping for efficient server_label lookup
if mcp_tool.server_label not in self.server_label_to_tools:
self.server_label_to_tools[mcp_tool.server_label] = []
self.server_label_to_tools[mcp_tool.server_label].append(t.name)
# Add to MCP list message
mcp_list_message.tools.append(
MCPListToolsTool(
@ -1304,3 +1376,112 @@ class StreamingResponseOrchestrator:
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
yield stream_event
async def _process_tool_choice(
chat_tools: list[ChatCompletionToolParam],
tool_choice: OpenAIResponseInputToolChoice,
server_label_to_tools: dict[str, list[str]],
) -> str | OpenAIChatCompletionToolChoice | None:
"""Process and validate the OpenAI Responses tool choice and return the appropriate chat completion tool choice object.
:param chat_tools: The list of chat tools to enforce tool choice against.
:param tool_choice: The OpenAI Responses tool choice to process.
:param server_label_to_tools: A dictionary mapping server labels to the list of tools available on that server.
:return: The appropriate chat completion tool choice object.
"""
# retrieve all function tool names from the chat tools
# Note: chat_tools contains dicts, not objects
chat_tool_names = [tool["function"]["name"] for tool in chat_tools if tool["type"] == "function"]
if isinstance(tool_choice, OpenAIResponseInputToolChoiceMode):
if tool_choice.value == "required":
if len(chat_tool_names) == 0:
return None
# add all function tools to the allowed tools list and set mode to required
return OpenAIChatCompletionToolChoiceAllowedTools(
tools=[{"type": "function", "function": {"name": tool}} for tool in chat_tool_names],
mode="required",
)
# return other modes as is
return tool_choice.value
elif isinstance(tool_choice, OpenAIResponseInputToolChoiceAllowedTools):
# ensure that specified tool choices are available in the chat tools, if not, remove them from the list
final_tools = []
for tool in tool_choice.tools:
match tool.get("type"):
case "function":
final_tools.append({"type": "function", "function": {"name": tool.get("name")}})
case "custom":
final_tools.append({"type": "custom", "custom": {"name": tool.get("name")}})
case "mcp":
mcp_tools = convert_mcp_tool_choice(
chat_tool_names, tool.get("server_label"), server_label_to_tools, None
)
# convert_mcp_tool_choice can return a dict, list, or None
if isinstance(mcp_tools, list):
final_tools.extend(mcp_tools)
elif isinstance(mcp_tools, dict):
final_tools.append(mcp_tools)
# Skip if None or empty
case "file_search":
final_tools.append({"type": "function", "function": {"name": "file_search"}})
case _ if tool["type"] in WebSearchToolTypes:
final_tools.append({"type": "function", "function": {"name": "web_search"}})
case _:
logger.warning(f"Unsupported tool type: {tool['type']}, skipping tool choice enforcement for it")
continue
return OpenAIChatCompletionToolChoiceAllowedTools(
tools=final_tools,
mode=tool_choice.mode,
)
else:
# Handle specific tool choice by type
# Each case validates the tool exists in chat_tools before returning
tool_name = getattr(tool_choice, "name", None)
match tool_choice:
case OpenAIResponseInputToolChoiceCustomTool():
if tool_name and tool_name not in chat_tool_names:
logger.warning(f"Tool {tool_name} not found in chat tools")
return None
return OpenAIChatCompletionToolChoiceCustomTool(name=tool_name)
case OpenAIResponseInputToolChoiceFunctionTool():
if tool_name and tool_name not in chat_tool_names:
logger.warning(f"Tool {tool_name} not found in chat tools")
return None
return OpenAIChatCompletionToolChoiceFunctionTool(name=tool_name)
case OpenAIResponseInputToolChoiceFileSearch():
if "file_search" not in chat_tool_names:
logger.warning("Tool file_search not found in chat tools")
return None
return OpenAIChatCompletionToolChoiceFunctionTool(name="file_search")
case OpenAIResponseInputToolChoiceWebSearch():
if "web_search" not in chat_tool_names:
logger.warning("Tool web_search not found in chat tools")
return None
return OpenAIChatCompletionToolChoiceFunctionTool(name="web_search")
case OpenAIResponseInputToolChoiceMCPTool():
tool_choice = convert_mcp_tool_choice(
chat_tool_names,
tool_choice.server_label,
server_label_to_tools,
tool_name,
)
if isinstance(tool_choice, dict):
# for single tool choice, return as function tool choice
return OpenAIChatCompletionToolChoiceFunctionTool(name=tool_choice["function"]["name"])
elif isinstance(tool_choice, list):
# for multiple tool choices, return as allowed tools
return OpenAIChatCompletionToolChoiceAllowedTools(
tools=tool_choice,
mode="required",
)

View file

@ -16,6 +16,7 @@ from llama_stack_api import (
OpenAIResponseFormatParam,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseInputToolChoice,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
@ -162,6 +163,7 @@ class ChatCompletionContext(BaseModel):
temperature: float | None
response_format: OpenAIResponseFormatParam
tool_context: ToolContext | None
tool_choice: OpenAIResponseInputToolChoice | None = None
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
@ -174,6 +176,7 @@ class ChatCompletionContext(BaseModel):
response_format: OpenAIResponseFormatParam,
tool_context: ToolContext,
inputs: list[OpenAIResponseInput] | str,
tool_choice: OpenAIResponseInputToolChoice | None = None,
):
super().__init__(
model=model,
@ -182,6 +185,7 @@ class ChatCompletionContext(BaseModel):
temperature=temperature,
response_format=response_format,
tool_context=tool_context,
tool_choice=tool_choice,
)
if not isinstance(inputs, str):
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]

View file

@ -506,3 +506,28 @@ def extract_guardrail_ids(guardrails: list | None) -> list[str]:
raise ValueError(f"Unknown guardrail format: {guardrail}, expected str or ResponseGuardrailSpec")
return guardrail_ids
def convert_mcp_tool_choice(
chat_tool_names: list[str],
server_label: str | None = None,
server_label_to_tools: dict[str, list[str]] | None = None,
tool_name: str | None = None,
) -> dict[str, str] | list[dict[str, str]]:
"""Convert a responses tool choice of type mcp to a chat completions compatible function tool choice."""
if tool_name:
if tool_name not in chat_tool_names:
return None
return {"type": "function", "function": {"name": tool_name}}
elif server_label and server_label_to_tools:
# no tool name specified, so we need to enforce an allowed_tools with the function tools derived only from the given server label
# Use reverse mapping for lookup by server_label
# This already accounts for allowed_tools restrictions applied during _process_mcp_tool
tool_names = server_label_to_tools.get(server_label, [])
if not tool_names:
return None
matching_tools = [{"type": "function", "function": {"name": tool_name}} for tool_name in tool_names]
return matching_tools
return []

View file

@ -167,6 +167,10 @@ from .inference import (
OpenAIChatCompletionTextOnlyMessageContent,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChatCompletionToolChoice,
OpenAIChatCompletionToolChoiceAllowedTools,
OpenAIChatCompletionToolChoiceCustomTool,
OpenAIChatCompletionToolChoiceFunctionTool,
OpenAIChatCompletionUsage,
OpenAIChatCompletionUsageCompletionTokensDetails,
OpenAIChatCompletionUsagePromptTokensDetails,
@ -259,6 +263,15 @@ from .openai_responses import (
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseInputToolChoice,
OpenAIResponseInputToolChoiceAllowedTools,
OpenAIResponseInputToolChoiceCustomTool,
OpenAIResponseInputToolChoiceFileSearch,
OpenAIResponseInputToolChoiceFunctionTool,
OpenAIResponseInputToolChoiceMCPTool,
OpenAIResponseInputToolChoiceMode,
OpenAIResponseInputToolChoiceObject,
OpenAIResponseInputToolChoiceWebSearch,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
@ -635,6 +648,10 @@ __all__ = [
"OpenAIChatCompletionUsage",
"OpenAIChatCompletionUsageCompletionTokensDetails",
"OpenAIChatCompletionUsagePromptTokensDetails",
"OpenAIChatCompletionToolChoiceAllowedTools",
"OpenAIChatCompletionToolChoiceFunctionTool",
"OpenAIChatCompletionToolChoiceCustomTool",
"OpenAIChatCompletionToolChoice",
"OpenAIChoice",
"OpenAIChoiceDelta",
"OpenAIChoiceLogprobs",
@ -689,6 +706,15 @@ __all__ = [
"OpenAIResponseInputToolFunction",
"OpenAIResponseInputToolMCP",
"OpenAIResponseInputToolWebSearch",
"OpenAIResponseInputToolChoice",
"OpenAIResponseInputToolChoiceAllowedTools",
"OpenAIResponseInputToolChoiceFileSearch",
"OpenAIResponseInputToolChoiceWebSearch",
"OpenAIResponseInputToolChoiceFunctionTool",
"OpenAIResponseInputToolChoiceMCPTool",
"OpenAIResponseInputToolChoiceCustomTool",
"OpenAIResponseInputToolChoiceMode",
"OpenAIResponseInputToolChoiceObject",
"OpenAIResponseMCPApprovalRequest",
"OpenAIResponseMCPApprovalResponse",
"OpenAIResponseMessage",

View file

@ -20,6 +20,7 @@ from .openai_responses import (
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseInputToolChoice,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponsePrompt,
@ -94,6 +95,7 @@ class Agents(Protocol):
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tool_choice: OpenAIResponseInputToolChoice | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[ResponseItemInclude] | None = None,
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API

View file

@ -555,6 +555,81 @@ OpenAIResponseFormatParam = Annotated[
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
@json_schema_type
class FunctionToolConfig(BaseModel):
name: str
@json_schema_type
class OpenAIChatCompletionToolChoiceFunctionTool(BaseModel):
"""Function tool choice for OpenAI-compatible chat completion requests.
:param type: Must be "function" to indicate function tool choice
:param function: The function tool configuration
"""
type: Literal["function"] = "function"
function: FunctionToolConfig
def __init__(self, name: str):
super().__init__(type="function", function=FunctionToolConfig(name=name))
@json_schema_type
class CustomToolConfig(BaseModel):
"""Custom tool configuration for OpenAI-compatible chat completion requests.
:param name: Name of the custom tool
"""
name: str
@json_schema_type
class OpenAIChatCompletionToolChoiceCustomTool(BaseModel):
"""Custom tool choice for OpenAI-compatible chat completion requests.
:param type: Must be "custom" to indicate custom tool choice
"""
type: Literal["custom"] = "custom"
custom: CustomToolConfig
def __init__(self, name: str):
super().__init__(type="custom", custom=CustomToolConfig(name=name))
@json_schema_type
class AllowedToolsConfig(BaseModel):
tools: list[dict[str, Any]]
mode: Literal["auto", "required"]
@json_schema_type
class OpenAIChatCompletionToolChoiceAllowedTools(BaseModel):
"""Allowed tools response format for OpenAI-compatible chat completion requests.
:param type: Must be "allowed_tools" to indicate allowed tools response format
"""
type: Literal["allowed_tools"] = "allowed_tools"
allowed_tools: AllowedToolsConfig
def __init__(self, tools: list[dict[str, Any]], mode: Literal["auto", "required"]):
super().__init__(type="allowed_tools", allowed_tools=AllowedToolsConfig(tools=tools, mode=mode))
# Define the object-level union with discriminator
OpenAIChatCompletionToolChoice = Annotated[
OpenAIChatCompletionToolChoiceAllowedTools
| OpenAIChatCompletionToolChoiceFunctionTool
| OpenAIChatCompletionToolChoiceCustomTool,
Field(discriminator="type"),
]
register_schema(OpenAIChatCompletionToolChoice, name="OpenAIChatCompletionToolChoice")
@json_schema_type
class OpenAITopLogProb(BaseModel):
"""The top log probability for a token from an OpenAI-compatible chat completion response.

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from collections.abc import Sequence
from enum import Enum
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field, model_validator
@ -541,6 +542,105 @@ OpenAIResponseTool = Annotated[
register_schema(OpenAIResponseTool, name="OpenAIResponseTool")
@json_schema_type
class OpenAIResponseInputToolChoiceAllowedTools(BaseModel):
"""Constrains the tools available to the model to a pre-defined set.
:param mode: Constrains the tools available to the model to a pre-defined set
:param tools: A list of tool definitions that the model should be allowed to call
:param type: Tool choice type identifier, always "allowed_tools"
"""
mode: Literal["auto", "required"] = "auto"
tools: list[dict[str, str]]
type: Literal["allowed_tools"] = "allowed_tools"
@json_schema_type
class OpenAIResponseInputToolChoiceFileSearch(BaseModel):
"""Indicates that the model should use file search to generate a response.
:param type: Tool choice type identifier, always "file_search"
"""
type: Literal["file_search"] = "file_search"
@json_schema_type
class OpenAIResponseInputToolChoiceWebSearch(BaseModel):
"""Indicates that the model should use web search to generate a response
:param type: Web search tool type variant to use
"""
type: (
Literal["web_search"]
| Literal["web_search_preview"]
| Literal["web_search_preview_2025_03_11"]
| Literal["web_search_2025_08_26"]
) = "web_search"
@json_schema_type
class OpenAIResponseInputToolChoiceFunctionTool(BaseModel):
"""Forces the model to call a specific function.
:param name: The name of the function to call
:param type: Tool choice type identifier, always "function"
"""
name: str
type: Literal["function"] = "function"
@json_schema_type
class OpenAIResponseInputToolChoiceMCPTool(BaseModel):
"""Forces the model to call a specific tool on a remote MCP server
:param server_label: The label of the MCP server to use.
:param type: Tool choice type identifier, always "mcp"
:param name: (Optional) The name of the tool to call on the server.
"""
server_label: str
type: Literal["mcp"] = "mcp"
name: str | None = None
@json_schema_type
class OpenAIResponseInputToolChoiceCustomTool(BaseModel):
"""Forces the model to call a custom tool.
:param type: Tool choice type identifier, always "custom"
:param name: The name of the custom tool to call.
"""
type: Literal["custom"] = "custom"
name: str
class OpenAIResponseInputToolChoiceMode(str, Enum):
auto = "auto"
required = "required"
none = "none"
OpenAIResponseInputToolChoiceObject = Annotated[
OpenAIResponseInputToolChoiceAllowedTools
| OpenAIResponseInputToolChoiceFileSearch
| OpenAIResponseInputToolChoiceWebSearch
| OpenAIResponseInputToolChoiceFunctionTool
| OpenAIResponseInputToolChoiceMCPTool
| OpenAIResponseInputToolChoiceCustomTool,
Field(discriminator="type"),
]
# 3. Final Union without registration or None (Keep it clean)
OpenAIResponseInputToolChoice = OpenAIResponseInputToolChoiceMode | OpenAIResponseInputToolChoiceObject
register_schema(OpenAIResponseInputToolChoice, name="OpenAIResponseInputToolChoice")
class OpenAIResponseUsageOutputTokensDetails(BaseModel):
"""Token details for output tokens in OpenAI response usage.
@ -595,6 +695,7 @@ class OpenAIResponseObject(BaseModel):
:param text: Text formatting configuration for the response
:param top_p: (Optional) Nucleus sampling parameter used for generation
:param tools: (Optional) An array of tools the model may call while generating a response.
:param tool_choice: (Optional) Tool choice configuration for the response.
:param truncation: (Optional) Truncation strategy applied to the response
:param usage: (Optional) Token usage information for the response
:param instructions: (Optional) System message inserted into the model's context
@ -618,6 +719,7 @@ class OpenAIResponseObject(BaseModel):
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
top_p: float | None = None
tools: Sequence[OpenAIResponseTool] | None = None
tool_choice: OpenAIResponseInputToolChoice | None = None
truncation: str | None = None
usage: OpenAIResponseUsage | None = None
instructions: str | None = None