feat: add support for tool_choice to responses api (#4106)

# What does this PR do?
Adds support for enforcing tool usage via responses api. See
https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice
for details from official documentation.
Note: at present this PR only supports `file_search` and `web_search` as
options to enforce builtin tool usage

<!-- If resolving an issue, uncomment and update the line below -->
Closes #3548 

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->
`./scripts/unit-tests.sh
tests/unit/providers/agents/meta_reference/test_response_tool_context.py
`

---------

Signed-off-by: Jaideep Rao <jrao@redhat.com>
This commit is contained in:
Jaideep Rao 2025-12-16 00:52:06 +05:30 committed by GitHub
parent 62005dc1a9
commit 56f946f3f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 49989 additions and 3 deletions

View file

@ -167,6 +167,10 @@ from .inference import (
OpenAIChatCompletionTextOnlyMessageContent,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChatCompletionToolChoice,
OpenAIChatCompletionToolChoiceAllowedTools,
OpenAIChatCompletionToolChoiceCustomTool,
OpenAIChatCompletionToolChoiceFunctionTool,
OpenAIChatCompletionUsage,
OpenAIChatCompletionUsageCompletionTokensDetails,
OpenAIChatCompletionUsagePromptTokensDetails,
@ -259,6 +263,15 @@ from .openai_responses import (
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseInputToolChoice,
OpenAIResponseInputToolChoiceAllowedTools,
OpenAIResponseInputToolChoiceCustomTool,
OpenAIResponseInputToolChoiceFileSearch,
OpenAIResponseInputToolChoiceFunctionTool,
OpenAIResponseInputToolChoiceMCPTool,
OpenAIResponseInputToolChoiceMode,
OpenAIResponseInputToolChoiceObject,
OpenAIResponseInputToolChoiceWebSearch,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
@ -635,6 +648,10 @@ __all__ = [
"OpenAIChatCompletionUsage",
"OpenAIChatCompletionUsageCompletionTokensDetails",
"OpenAIChatCompletionUsagePromptTokensDetails",
"OpenAIChatCompletionToolChoiceAllowedTools",
"OpenAIChatCompletionToolChoiceFunctionTool",
"OpenAIChatCompletionToolChoiceCustomTool",
"OpenAIChatCompletionToolChoice",
"OpenAIChoice",
"OpenAIChoiceDelta",
"OpenAIChoiceLogprobs",
@ -689,6 +706,15 @@ __all__ = [
"OpenAIResponseInputToolFunction",
"OpenAIResponseInputToolMCP",
"OpenAIResponseInputToolWebSearch",
"OpenAIResponseInputToolChoice",
"OpenAIResponseInputToolChoiceAllowedTools",
"OpenAIResponseInputToolChoiceFileSearch",
"OpenAIResponseInputToolChoiceWebSearch",
"OpenAIResponseInputToolChoiceFunctionTool",
"OpenAIResponseInputToolChoiceMCPTool",
"OpenAIResponseInputToolChoiceCustomTool",
"OpenAIResponseInputToolChoiceMode",
"OpenAIResponseInputToolChoiceObject",
"OpenAIResponseMCPApprovalRequest",
"OpenAIResponseMCPApprovalResponse",
"OpenAIResponseMessage",

View file

@ -20,6 +20,7 @@ from .openai_responses import (
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseInputToolChoice,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponsePrompt,
@ -94,6 +95,7 @@ class Agents(Protocol):
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tool_choice: OpenAIResponseInputToolChoice | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[ResponseItemInclude] | None = None,
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API

View file

@ -555,6 +555,81 @@ OpenAIResponseFormatParam = Annotated[
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
@json_schema_type
class FunctionToolConfig(BaseModel):
name: str
@json_schema_type
class OpenAIChatCompletionToolChoiceFunctionTool(BaseModel):
"""Function tool choice for OpenAI-compatible chat completion requests.
:param type: Must be "function" to indicate function tool choice
:param function: The function tool configuration
"""
type: Literal["function"] = "function"
function: FunctionToolConfig
def __init__(self, name: str):
super().__init__(type="function", function=FunctionToolConfig(name=name))
@json_schema_type
class CustomToolConfig(BaseModel):
"""Custom tool configuration for OpenAI-compatible chat completion requests.
:param name: Name of the custom tool
"""
name: str
@json_schema_type
class OpenAIChatCompletionToolChoiceCustomTool(BaseModel):
"""Custom tool choice for OpenAI-compatible chat completion requests.
:param type: Must be "custom" to indicate custom tool choice
"""
type: Literal["custom"] = "custom"
custom: CustomToolConfig
def __init__(self, name: str):
super().__init__(type="custom", custom=CustomToolConfig(name=name))
@json_schema_type
class AllowedToolsConfig(BaseModel):
tools: list[dict[str, Any]]
mode: Literal["auto", "required"]
@json_schema_type
class OpenAIChatCompletionToolChoiceAllowedTools(BaseModel):
"""Allowed tools response format for OpenAI-compatible chat completion requests.
:param type: Must be "allowed_tools" to indicate allowed tools response format
"""
type: Literal["allowed_tools"] = "allowed_tools"
allowed_tools: AllowedToolsConfig
def __init__(self, tools: list[dict[str, Any]], mode: Literal["auto", "required"]):
super().__init__(type="allowed_tools", allowed_tools=AllowedToolsConfig(tools=tools, mode=mode))
# Define the object-level union with discriminator
OpenAIChatCompletionToolChoice = Annotated[
OpenAIChatCompletionToolChoiceAllowedTools
| OpenAIChatCompletionToolChoiceFunctionTool
| OpenAIChatCompletionToolChoiceCustomTool,
Field(discriminator="type"),
]
register_schema(OpenAIChatCompletionToolChoice, name="OpenAIChatCompletionToolChoice")
@json_schema_type
class OpenAITopLogProb(BaseModel):
"""The top log probability for a token from an OpenAI-compatible chat completion response.

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from collections.abc import Sequence
from enum import Enum
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field, model_validator
@ -541,6 +542,105 @@ OpenAIResponseTool = Annotated[
register_schema(OpenAIResponseTool, name="OpenAIResponseTool")
@json_schema_type
class OpenAIResponseInputToolChoiceAllowedTools(BaseModel):
"""Constrains the tools available to the model to a pre-defined set.
:param mode: Constrains the tools available to the model to a pre-defined set
:param tools: A list of tool definitions that the model should be allowed to call
:param type: Tool choice type identifier, always "allowed_tools"
"""
mode: Literal["auto", "required"] = "auto"
tools: list[dict[str, str]]
type: Literal["allowed_tools"] = "allowed_tools"
@json_schema_type
class OpenAIResponseInputToolChoiceFileSearch(BaseModel):
"""Indicates that the model should use file search to generate a response.
:param type: Tool choice type identifier, always "file_search"
"""
type: Literal["file_search"] = "file_search"
@json_schema_type
class OpenAIResponseInputToolChoiceWebSearch(BaseModel):
"""Indicates that the model should use web search to generate a response
:param type: Web search tool type variant to use
"""
type: (
Literal["web_search"]
| Literal["web_search_preview"]
| Literal["web_search_preview_2025_03_11"]
| Literal["web_search_2025_08_26"]
) = "web_search"
@json_schema_type
class OpenAIResponseInputToolChoiceFunctionTool(BaseModel):
"""Forces the model to call a specific function.
:param name: The name of the function to call
:param type: Tool choice type identifier, always "function"
"""
name: str
type: Literal["function"] = "function"
@json_schema_type
class OpenAIResponseInputToolChoiceMCPTool(BaseModel):
"""Forces the model to call a specific tool on a remote MCP server
:param server_label: The label of the MCP server to use.
:param type: Tool choice type identifier, always "mcp"
:param name: (Optional) The name of the tool to call on the server.
"""
server_label: str
type: Literal["mcp"] = "mcp"
name: str | None = None
@json_schema_type
class OpenAIResponseInputToolChoiceCustomTool(BaseModel):
"""Forces the model to call a custom tool.
:param type: Tool choice type identifier, always "custom"
:param name: The name of the custom tool to call.
"""
type: Literal["custom"] = "custom"
name: str
class OpenAIResponseInputToolChoiceMode(str, Enum):
auto = "auto"
required = "required"
none = "none"
OpenAIResponseInputToolChoiceObject = Annotated[
OpenAIResponseInputToolChoiceAllowedTools
| OpenAIResponseInputToolChoiceFileSearch
| OpenAIResponseInputToolChoiceWebSearch
| OpenAIResponseInputToolChoiceFunctionTool
| OpenAIResponseInputToolChoiceMCPTool
| OpenAIResponseInputToolChoiceCustomTool,
Field(discriminator="type"),
]
# 3. Final Union without registration or None (Keep it clean)
OpenAIResponseInputToolChoice = OpenAIResponseInputToolChoiceMode | OpenAIResponseInputToolChoiceObject
register_schema(OpenAIResponseInputToolChoice, name="OpenAIResponseInputToolChoice")
class OpenAIResponseUsageOutputTokensDetails(BaseModel):
"""Token details for output tokens in OpenAI response usage.
@ -595,6 +695,7 @@ class OpenAIResponseObject(BaseModel):
:param text: Text formatting configuration for the response
:param top_p: (Optional) Nucleus sampling parameter used for generation
:param tools: (Optional) An array of tools the model may call while generating a response.
:param tool_choice: (Optional) Tool choice configuration for the response.
:param truncation: (Optional) Truncation strategy applied to the response
:param usage: (Optional) Token usage information for the response
:param instructions: (Optional) System message inserted into the model's context
@ -618,6 +719,7 @@ class OpenAIResponseObject(BaseModel):
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
top_p: float | None = None
tools: Sequence[OpenAIResponseTool] | None = None
tool_choice: OpenAIResponseInputToolChoice | None = None
truncation: str | None = None
usage: OpenAIResponseUsage | None = None
instructions: str | None = None