mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-21 07:22:25 +00:00
feat: add support for tool_choice to responses api (#4106)
# What does this PR do? Adds support for enforcing tool usage via responses api. See https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice for details from official documentation. Note: at present this PR only supports `file_search` and `web_search` as options to enforce builtin tool usage <!-- If resolving an issue, uncomment and update the line below --> Closes #3548 ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> `./scripts/unit-tests.sh tests/unit/providers/agents/meta_reference/test_response_tool_context.py ` --------- Signed-off-by: Jaideep Rao <jrao@redhat.com>
This commit is contained in:
parent
62005dc1a9
commit
56f946f3f5
18 changed files with 49989 additions and 3 deletions
|
|
@ -167,6 +167,10 @@ from .inference import (
|
|||
OpenAIChatCompletionTextOnlyMessageContent,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChatCompletionToolChoice,
|
||||
OpenAIChatCompletionToolChoiceAllowedTools,
|
||||
OpenAIChatCompletionToolChoiceCustomTool,
|
||||
OpenAIChatCompletionToolChoiceFunctionTool,
|
||||
OpenAIChatCompletionUsage,
|
||||
OpenAIChatCompletionUsageCompletionTokensDetails,
|
||||
OpenAIChatCompletionUsagePromptTokensDetails,
|
||||
|
|
@ -259,6 +263,15 @@ from .openai_responses import (
|
|||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolChoice,
|
||||
OpenAIResponseInputToolChoiceAllowedTools,
|
||||
OpenAIResponseInputToolChoiceCustomTool,
|
||||
OpenAIResponseInputToolChoiceFileSearch,
|
||||
OpenAIResponseInputToolChoiceFunctionTool,
|
||||
OpenAIResponseInputToolChoiceMCPTool,
|
||||
OpenAIResponseInputToolChoiceMode,
|
||||
OpenAIResponseInputToolChoiceObject,
|
||||
OpenAIResponseInputToolChoiceWebSearch,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
|
|
@ -635,6 +648,10 @@ __all__ = [
|
|||
"OpenAIChatCompletionUsage",
|
||||
"OpenAIChatCompletionUsageCompletionTokensDetails",
|
||||
"OpenAIChatCompletionUsagePromptTokensDetails",
|
||||
"OpenAIChatCompletionToolChoiceAllowedTools",
|
||||
"OpenAIChatCompletionToolChoiceFunctionTool",
|
||||
"OpenAIChatCompletionToolChoiceCustomTool",
|
||||
"OpenAIChatCompletionToolChoice",
|
||||
"OpenAIChoice",
|
||||
"OpenAIChoiceDelta",
|
||||
"OpenAIChoiceLogprobs",
|
||||
|
|
@ -689,6 +706,15 @@ __all__ = [
|
|||
"OpenAIResponseInputToolFunction",
|
||||
"OpenAIResponseInputToolMCP",
|
||||
"OpenAIResponseInputToolWebSearch",
|
||||
"OpenAIResponseInputToolChoice",
|
||||
"OpenAIResponseInputToolChoiceAllowedTools",
|
||||
"OpenAIResponseInputToolChoiceFileSearch",
|
||||
"OpenAIResponseInputToolChoiceWebSearch",
|
||||
"OpenAIResponseInputToolChoiceFunctionTool",
|
||||
"OpenAIResponseInputToolChoiceMCPTool",
|
||||
"OpenAIResponseInputToolChoiceCustomTool",
|
||||
"OpenAIResponseInputToolChoiceMode",
|
||||
"OpenAIResponseInputToolChoiceObject",
|
||||
"OpenAIResponseMCPApprovalRequest",
|
||||
"OpenAIResponseMCPApprovalResponse",
|
||||
"OpenAIResponseMessage",
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from .openai_responses import (
|
|||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolChoice,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponsePrompt,
|
||||
|
|
@ -94,6 +95,7 @@ class Agents(Protocol):
|
|||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tool_choice: OpenAIResponseInputToolChoice | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[ResponseItemInclude] | None = None,
|
||||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||
|
|
|
|||
|
|
@ -555,6 +555,81 @@ OpenAIResponseFormatParam = Annotated[
|
|||
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FunctionToolConfig(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolChoiceFunctionTool(BaseModel):
|
||||
"""Function tool choice for OpenAI-compatible chat completion requests.
|
||||
|
||||
:param type: Must be "function" to indicate function tool choice
|
||||
:param function: The function tool configuration
|
||||
"""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionToolConfig
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(type="function", function=FunctionToolConfig(name=name))
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CustomToolConfig(BaseModel):
|
||||
"""Custom tool configuration for OpenAI-compatible chat completion requests.
|
||||
|
||||
:param name: Name of the custom tool
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolChoiceCustomTool(BaseModel):
|
||||
"""Custom tool choice for OpenAI-compatible chat completion requests.
|
||||
|
||||
:param type: Must be "custom" to indicate custom tool choice
|
||||
"""
|
||||
|
||||
type: Literal["custom"] = "custom"
|
||||
custom: CustomToolConfig
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(type="custom", custom=CustomToolConfig(name=name))
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AllowedToolsConfig(BaseModel):
|
||||
tools: list[dict[str, Any]]
|
||||
mode: Literal["auto", "required"]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolChoiceAllowedTools(BaseModel):
|
||||
"""Allowed tools response format for OpenAI-compatible chat completion requests.
|
||||
|
||||
:param type: Must be "allowed_tools" to indicate allowed tools response format
|
||||
"""
|
||||
|
||||
type: Literal["allowed_tools"] = "allowed_tools"
|
||||
allowed_tools: AllowedToolsConfig
|
||||
|
||||
def __init__(self, tools: list[dict[str, Any]], mode: Literal["auto", "required"]):
|
||||
super().__init__(type="allowed_tools", allowed_tools=AllowedToolsConfig(tools=tools, mode=mode))
|
||||
|
||||
|
||||
# Define the object-level union with discriminator
|
||||
OpenAIChatCompletionToolChoice = Annotated[
|
||||
OpenAIChatCompletionToolChoiceAllowedTools
|
||||
| OpenAIChatCompletionToolChoiceFunctionTool
|
||||
| OpenAIChatCompletionToolChoiceCustomTool,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
register_schema(OpenAIChatCompletionToolChoice, name="OpenAIChatCompletionToolChoice")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAITopLogProb(BaseModel):
|
||||
"""The top log probability for a token from an OpenAI-compatible chat completion response.
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
|
@ -541,6 +542,105 @@ OpenAIResponseTool = Annotated[
|
|||
register_schema(OpenAIResponseTool, name="OpenAIResponseTool")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolChoiceAllowedTools(BaseModel):
|
||||
"""Constrains the tools available to the model to a pre-defined set.
|
||||
|
||||
:param mode: Constrains the tools available to the model to a pre-defined set
|
||||
:param tools: A list of tool definitions that the model should be allowed to call
|
||||
:param type: Tool choice type identifier, always "allowed_tools"
|
||||
"""
|
||||
|
||||
mode: Literal["auto", "required"] = "auto"
|
||||
tools: list[dict[str, str]]
|
||||
type: Literal["allowed_tools"] = "allowed_tools"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolChoiceFileSearch(BaseModel):
|
||||
"""Indicates that the model should use file search to generate a response.
|
||||
|
||||
:param type: Tool choice type identifier, always "file_search"
|
||||
"""
|
||||
|
||||
type: Literal["file_search"] = "file_search"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolChoiceWebSearch(BaseModel):
|
||||
"""Indicates that the model should use web search to generate a response
|
||||
|
||||
:param type: Web search tool type variant to use
|
||||
"""
|
||||
|
||||
type: (
|
||||
Literal["web_search"]
|
||||
| Literal["web_search_preview"]
|
||||
| Literal["web_search_preview_2025_03_11"]
|
||||
| Literal["web_search_2025_08_26"]
|
||||
) = "web_search"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolChoiceFunctionTool(BaseModel):
|
||||
"""Forces the model to call a specific function.
|
||||
|
||||
:param name: The name of the function to call
|
||||
:param type: Tool choice type identifier, always "function"
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: Literal["function"] = "function"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolChoiceMCPTool(BaseModel):
|
||||
"""Forces the model to call a specific tool on a remote MCP server
|
||||
|
||||
:param server_label: The label of the MCP server to use.
|
||||
:param type: Tool choice type identifier, always "mcp"
|
||||
:param name: (Optional) The name of the tool to call on the server.
|
||||
"""
|
||||
|
||||
server_label: str
|
||||
type: Literal["mcp"] = "mcp"
|
||||
name: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolChoiceCustomTool(BaseModel):
|
||||
"""Forces the model to call a custom tool.
|
||||
|
||||
:param type: Tool choice type identifier, always "custom"
|
||||
:param name: The name of the custom tool to call.
|
||||
"""
|
||||
|
||||
type: Literal["custom"] = "custom"
|
||||
name: str
|
||||
|
||||
|
||||
class OpenAIResponseInputToolChoiceMode(str, Enum):
|
||||
auto = "auto"
|
||||
required = "required"
|
||||
none = "none"
|
||||
|
||||
|
||||
OpenAIResponseInputToolChoiceObject = Annotated[
|
||||
OpenAIResponseInputToolChoiceAllowedTools
|
||||
| OpenAIResponseInputToolChoiceFileSearch
|
||||
| OpenAIResponseInputToolChoiceWebSearch
|
||||
| OpenAIResponseInputToolChoiceFunctionTool
|
||||
| OpenAIResponseInputToolChoiceMCPTool
|
||||
| OpenAIResponseInputToolChoiceCustomTool,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
# 3. Final Union without registration or None (Keep it clean)
|
||||
OpenAIResponseInputToolChoice = OpenAIResponseInputToolChoiceMode | OpenAIResponseInputToolChoiceObject
|
||||
|
||||
register_schema(OpenAIResponseInputToolChoice, name="OpenAIResponseInputToolChoice")
|
||||
|
||||
|
||||
class OpenAIResponseUsageOutputTokensDetails(BaseModel):
|
||||
"""Token details for output tokens in OpenAI response usage.
|
||||
|
||||
|
|
@ -595,6 +695,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
:param text: Text formatting configuration for the response
|
||||
:param top_p: (Optional) Nucleus sampling parameter used for generation
|
||||
:param tools: (Optional) An array of tools the model may call while generating a response.
|
||||
:param tool_choice: (Optional) Tool choice configuration for the response.
|
||||
:param truncation: (Optional) Truncation strategy applied to the response
|
||||
:param usage: (Optional) Token usage information for the response
|
||||
:param instructions: (Optional) System message inserted into the model's context
|
||||
|
|
@ -618,6 +719,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
top_p: float | None = None
|
||||
tools: Sequence[OpenAIResponseTool] | None = None
|
||||
tool_choice: OpenAIResponseInputToolChoice | None = None
|
||||
truncation: str | None = None
|
||||
usage: OpenAIResponseUsage | None = None
|
||||
instructions: str | None = None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue