chore: simplify _get_tool_defs (#1384)

Summary:

Test Plan:
LLAMA_STACK_CONFIG=fireworks pytest -s -v
tests/integration/agents/test_agents.py --safety-shield
meta-llama/Llama-Guard-3-8B --text-model
meta-llama/Llama-3.1-8B-Instruct
This commit is contained in:
ehhuang 2025-03-12 18:51:18 -07:00 committed by GitHub
parent 41c9bca1aa
commit ed6caead72
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -12,7 +12,7 @@ import secrets
import string import string
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union from typing import AsyncGenerator, List, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
@ -457,10 +457,12 @@ class ChatAgent(ShieldRunnerMixin):
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it # if the session has a memory bank id, let the memory tool use it
if session_info and session_info.vector_db_id: if session_info and session_info.vector_db_id:
if RAG_TOOL_GROUP not in self.toolgroup_to_args: for tool_name in self.tool_name_to_args.keys():
self.toolgroup_to_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]} if tool_name == MEMORY_QUERY_TOOL:
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id]
else: else:
self.toolgroup_to_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id) self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
output_attachments = [] output_attachments = []
@ -727,18 +729,16 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message, result_message] input_messages = input_messages + [message, result_message]
async def _initialize_tools(self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None): async def _initialize_tools(
self.toolgroup_to_args = {} self,
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []): toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> None:
toolgroup_to_args = {}
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
if isinstance(toolgroup, AgentToolGroupWithArgs): if isinstance(toolgroup, AgentToolGroupWithArgs):
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name) tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
self.toolgroup_to_args[tool_group_name] = toolgroup.args toolgroup_to_args[tool_group_name] = toolgroup.args
self.tool_defs, self.tool_name_to_group_id = await self._get_tool_defs(toolgroups_for_turn)
async def _get_tool_defs(
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
# Determine which tools to include # Determine which tools to include
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or [] tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
agent_config_toolgroups = [] agent_config_toolgroups = []
@ -747,8 +747,10 @@ class ChatAgent(ShieldRunnerMixin):
if name not in agent_config_toolgroups: if name not in agent_config_toolgroups:
agent_config_toolgroups.append(name) agent_config_toolgroups.append(name)
toolgroup_to_args = toolgroup_to_args or {}
tool_name_to_def = {} tool_name_to_def = {}
tool_name_to_group_id = {} tool_name_to_args = {}
for tool_def in self.agent_config.client_tools: for tool_def in self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None): if tool_name_to_def.get(tool_def.name, None):
@ -766,53 +768,38 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters for param in tool_def.parameters
}, },
) )
tool_name_to_group_id[tool_def.name] = "__client_tools__"
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
if not tools.data: if not tools.data:
available_tool_groups = ", ".join( available_tool_groups = ", ".join(
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data] [t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
) )
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}") raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data): if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
raise ValueError( raise ValueError(
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}" f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
) )
for tool_def in tools.data: for tool_def in tools.data:
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP: if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
tool_name = tool_def.identifier identifier: str | BuiltinTool | None = tool_def.identifier
built_in_type = BuiltinTool.brave_search if identifier == "web_search":
if tool_name == "web_search": identifier = BuiltinTool.brave_search
built_in_type = BuiltinTool.brave_search
else: else:
built_in_type = BuiltinTool(tool_name) identifier = BuiltinTool(identifier)
else:
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
if input_tool_name in (None, tool_def.identifier):
identifier = tool_def.identifier
else:
identifier = None
if tool_name_to_def.get(built_in_type, None): if tool_name_to_def.get(identifier, None):
raise ValueError(f"Tool {built_in_type} already exists") raise ValueError(f"Tool {identifier} already exists")
if identifier:
tool_name_to_def[built_in_type] = ToolDefinition(
tool_name=built_in_type,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_name_to_group_id[built_in_type] = tool_def.toolgroup_id
continue
if tool_name_to_def.get(tool_def.identifier, None):
raise ValueError(f"Tool {tool_def.identifier} already exists")
if tool_name in (None, tool_def.identifier):
tool_name_to_def[tool_def.identifier] = ToolDefinition( tool_name_to_def[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier, tool_name=identifier,
description=tool_def.description, description=tool_def.description,
parameters={ parameters={
param.name: ToolParamDefinition( param.name: ToolParamDefinition(
@ -824,9 +811,9 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters for param in tool_def.parameters
}, },
) )
tool_name_to_group_id[tool_def.identifier] = tool_def.toolgroup_id tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
return list(tool_name_to_def.values()), tool_name_to_group_id self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]: def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
"""Parse a toolgroup name into its components. """Parse a toolgroup name into its components.
@ -850,29 +837,31 @@ class ChatAgent(ShieldRunnerMixin):
session_id: str, session_id: str,
tool_call: ToolCall, tool_call: ToolCall,
) -> ToolInvocationResult: ) -> ToolInvocationResult:
name = tool_call.tool_name tool_name = tool_call.tool_name
group_name = self.tool_name_to_group_id.get(name, None) registered_tool_names = [tool_def.tool_name for tool_def in self.tool_defs]
if group_name is None: if tool_name not in registered_tool_names:
raise ValueError( raise ValueError(
f"Tool {name} not found in any tool group, available tools: {', '.join(self.tool_name_to_group_id.keys())}" f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}"
) )
if isinstance(name, BuiltinTool): if isinstance(tool_name, BuiltinTool):
if name == BuiltinTool.brave_search: if tool_name == BuiltinTool.brave_search:
name = WEB_SEARCH_TOOL tool_name_str = WEB_SEARCH_TOOL
else: else:
name = name.value tool_name_str = tool_name.value
else:
tool_name_str = tool_name
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}") logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
result = await self.tool_runtime_api.invoke_tool( result = await self.tool_runtime_api.invoke_tool(
tool_name=name, tool_name=tool_name_str,
kwargs={ kwargs={
"session_id": session_id, "session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent # get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**tool_call.arguments, **tool_call.arguments,
**self.toolgroup_to_args.get(group_name, {}), **self.tool_name_to_args.get(tool_name_str, {}),
}, },
) )
logger.debug(f"tool call {name} completed with result: {result}") logger.debug(f"tool call {tool_name_str} completed with result: {result}")
return result return result
async def handle_documents( async def handle_documents(