refactor(agent): drop AgentToolGroup for responses tools

This commit is contained in:
Ashwin Bharambe 2025-10-10 13:43:43 -07:00
parent c56b2deb7d
commit ce44b9d6f6
12 changed files with 4051 additions and 4225 deletions

View file

@ -10,16 +10,13 @@ import re
import uuid
import warnings
from collections.abc import AsyncGenerator
from typing import Any
from datetime import UTC, datetime
from typing import Any
import httpx
from llama_stack.apis.agents import (
AgentConfig,
OpenAIResponseInputTool,
AgentToolGroup,
AgentToolGroupWithArgs,
AgentTurnCreateRequest,
AgentTurnResponseEvent,
AgentTurnResponseEventType,
@ -33,12 +30,19 @@ from llama_stack.apis.agents import (
Attachment,
Document,
InferenceStep,
OpenAIResponseInputTool,
ShieldCallStep,
Step,
StepType,
ToolExecutionStep,
Turn,
)
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseInputToolWebSearch,
)
from llama_stack.apis.common.content_types import URL, ToolCallDelta, ToolCallParseStatus
from llama_stack.apis.common.errors import SessionNotFoundError
from llama_stack.apis.inference import (
@ -47,13 +51,12 @@ from llama_stack.apis.inference import (
Inference,
Message,
OpenAIAssistantMessageParam,
OpenAIDeveloperMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionMessageContent,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIImageURL,
OpenAIDeveloperMessageParam,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
@ -123,7 +126,9 @@ def _openai_tool_call_to_legacy(tool_call: OpenAIChatCompletionToolCall) -> Tool
def _legacy_tool_call_to_openai(tool_call: ToolCall, index: int | None = None) -> OpenAIChatCompletionToolCall:
function_name = tool_call.tool_name if not isinstance(tool_call.tool_name, BuiltinTool) else tool_call.tool_name.value
function_name = (
tool_call.tool_name if not isinstance(tool_call.tool_name, BuiltinTool) else tool_call.tool_name.value
)
return OpenAIChatCompletionToolCall(
index=index,
id=tool_call.call_id,
@ -178,9 +183,9 @@ def _coerce_to_text(content: Any) -> str:
if isinstance(content, list):
return "\n".join(_coerce_to_text(item) for item in content)
if hasattr(content, "text"):
return getattr(content, "text")
return content.text
if hasattr(content, "image"):
image = getattr(content, "image")
image = content.image
if hasattr(image, "url") and image.url:
return getattr(image.url, "uri", "")
return str(content)
@ -200,10 +205,7 @@ def _openai_message_param_to_legacy(message: OpenAIMessageParam) -> Message:
# Map developer messages to user role for legacy compatibility
return UserMessage(content=_openai_message_content_to_text(message.content))
if isinstance(message, OpenAIAssistantMessageParam):
tool_calls = [
_openai_tool_call_to_legacy(tool_call)
for tool_call in message.tool_calls or []
]
tool_calls = [_openai_tool_call_to_legacy(tool_call) for tool_call in message.tool_calls or []]
return CompletionMessage(
content=_openai_message_content_to_text(message.content) if message.content is not None else "",
stop_reason=StopReason.end_of_turn,
@ -279,6 +281,10 @@ class ChatAgent(ShieldRunnerMixin):
self.created_at = created_at
self.telemetry_enabled = telemetry_enabled
self.tool_defs: list[ToolDefinition] = []
self.tool_name_to_args: dict[str | BuiltinTool, dict[str, Any]] = {}
self.client_tools_config: list[OpenAIResponseInputTool | ToolDef] = []
ShieldRunnerMixin.__init__(
self,
safety_api,
@ -367,7 +373,7 @@ class ChatAgent(ShieldRunnerMixin):
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
await self._initialize_tools(request.toolgroups)
await self._initialize_tools(request.tools)
async for chunk in self._run_turn(request, turn_id):
yield chunk
@ -682,12 +688,11 @@ class ChatAgent(ShieldRunnerMixin):
# Build a map of custom tools to their definitions for faster lookup
client_tools: dict[str, OpenAIResponseInputTool | ToolDef] = {}
if self.agent_config.client_tools:
for tool in self.agent_config.client_tools:
if isinstance(tool, ToolDef) and tool.name:
client_tools[tool.name] = tool
elif getattr(tool, "type", None) == "function" and getattr(tool, "name", None):
client_tools[tool.name] = tool
for tool in self.client_tools_config or []:
if isinstance(tool, ToolDef) and tool.name:
client_tools[tool.name] = tool
elif getattr(tool, "type", None) == "function" and getattr(tool, "name", None):
client_tools[tool.name] = tool
while True:
step_id = str(uuid.uuid4())
inference_start_time = datetime.now(UTC).isoformat()
@ -987,91 +992,124 @@ class ChatAgent(ShieldRunnerMixin):
async def _initialize_tools(
self,
toolgroups_for_turn: list[AgentToolGroup] | None = None,
tools_for_turn: list[OpenAIResponseInputTool] | None = None,
) -> None:
toolgroup_to_args = {}
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
if isinstance(toolgroup, AgentToolGroupWithArgs):
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
toolgroup_to_args[tool_group_name] = toolgroup.args
# Determine which tools to include
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
agent_config_toolgroups = []
for toolgroup in tool_groups_to_include:
name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
if name not in agent_config_toolgroups:
agent_config_toolgroups.append(name)
toolgroup_to_args = toolgroup_to_args or {}
tool_name_to_def: dict[str | BuiltinTool, ToolDefinition] = {}
tool_name_to_args: dict[str | BuiltinTool, dict[str, Any]] = {}
client_tools_map: dict[str, OpenAIResponseInputTool | ToolDef] = {}
def add_tool_definition(identifier: str | BuiltinTool, tool_definition: ToolDefinition) -> None:
if identifier in tool_name_to_def:
raise ValueError(f"Tool {identifier} already exists")
tool_name_to_def[identifier] = tool_definition
def add_client_tool(tool: OpenAIResponseInputTool | ToolDef) -> None:
name = getattr(tool, "name", None)
if isinstance(tool, ToolDef):
name = tool.name
if not name:
raise ValueError("Client tools must have a name")
if name not in client_tools_map:
client_tools_map[name] = tool
tool_definition = _client_tool_to_tool_definition(tool)
add_tool_definition(tool_definition.tool_name, tool_definition)
if self.agent_config.client_tools:
for tool in self.agent_config.client_tools:
tool_definition = _client_tool_to_tool_definition(tool)
if tool_name_to_def.get(tool_definition.tool_name):
raise ValueError(f"Tool {tool_definition.tool_name} already exists")
tool_name_to_def[tool_definition.tool_name] = tool_definition
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
if not tools.data:
available_tool_groups = ", ".join(
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
)
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
if input_tool_name is not None and not any(tool.name == input_tool_name for tool in tools.data):
raise ValueError(
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.name for tool in tools.data])}"
)
add_client_tool(tool)
for tool_def in tools.data:
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
identifier: str | BuiltinTool | None = tool_def.name
if identifier == "web_search":
identifier = BuiltinTool.brave_search
else:
identifier = BuiltinTool(identifier)
else:
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
if input_tool_name in (None, tool_def.name):
identifier = tool_def.name
else:
identifier = None
effective_tools = tools_for_turn
if effective_tools is None:
effective_tools = self.agent_config.tools
if tool_name_to_def.get(identifier, None):
raise ValueError(f"Tool {identifier} already exists")
if identifier:
tool_name_to_def[identifier] = ToolDefinition(
for tool in effective_tools or []:
if isinstance(tool, OpenAIResponseInputToolFunction):
add_client_tool(tool)
continue
resolved_tools = await self._resolve_response_tool(tool)
for identifier, definition, args in resolved_tools:
add_tool_definition(identifier, definition)
if args:
existing_args = tool_name_to_args.get(identifier, {})
tool_name_to_args[identifier] = {**existing_args, **args}
self.tool_defs = list(tool_name_to_def.values())
self.tool_name_to_args = tool_name_to_args
self.client_tools_config = list(client_tools_map.values())
async def _resolve_response_tool(
self,
tool: OpenAIResponseInputTool,
) -> list[tuple[str | BuiltinTool, ToolDefinition, dict[str, Any]]]:
if isinstance(tool, OpenAIResponseInputToolWebSearch):
tool_def = await self.tool_groups_api.get_tool(WEB_SEARCH_TOOL)
if tool_def is None:
raise ValueError("web_search tool is not registered")
identifier: str | BuiltinTool = BuiltinTool.brave_search
return [
(
identifier,
ToolDefinition(
tool_name=identifier,
description=tool_def.description,
input_schema=tool_def.input_schema,
),
{},
)
]
if isinstance(tool, OpenAIResponseInputToolFileSearch):
tool_def = await self.tool_groups_api.get_tool(MEMORY_QUERY_TOOL)
if tool_def is None:
raise ValueError("knowledge_search tool is not registered")
args: dict[str, Any] = {
"vector_db_ids": tool.vector_store_ids,
}
if tool.filters is not None:
args["filters"] = tool.filters
if tool.max_num_results is not None:
args["max_num_results"] = tool.max_num_results
if tool.ranking_options is not None:
args["ranking_options"] = tool.ranking_options.model_dump()
return [
(
tool_def.name,
ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
input_schema=tool_def.input_schema,
),
args,
)
]
if isinstance(tool, OpenAIResponseInputToolMCP):
toolgroup_id = tool.server_label
if not toolgroup_id.startswith("mcp::"):
toolgroup_id = f"mcp::{toolgroup_id}"
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_id)
if not tools.data:
raise ValueError(
f"No tools registered for MCP server '{tool.server_label}'. Ensure the toolgroup '{toolgroup_id}' is registered."
)
resolved: list[tuple[str | BuiltinTool, ToolDefinition, dict[str, Any]]] = []
for tool_def in tools.data:
resolved.append(
(
tool_def.name,
ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
input_schema=tool_def.input_schema,
),
{},
)
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
)
return resolved
self.tool_defs, self.tool_name_to_args = (
list(tool_name_to_def.values()),
tool_name_to_args,
)
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, str | None]:
"""Parse a toolgroup name into its components.
Args:
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search")
Returns:
A tuple of (tool_type, tool_group, tool_name)
"""
split_names = toolgroup_name_with_maybe_tool_name.split("/")
if len(split_names) == 2:
# e.g. "builtin::rag"
tool_group, tool_name = split_names
else:
tool_group, tool_name = split_names[0], None
return tool_group, tool_name
raise ValueError(f"Unsupported tool type '{getattr(tool, 'type', None)}' in agent configuration")
async def execute_tool_call_maybe(
self,

View file

@ -15,7 +15,6 @@ from llama_stack.apis.agents import (
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentToolGroup,
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
@ -32,9 +31,9 @@ from llama_stack.apis.agents.openai_responses import OpenAIResponseText
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import (
Inference,
ToolConfig,
OpenAIMessageParam,
OpenAIToolMessageParam,
ToolConfig,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
@ -156,7 +155,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_id: str,
messages: list[OpenAIMessageParam],
toolgroups: list[AgentToolGroup] | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
documents: list[Document] | None = None,
stream: bool | None = False,
tool_config: ToolConfig | None = None,
@ -166,7 +165,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id=session_id,
messages=messages,
stream=True,
toolgroups=toolgroups,
tools=tools,
documents=documents,
tool_config=tool_config,
)