mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-21 07:22:25 +00:00
feat: add support for tool_choice to responses api (#4106)
# What does this PR do? Adds support for enforcing tool usage via responses api. See https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice for details from official documentation. Note: at present this PR only supports `file_search` and `web_search` as options to enforce builtin tool usage <!-- If resolving an issue, uncomment and update the line below --> Closes #3548 ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> `./scripts/unit-tests.sh tests/unit/providers/agents/meta_reference/test_response_tool_context.py ` --------- Signed-off-by: Jaideep Rao <jrao@redhat.com>
This commit is contained in:
parent
62005dc1a9
commit
56f946f3f5
18 changed files with 49989 additions and 3 deletions
|
|
@ -19,6 +19,7 @@ from llama_stack_api import (
|
|||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolChoice,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponsePrompt,
|
||||
OpenAIResponseText,
|
||||
|
|
@ -105,6 +106,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tool_choice: OpenAIResponseInputToolChoice | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
|
|
@ -124,6 +126,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
stream,
|
||||
temperature,
|
||||
text,
|
||||
tool_choice,
|
||||
tools,
|
||||
include,
|
||||
max_infer_iters,
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from llama_stack_api import (
|
|||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolChoice,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
|
|
@ -333,6 +334,7 @@ class OpenAIResponsesImpl:
|
|||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tool_choice: OpenAIResponseInputToolChoice | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[ResponseItemInclude] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
|
|
@ -390,6 +392,7 @@ class OpenAIResponsesImpl:
|
|||
temperature=temperature,
|
||||
text=text,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
max_infer_iters=max_infer_iters,
|
||||
guardrail_ids=guardrail_ids,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
|
|
@ -444,6 +447,7 @@ class OpenAIResponsesImpl:
|
|||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
tool_choice: OpenAIResponseInputToolChoice | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
parallel_tool_calls: bool | None = True,
|
||||
|
|
@ -474,6 +478,7 @@ class OpenAIResponsesImpl:
|
|||
model=model,
|
||||
messages=messages,
|
||||
response_tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
tool_context=tool_context,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import uuid
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from opentelemetry import trace
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -23,6 +24,10 @@ from llama_stack_api import (
|
|||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolChoice,
|
||||
OpenAIChatCompletionToolChoiceAllowedTools,
|
||||
OpenAIChatCompletionToolChoiceCustomTool,
|
||||
OpenAIChatCompletionToolChoiceFunctionTool,
|
||||
OpenAIChoice,
|
||||
OpenAIChoiceLogprobs,
|
||||
OpenAIMessageParam,
|
||||
|
|
@ -31,6 +36,14 @@ from llama_stack_api import (
|
|||
OpenAIResponseContentPartRefusal,
|
||||
OpenAIResponseError,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolChoice,
|
||||
OpenAIResponseInputToolChoiceAllowedTools,
|
||||
OpenAIResponseInputToolChoiceCustomTool,
|
||||
OpenAIResponseInputToolChoiceFileSearch,
|
||||
OpenAIResponseInputToolChoiceFunctionTool,
|
||||
OpenAIResponseInputToolChoiceMCPTool,
|
||||
OpenAIResponseInputToolChoiceMode,
|
||||
OpenAIResponseInputToolChoiceWebSearch,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMessage,
|
||||
|
|
@ -77,6 +90,7 @@ from llama_stack_api import (
|
|||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
from .utils import (
|
||||
convert_chat_choice_to_response_message,
|
||||
convert_mcp_tool_choice,
|
||||
is_function_tool_call,
|
||||
run_guardrails,
|
||||
)
|
||||
|
|
@ -148,6 +162,13 @@ class StreamingResponseOrchestrator:
|
|||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
||||
ctx.tool_context.previous_tools if ctx.tool_context else {}
|
||||
)
|
||||
# Reverse mapping: server_label -> list of tool names for efficient lookup
|
||||
self.server_label_to_tools: dict[str, list[str]] = {}
|
||||
# Build initial reverse mapping from previous_tools
|
||||
for tool_name, mcp_server in self.mcp_tool_to_server.items():
|
||||
if mcp_server.server_label not in self.server_label_to_tools:
|
||||
self.server_label_to_tools[mcp_server.server_label] = []
|
||||
self.server_label_to_tools[mcp_server.server_label].append(tool_name)
|
||||
# Track final messages after all tool executions
|
||||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
# mapping for annotations
|
||||
|
|
@ -200,6 +221,7 @@ class StreamingResponseOrchestrator:
|
|||
output=self._clone_outputs(outputs),
|
||||
text=self.text,
|
||||
tools=self.ctx.available_tools(),
|
||||
tool_choice=self.ctx.tool_choice,
|
||||
error=error,
|
||||
usage=self.accumulated_usage,
|
||||
instructions=self.instructions,
|
||||
|
|
@ -235,6 +257,34 @@ class StreamingResponseOrchestrator:
|
|||
async for stream_event in self._process_tools(output_messages):
|
||||
yield stream_event
|
||||
|
||||
chat_tool_choice = None
|
||||
# Track allowed tools for filtering (persists across iterations)
|
||||
allowed_tool_names: set[str] | None = None
|
||||
if self.ctx.tool_choice and len(self.ctx.chat_tools) > 0:
|
||||
processed_tool_choice = await _process_tool_choice(
|
||||
self.ctx.chat_tools,
|
||||
self.ctx.tool_choice,
|
||||
self.server_label_to_tools,
|
||||
)
|
||||
# chat_tool_choice can be str, dict-like object, or None
|
||||
if isinstance(processed_tool_choice, str | type(None)):
|
||||
chat_tool_choice = processed_tool_choice
|
||||
elif isinstance(processed_tool_choice, OpenAIChatCompletionToolChoiceAllowedTools):
|
||||
# For allowed_tools: filter the tools list instead of using tool_choice
|
||||
# This maintains the constraint across all iterations while letting model
|
||||
# decide freely whether to call a tool or respond
|
||||
allowed_tool_names = {
|
||||
tool["function"]["name"]
|
||||
for tool in processed_tool_choice.allowed_tools.tools
|
||||
if tool.get("type") == "function" and "function" in tool
|
||||
}
|
||||
# Use the mode (e.g., "required") for first iteration, then "auto"
|
||||
chat_tool_choice = (
|
||||
processed_tool_choice.allowed_tools.mode if processed_tool_choice.allowed_tools.mode else "auto"
|
||||
)
|
||||
else:
|
||||
chat_tool_choice = processed_tool_choice.model_dump()
|
||||
|
||||
n_iter = 0
|
||||
messages = self.ctx.messages.copy()
|
||||
final_status = "completed"
|
||||
|
|
@ -247,7 +297,15 @@ class StreamingResponseOrchestrator:
|
|||
response_format = (
|
||||
None if getattr(self.ctx.response_format, "type", None) == "text" else self.ctx.response_format
|
||||
)
|
||||
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
|
||||
# Filter tools to only allowed ones if tool_choice specified an allowed list
|
||||
effective_tools = self.ctx.chat_tools
|
||||
if allowed_tool_names is not None:
|
||||
effective_tools = [
|
||||
tool
|
||||
for tool in self.ctx.chat_tools
|
||||
if tool.get("function", {}).get("name") in allowed_tool_names
|
||||
]
|
||||
logger.debug(f"calling openai_chat_completion with tools: {effective_tools}")
|
||||
|
||||
logprobs = (
|
||||
True if self.include and ResponseItemInclude.message_output_text_logprobs in self.include else None
|
||||
|
|
@ -257,7 +315,8 @@ class StreamingResponseOrchestrator:
|
|||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
# Pydantic models are dict-compatible but mypy treats them as distinct types
|
||||
tools=self.ctx.chat_tools, # type: ignore[arg-type]
|
||||
tools=effective_tools, # type: ignore[arg-type]
|
||||
tool_choice=chat_tool_choice,
|
||||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=response_format,
|
||||
|
|
@ -335,6 +394,14 @@ class StreamingResponseOrchestrator:
|
|||
break
|
||||
|
||||
n_iter += 1
|
||||
# After first iteration, reset tool_choice to "auto" to let model decide freely
|
||||
# based on tool results (prevents infinite loops when forcing specific tools)
|
||||
# Note: When allowed_tool_names is set, tools are already filtered so model
|
||||
# can only call allowed tools - we just need to let it decide whether to call
|
||||
# a tool or respond (hence "auto" mode)
|
||||
if n_iter == 1 and chat_tool_choice and chat_tool_choice != "auto":
|
||||
chat_tool_choice = "auto"
|
||||
|
||||
if n_iter >= self.max_infer_iters:
|
||||
logger.info(
|
||||
f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}"
|
||||
|
|
@ -1165,6 +1232,11 @@ class StreamingResponseOrchestrator:
|
|||
raise ValueError(f"Duplicate tool name {t.name} found for server {mcp_tool.server_label}")
|
||||
self.mcp_tool_to_server[t.name] = mcp_tool
|
||||
|
||||
# Add to reverse mapping for efficient server_label lookup
|
||||
if mcp_tool.server_label not in self.server_label_to_tools:
|
||||
self.server_label_to_tools[mcp_tool.server_label] = []
|
||||
self.server_label_to_tools[mcp_tool.server_label].append(t.name)
|
||||
|
||||
# Add to MCP list message
|
||||
mcp_list_message.tools.append(
|
||||
MCPListToolsTool(
|
||||
|
|
@ -1304,3 +1376,112 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
|
||||
yield stream_event
|
||||
|
||||
|
||||
async def _process_tool_choice(
|
||||
chat_tools: list[ChatCompletionToolParam],
|
||||
tool_choice: OpenAIResponseInputToolChoice,
|
||||
server_label_to_tools: dict[str, list[str]],
|
||||
) -> str | OpenAIChatCompletionToolChoice | None:
|
||||
"""Process and validate the OpenAI Responses tool choice and return the appropriate chat completion tool choice object.
|
||||
|
||||
:param chat_tools: The list of chat tools to enforce tool choice against.
|
||||
:param tool_choice: The OpenAI Responses tool choice to process.
|
||||
:param server_label_to_tools: A dictionary mapping server labels to the list of tools available on that server.
|
||||
:return: The appropriate chat completion tool choice object.
|
||||
"""
|
||||
|
||||
# retrieve all function tool names from the chat tools
|
||||
# Note: chat_tools contains dicts, not objects
|
||||
chat_tool_names = [tool["function"]["name"] for tool in chat_tools if tool["type"] == "function"]
|
||||
|
||||
if isinstance(tool_choice, OpenAIResponseInputToolChoiceMode):
|
||||
if tool_choice.value == "required":
|
||||
if len(chat_tool_names) == 0:
|
||||
return None
|
||||
|
||||
# add all function tools to the allowed tools list and set mode to required
|
||||
return OpenAIChatCompletionToolChoiceAllowedTools(
|
||||
tools=[{"type": "function", "function": {"name": tool}} for tool in chat_tool_names],
|
||||
mode="required",
|
||||
)
|
||||
# return other modes as is
|
||||
return tool_choice.value
|
||||
|
||||
elif isinstance(tool_choice, OpenAIResponseInputToolChoiceAllowedTools):
|
||||
# ensure that specified tool choices are available in the chat tools, if not, remove them from the list
|
||||
final_tools = []
|
||||
for tool in tool_choice.tools:
|
||||
match tool.get("type"):
|
||||
case "function":
|
||||
final_tools.append({"type": "function", "function": {"name": tool.get("name")}})
|
||||
case "custom":
|
||||
final_tools.append({"type": "custom", "custom": {"name": tool.get("name")}})
|
||||
case "mcp":
|
||||
mcp_tools = convert_mcp_tool_choice(
|
||||
chat_tool_names, tool.get("server_label"), server_label_to_tools, None
|
||||
)
|
||||
# convert_mcp_tool_choice can return a dict, list, or None
|
||||
if isinstance(mcp_tools, list):
|
||||
final_tools.extend(mcp_tools)
|
||||
elif isinstance(mcp_tools, dict):
|
||||
final_tools.append(mcp_tools)
|
||||
# Skip if None or empty
|
||||
case "file_search":
|
||||
final_tools.append({"type": "function", "function": {"name": "file_search"}})
|
||||
case _ if tool["type"] in WebSearchToolTypes:
|
||||
final_tools.append({"type": "function", "function": {"name": "web_search"}})
|
||||
case _:
|
||||
logger.warning(f"Unsupported tool type: {tool['type']}, skipping tool choice enforcement for it")
|
||||
continue
|
||||
|
||||
return OpenAIChatCompletionToolChoiceAllowedTools(
|
||||
tools=final_tools,
|
||||
mode=tool_choice.mode,
|
||||
)
|
||||
|
||||
else:
|
||||
# Handle specific tool choice by type
|
||||
# Each case validates the tool exists in chat_tools before returning
|
||||
tool_name = getattr(tool_choice, "name", None)
|
||||
match tool_choice:
|
||||
case OpenAIResponseInputToolChoiceCustomTool():
|
||||
if tool_name and tool_name not in chat_tool_names:
|
||||
logger.warning(f"Tool {tool_name} not found in chat tools")
|
||||
return None
|
||||
return OpenAIChatCompletionToolChoiceCustomTool(name=tool_name)
|
||||
|
||||
case OpenAIResponseInputToolChoiceFunctionTool():
|
||||
if tool_name and tool_name not in chat_tool_names:
|
||||
logger.warning(f"Tool {tool_name} not found in chat tools")
|
||||
return None
|
||||
return OpenAIChatCompletionToolChoiceFunctionTool(name=tool_name)
|
||||
|
||||
case OpenAIResponseInputToolChoiceFileSearch():
|
||||
if "file_search" not in chat_tool_names:
|
||||
logger.warning("Tool file_search not found in chat tools")
|
||||
return None
|
||||
return OpenAIChatCompletionToolChoiceFunctionTool(name="file_search")
|
||||
|
||||
case OpenAIResponseInputToolChoiceWebSearch():
|
||||
if "web_search" not in chat_tool_names:
|
||||
logger.warning("Tool web_search not found in chat tools")
|
||||
return None
|
||||
return OpenAIChatCompletionToolChoiceFunctionTool(name="web_search")
|
||||
|
||||
case OpenAIResponseInputToolChoiceMCPTool():
|
||||
tool_choice = convert_mcp_tool_choice(
|
||||
chat_tool_names,
|
||||
tool_choice.server_label,
|
||||
server_label_to_tools,
|
||||
tool_name,
|
||||
)
|
||||
if isinstance(tool_choice, dict):
|
||||
# for single tool choice, return as function tool choice
|
||||
return OpenAIChatCompletionToolChoiceFunctionTool(name=tool_choice["function"]["name"])
|
||||
elif isinstance(tool_choice, list):
|
||||
# for multiple tool choices, return as allowed tools
|
||||
return OpenAIChatCompletionToolChoiceAllowedTools(
|
||||
tools=tool_choice,
|
||||
mode="required",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from llama_stack_api import (
|
|||
OpenAIResponseFormatParam,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolChoice,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
|
|
@ -162,6 +163,7 @@ class ChatCompletionContext(BaseModel):
|
|||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
||||
tool_context: ToolContext | None
|
||||
tool_choice: OpenAIResponseInputToolChoice | None = None
|
||||
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
|
||||
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
|
||||
|
||||
|
|
@ -174,6 +176,7 @@ class ChatCompletionContext(BaseModel):
|
|||
response_format: OpenAIResponseFormatParam,
|
||||
tool_context: ToolContext,
|
||||
inputs: list[OpenAIResponseInput] | str,
|
||||
tool_choice: OpenAIResponseInputToolChoice | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
|
|
@ -182,6 +185,7 @@ class ChatCompletionContext(BaseModel):
|
|||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
tool_context=tool_context,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
if not isinstance(inputs, str):
|
||||
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
|
||||
|
|
|
|||
|
|
@ -506,3 +506,28 @@ def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
|||
raise ValueError(f"Unknown guardrail format: {guardrail}, expected str or ResponseGuardrailSpec")
|
||||
|
||||
return guardrail_ids
|
||||
|
||||
|
||||
def convert_mcp_tool_choice(
|
||||
chat_tool_names: list[str],
|
||||
server_label: str | None = None,
|
||||
server_label_to_tools: dict[str, list[str]] | None = None,
|
||||
tool_name: str | None = None,
|
||||
) -> dict[str, str] | list[dict[str, str]]:
|
||||
"""Convert a responses tool choice of type mcp to a chat completions compatible function tool choice."""
|
||||
|
||||
if tool_name:
|
||||
if tool_name not in chat_tool_names:
|
||||
return None
|
||||
return {"type": "function", "function": {"name": tool_name}}
|
||||
|
||||
elif server_label and server_label_to_tools:
|
||||
# no tool name specified, so we need to enforce an allowed_tools with the function tools derived only from the given server label
|
||||
# Use reverse mapping for lookup by server_label
|
||||
# This already accounts for allowed_tools restrictions applied during _process_mcp_tool
|
||||
tool_names = server_label_to_tools.get(server_label, [])
|
||||
if not tool_names:
|
||||
return None
|
||||
matching_tools = [{"type": "function", "function": {"name": tool_name}} for tool_name in tool_names]
|
||||
return matching_tools
|
||||
return []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue