chore: simplify _get_tool_defs (#1384)

Summary:

Test Plan:
LLAMA_STACK_CONFIG=fireworks pytest -s -v
tests/integration/agents/test_agents.py --safety-shield
meta-llama/Llama-Guard-3-8B --text-model
meta-llama/Llama-3.1-8B-Instruct
This commit is contained in:
ehhuang 2025-03-12 18:51:18 -07:00 committed by GitHub
parent 41c9bca1aa
commit ed6caead72
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -12,7 +12,7 @@ import secrets
import string
import uuid
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
from typing import AsyncGenerator, List, Optional, Union
from urllib.parse import urlparse
import httpx
@ -457,10 +457,12 @@ class ChatAgent(ShieldRunnerMixin):
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info and session_info.vector_db_id:
if RAG_TOOL_GROUP not in self.toolgroup_to_args:
self.toolgroup_to_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]}
else:
self.toolgroup_to_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
for tool_name in self.tool_name_to_args.keys():
if tool_name == MEMORY_QUERY_TOOL:
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id]
else:
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
output_attachments = []
@ -727,18 +729,16 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message, result_message]
async def _initialize_tools(self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None):
self.toolgroup_to_args = {}
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
async def _initialize_tools(
self,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> None:
toolgroup_to_args = {}
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
if isinstance(toolgroup, AgentToolGroupWithArgs):
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
self.toolgroup_to_args[tool_group_name] = toolgroup.args
toolgroup_to_args[tool_group_name] = toolgroup.args
self.tool_defs, self.tool_name_to_group_id = await self._get_tool_defs(toolgroups_for_turn)
async def _get_tool_defs(
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
# Determine which tools to include
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
agent_config_toolgroups = []
@ -747,8 +747,10 @@ class ChatAgent(ShieldRunnerMixin):
if name not in agent_config_toolgroups:
agent_config_toolgroups.append(name)
toolgroup_to_args = toolgroup_to_args or {}
tool_name_to_def = {}
tool_name_to_group_id = {}
tool_name_to_args = {}
for tool_def in self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None):
@ -766,53 +768,38 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters
},
)
tool_name_to_group_id[tool_def.name] = "__client_tools__"
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
if not tools.data:
available_tool_groups = ", ".join(
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
)
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
raise ValueError(
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
)
for tool_def in tools.data:
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
tool_name = tool_def.identifier
built_in_type = BuiltinTool.brave_search
if tool_name == "web_search":
built_in_type = BuiltinTool.brave_search
identifier: str | BuiltinTool | None = tool_def.identifier
if identifier == "web_search":
identifier = BuiltinTool.brave_search
else:
built_in_type = BuiltinTool(tool_name)
identifier = BuiltinTool(identifier)
else:
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
if input_tool_name in (None, tool_def.identifier):
identifier = tool_def.identifier
else:
identifier = None
if tool_name_to_def.get(built_in_type, None):
raise ValueError(f"Tool {built_in_type} already exists")
tool_name_to_def[built_in_type] = ToolDefinition(
tool_name=built_in_type,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_name_to_group_id[built_in_type] = tool_def.toolgroup_id
continue
if tool_name_to_def.get(tool_def.identifier, None):
raise ValueError(f"Tool {tool_def.identifier} already exists")
if tool_name in (None, tool_def.identifier):
if tool_name_to_def.get(identifier, None):
raise ValueError(f"Tool {identifier} already exists")
if identifier:
tool_name_to_def[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
tool_name=identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
@ -824,9 +811,9 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters
},
)
tool_name_to_group_id[tool_def.identifier] = tool_def.toolgroup_id
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
return list(tool_name_to_def.values()), tool_name_to_group_id
self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
"""Parse a toolgroup name into its components.
@ -850,29 +837,31 @@ class ChatAgent(ShieldRunnerMixin):
session_id: str,
tool_call: ToolCall,
) -> ToolInvocationResult:
name = tool_call.tool_name
group_name = self.tool_name_to_group_id.get(name, None)
if group_name is None:
tool_name = tool_call.tool_name
registered_tool_names = [tool_def.tool_name for tool_def in self.tool_defs]
if tool_name not in registered_tool_names:
raise ValueError(
f"Tool {name} not found in any tool group, available tools: {', '.join(self.tool_name_to_group_id.keys())}"
f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}"
)
if isinstance(name, BuiltinTool):
if name == BuiltinTool.brave_search:
name = WEB_SEARCH_TOOL
if isinstance(tool_name, BuiltinTool):
if tool_name == BuiltinTool.brave_search:
tool_name_str = WEB_SEARCH_TOOL
else:
name = name.value
tool_name_str = tool_name.value
else:
tool_name_str = tool_name
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
result = await self.tool_runtime_api.invoke_tool(
tool_name=name,
tool_name=tool_name_str,
kwargs={
"session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**tool_call.arguments,
**self.toolgroup_to_args.get(group_name, {}),
**self.tool_name_to_args.get(tool_name_str, {}),
},
)
logger.debug(f"tool call {name} completed with result: {result}")
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
return result
async def handle_documents(