mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-22 16:23:08 +00:00
refactor(agent): drop AgentToolGroup for responses tools
This commit is contained in:
parent
c56b2deb7d
commit
ce44b9d6f6
12 changed files with 4051 additions and 4225 deletions
|
@ -10,16 +10,13 @@ import re
|
|||
import uuid
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
OpenAIResponseInputTool,
|
||||
AgentToolGroup,
|
||||
AgentToolGroupWithArgs,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseEvent,
|
||||
AgentTurnResponseEventType,
|
||||
|
@ -33,12 +30,19 @@ from llama_stack.apis.agents import (
|
|||
Attachment,
|
||||
Document,
|
||||
InferenceStep,
|
||||
OpenAIResponseInputTool,
|
||||
ShieldCallStep,
|
||||
Step,
|
||||
StepType,
|
||||
ToolExecutionStep,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import URL, ToolCallDelta, ToolCallParseStatus
|
||||
from llama_stack.apis.common.errors import SessionNotFoundError
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -47,13 +51,12 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionMessageContent,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIImageURL,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
|
@ -123,7 +126,9 @@ def _openai_tool_call_to_legacy(tool_call: OpenAIChatCompletionToolCall) -> Tool
|
|||
|
||||
|
||||
def _legacy_tool_call_to_openai(tool_call: ToolCall, index: int | None = None) -> OpenAIChatCompletionToolCall:
|
||||
function_name = tool_call.tool_name if not isinstance(tool_call.tool_name, BuiltinTool) else tool_call.tool_name.value
|
||||
function_name = (
|
||||
tool_call.tool_name if not isinstance(tool_call.tool_name, BuiltinTool) else tool_call.tool_name.value
|
||||
)
|
||||
return OpenAIChatCompletionToolCall(
|
||||
index=index,
|
||||
id=tool_call.call_id,
|
||||
|
@ -178,9 +183,9 @@ def _coerce_to_text(content: Any) -> str:
|
|||
if isinstance(content, list):
|
||||
return "\n".join(_coerce_to_text(item) for item in content)
|
||||
if hasattr(content, "text"):
|
||||
return getattr(content, "text")
|
||||
return content.text
|
||||
if hasattr(content, "image"):
|
||||
image = getattr(content, "image")
|
||||
image = content.image
|
||||
if hasattr(image, "url") and image.url:
|
||||
return getattr(image.url, "uri", "")
|
||||
return str(content)
|
||||
|
@ -200,10 +205,7 @@ def _openai_message_param_to_legacy(message: OpenAIMessageParam) -> Message:
|
|||
# Map developer messages to user role for legacy compatibility
|
||||
return UserMessage(content=_openai_message_content_to_text(message.content))
|
||||
if isinstance(message, OpenAIAssistantMessageParam):
|
||||
tool_calls = [
|
||||
_openai_tool_call_to_legacy(tool_call)
|
||||
for tool_call in message.tool_calls or []
|
||||
]
|
||||
tool_calls = [_openai_tool_call_to_legacy(tool_call) for tool_call in message.tool_calls or []]
|
||||
return CompletionMessage(
|
||||
content=_openai_message_content_to_text(message.content) if message.content is not None else "",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
|
@ -279,6 +281,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self.created_at = created_at
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
|
||||
self.tool_defs: list[ToolDefinition] = []
|
||||
self.tool_name_to_args: dict[str | BuiltinTool, dict[str, Any]] = {}
|
||||
self.client_tools_config: list[OpenAIResponseInputTool | ToolDef] = []
|
||||
|
||||
ShieldRunnerMixin.__init__(
|
||||
self,
|
||||
safety_api,
|
||||
|
@ -367,7 +373,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
await self._initialize_tools(request.tools)
|
||||
async for chunk in self._run_turn(request, turn_id):
|
||||
yield chunk
|
||||
|
||||
|
@ -682,12 +688,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# Build a map of custom tools to their definitions for faster lookup
|
||||
client_tools: dict[str, OpenAIResponseInputTool | ToolDef] = {}
|
||||
if self.agent_config.client_tools:
|
||||
for tool in self.agent_config.client_tools:
|
||||
if isinstance(tool, ToolDef) and tool.name:
|
||||
client_tools[tool.name] = tool
|
||||
elif getattr(tool, "type", None) == "function" and getattr(tool, "name", None):
|
||||
client_tools[tool.name] = tool
|
||||
for tool in self.client_tools_config or []:
|
||||
if isinstance(tool, ToolDef) and tool.name:
|
||||
client_tools[tool.name] = tool
|
||||
elif getattr(tool, "type", None) == "function" and getattr(tool, "name", None):
|
||||
client_tools[tool.name] = tool
|
||||
while True:
|
||||
step_id = str(uuid.uuid4())
|
||||
inference_start_time = datetime.now(UTC).isoformat()
|
||||
|
@ -987,91 +992,124 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _initialize_tools(
|
||||
self,
|
||||
toolgroups_for_turn: list[AgentToolGroup] | None = None,
|
||||
tools_for_turn: list[OpenAIResponseInputTool] | None = None,
|
||||
) -> None:
|
||||
toolgroup_to_args = {}
|
||||
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
||||
toolgroup_to_args[tool_group_name] = toolgroup.args
|
||||
|
||||
# Determine which tools to include
|
||||
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
|
||||
agent_config_toolgroups = []
|
||||
for toolgroup in tool_groups_to_include:
|
||||
name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
|
||||
if name not in agent_config_toolgroups:
|
||||
agent_config_toolgroups.append(name)
|
||||
|
||||
toolgroup_to_args = toolgroup_to_args or {}
|
||||
|
||||
tool_name_to_def: dict[str | BuiltinTool, ToolDefinition] = {}
|
||||
tool_name_to_args: dict[str | BuiltinTool, dict[str, Any]] = {}
|
||||
client_tools_map: dict[str, OpenAIResponseInputTool | ToolDef] = {}
|
||||
|
||||
def add_tool_definition(identifier: str | BuiltinTool, tool_definition: ToolDefinition) -> None:
|
||||
if identifier in tool_name_to_def:
|
||||
raise ValueError(f"Tool {identifier} already exists")
|
||||
tool_name_to_def[identifier] = tool_definition
|
||||
|
||||
def add_client_tool(tool: OpenAIResponseInputTool | ToolDef) -> None:
|
||||
name = getattr(tool, "name", None)
|
||||
if isinstance(tool, ToolDef):
|
||||
name = tool.name
|
||||
if not name:
|
||||
raise ValueError("Client tools must have a name")
|
||||
if name not in client_tools_map:
|
||||
client_tools_map[name] = tool
|
||||
tool_definition = _client_tool_to_tool_definition(tool)
|
||||
add_tool_definition(tool_definition.tool_name, tool_definition)
|
||||
|
||||
if self.agent_config.client_tools:
|
||||
for tool in self.agent_config.client_tools:
|
||||
tool_definition = _client_tool_to_tool_definition(tool)
|
||||
if tool_name_to_def.get(tool_definition.tool_name):
|
||||
raise ValueError(f"Tool {tool_definition.tool_name} already exists")
|
||||
tool_name_to_def[tool_definition.tool_name] = tool_definition
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||
if not tools.data:
|
||||
available_tool_groups = ", ".join(
|
||||
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
|
||||
)
|
||||
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
|
||||
if input_tool_name is not None and not any(tool.name == input_tool_name for tool in tools.data):
|
||||
raise ValueError(
|
||||
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.name for tool in tools.data])}"
|
||||
)
|
||||
add_client_tool(tool)
|
||||
|
||||
for tool_def in tools.data:
|
||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||
identifier: str | BuiltinTool | None = tool_def.name
|
||||
if identifier == "web_search":
|
||||
identifier = BuiltinTool.brave_search
|
||||
else:
|
||||
identifier = BuiltinTool(identifier)
|
||||
else:
|
||||
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
|
||||
if input_tool_name in (None, tool_def.name):
|
||||
identifier = tool_def.name
|
||||
else:
|
||||
identifier = None
|
||||
effective_tools = tools_for_turn
|
||||
if effective_tools is None:
|
||||
effective_tools = self.agent_config.tools
|
||||
|
||||
if tool_name_to_def.get(identifier, None):
|
||||
raise ValueError(f"Tool {identifier} already exists")
|
||||
if identifier:
|
||||
tool_name_to_def[identifier] = ToolDefinition(
|
||||
for tool in effective_tools or []:
|
||||
if isinstance(tool, OpenAIResponseInputToolFunction):
|
||||
add_client_tool(tool)
|
||||
continue
|
||||
|
||||
resolved_tools = await self._resolve_response_tool(tool)
|
||||
for identifier, definition, args in resolved_tools:
|
||||
add_tool_definition(identifier, definition)
|
||||
if args:
|
||||
existing_args = tool_name_to_args.get(identifier, {})
|
||||
tool_name_to_args[identifier] = {**existing_args, **args}
|
||||
|
||||
self.tool_defs = list(tool_name_to_def.values())
|
||||
self.tool_name_to_args = tool_name_to_args
|
||||
self.client_tools_config = list(client_tools_map.values())
|
||||
|
||||
async def _resolve_response_tool(
|
||||
self,
|
||||
tool: OpenAIResponseInputTool,
|
||||
) -> list[tuple[str | BuiltinTool, ToolDefinition, dict[str, Any]]]:
|
||||
if isinstance(tool, OpenAIResponseInputToolWebSearch):
|
||||
tool_def = await self.tool_groups_api.get_tool(WEB_SEARCH_TOOL)
|
||||
if tool_def is None:
|
||||
raise ValueError("web_search tool is not registered")
|
||||
identifier: str | BuiltinTool = BuiltinTool.brave_search
|
||||
return [
|
||||
(
|
||||
identifier,
|
||||
ToolDefinition(
|
||||
tool_name=identifier,
|
||||
description=tool_def.description,
|
||||
input_schema=tool_def.input_schema,
|
||||
),
|
||||
{},
|
||||
)
|
||||
]
|
||||
|
||||
if isinstance(tool, OpenAIResponseInputToolFileSearch):
|
||||
tool_def = await self.tool_groups_api.get_tool(MEMORY_QUERY_TOOL)
|
||||
if tool_def is None:
|
||||
raise ValueError("knowledge_search tool is not registered")
|
||||
args: dict[str, Any] = {
|
||||
"vector_db_ids": tool.vector_store_ids,
|
||||
}
|
||||
if tool.filters is not None:
|
||||
args["filters"] = tool.filters
|
||||
if tool.max_num_results is not None:
|
||||
args["max_num_results"] = tool.max_num_results
|
||||
if tool.ranking_options is not None:
|
||||
args["ranking_options"] = tool.ranking_options.model_dump()
|
||||
|
||||
return [
|
||||
(
|
||||
tool_def.name,
|
||||
ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
input_schema=tool_def.input_schema,
|
||||
),
|
||||
args,
|
||||
)
|
||||
]
|
||||
|
||||
if isinstance(tool, OpenAIResponseInputToolMCP):
|
||||
toolgroup_id = tool.server_label
|
||||
if not toolgroup_id.startswith("mcp::"):
|
||||
toolgroup_id = f"mcp::{toolgroup_id}"
|
||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_id)
|
||||
if not tools.data:
|
||||
raise ValueError(
|
||||
f"No tools registered for MCP server '{tool.server_label}'. Ensure the toolgroup '{toolgroup_id}' is registered."
|
||||
)
|
||||
resolved: list[tuple[str | BuiltinTool, ToolDefinition, dict[str, Any]]] = []
|
||||
for tool_def in tools.data:
|
||||
resolved.append(
|
||||
(
|
||||
tool_def.name,
|
||||
ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
input_schema=tool_def.input_schema,
|
||||
),
|
||||
{},
|
||||
)
|
||||
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
)
|
||||
return resolved
|
||||
|
||||
self.tool_defs, self.tool_name_to_args = (
|
||||
list(tool_name_to_def.values()),
|
||||
tool_name_to_args,
|
||||
)
|
||||
|
||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, str | None]:
|
||||
"""Parse a toolgroup name into its components.
|
||||
|
||||
Args:
|
||||
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search")
|
||||
|
||||
Returns:
|
||||
A tuple of (tool_type, tool_group, tool_name)
|
||||
"""
|
||||
split_names = toolgroup_name_with_maybe_tool_name.split("/")
|
||||
if len(split_names) == 2:
|
||||
# e.g. "builtin::rag"
|
||||
tool_group, tool_name = split_names
|
||||
else:
|
||||
tool_group, tool_name = split_names[0], None
|
||||
return tool_group, tool_name
|
||||
raise ValueError(f"Unsupported tool type '{getattr(tool, 'type', None)}' in agent configuration")
|
||||
|
||||
async def execute_tool_call_maybe(
|
||||
self,
|
||||
|
|
|
@ -15,7 +15,6 @@ from llama_stack.apis.agents import (
|
|||
Agents,
|
||||
AgentSessionCreateResponse,
|
||||
AgentStepResponse,
|
||||
AgentToolGroup,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
|
@ -32,9 +31,9 @@ from llama_stack.apis.agents.openai_responses import OpenAIResponseText
|
|||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
OpenAIMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
|
@ -156,7 +155,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
documents: list[Document] | None = None,
|
||||
stream: bool | None = False,
|
||||
tool_config: ToolConfig | None = None,
|
||||
|
@ -166,7 +165,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id=session_id,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
toolgroups=toolgroups,
|
||||
tools=tools,
|
||||
documents=documents,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue