From ed6caead724aa8b5b1c4e53528e120888517a812 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 12 Mar 2025 18:51:18 -0700 Subject: [PATCH] chore: simplify _get_tool_defs (#1384) Summary: Test Plan: LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/integration/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct --- .../agents/meta_reference/agent_instance.py | 111 ++++++++---------- 1 file changed, 50 insertions(+), 61 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 1884094df..3f09cacc0 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -12,7 +12,7 @@ import secrets import string import uuid from datetime import datetime -from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union +from typing import AsyncGenerator, List, Optional, Union from urllib.parse import urlparse import httpx @@ -457,10 +457,12 @@ class ChatAgent(ShieldRunnerMixin): session_info = await self.storage.get_session_info(session_id) # if the session has a memory bank id, let the memory tool use it if session_info and session_info.vector_db_id: - if RAG_TOOL_GROUP not in self.toolgroup_to_args: - self.toolgroup_to_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]} - else: - self.toolgroup_to_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id) + for tool_name in self.tool_name_to_args.keys(): + if tool_name == MEMORY_QUERY_TOOL: + if "vector_db_ids" not in self.tool_name_to_args[tool_name]: + self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id] + else: + self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id) output_attachments = [] @@ -727,18 +729,16 @@ class ChatAgent(ShieldRunnerMixin): input_messages = input_messages + [message, result_message] - async def _initialize_tools(self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None): - self.toolgroup_to_args = {} - for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []): + async def _initialize_tools( + self, + toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, + ) -> None: + toolgroup_to_args = {} + for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []): if isinstance(toolgroup, AgentToolGroupWithArgs): tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name) - self.toolgroup_to_args[tool_group_name] = toolgroup.args + toolgroup_to_args[tool_group_name] = toolgroup.args - self.tool_defs, self.tool_name_to_group_id = await self._get_tool_defs(toolgroups_for_turn) - - async def _get_tool_defs( - self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None - ) -> Tuple[List[ToolDefinition], Dict[str, str]]: # Determine which tools to include tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or [] agent_config_toolgroups = [] @@ -747,8 +747,10 @@ class ChatAgent(ShieldRunnerMixin): if name not in agent_config_toolgroups: agent_config_toolgroups.append(name) + toolgroup_to_args = toolgroup_to_args or {} + tool_name_to_def = {} - tool_name_to_group_id = {} + tool_name_to_args = {} for tool_def in self.agent_config.client_tools: if tool_name_to_def.get(tool_def.name, None): @@ -766,53 +768,38 @@ class ChatAgent(ShieldRunnerMixin): for param in tool_def.parameters }, ) - tool_name_to_group_id[tool_def.name] = "__client_tools__" for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: - toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) + toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) if not tools.data: available_tool_groups = ", ".join( [t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data] ) raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}") - if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data): + if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data): raise ValueError( - f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}" + f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}" ) for tool_def in tools.data: if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP: - tool_name = tool_def.identifier - built_in_type = BuiltinTool.brave_search - if tool_name == "web_search": - built_in_type = BuiltinTool.brave_search + identifier: str | BuiltinTool | None = tool_def.identifier + if identifier == "web_search": + identifier = BuiltinTool.brave_search else: - built_in_type = BuiltinTool(tool_name) + identifier = BuiltinTool(identifier) + else: + # add if tool_name is unspecified or the tool_def identifier is the same as the tool_name + if input_tool_name in (None, tool_def.identifier): + identifier = tool_def.identifier + else: + identifier = None - if tool_name_to_def.get(built_in_type, None): - raise ValueError(f"Tool {built_in_type} already exists") - - tool_name_to_def[built_in_type] = ToolDefinition( - tool_name=built_in_type, - description=tool_def.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - for param in tool_def.parameters - }, - ) - tool_name_to_group_id[built_in_type] = tool_def.toolgroup_id - continue - - if tool_name_to_def.get(tool_def.identifier, None): - raise ValueError(f"Tool {tool_def.identifier} already exists") - if tool_name in (None, tool_def.identifier): + if tool_name_to_def.get(identifier, None): + raise ValueError(f"Tool {identifier} already exists") + if identifier: tool_name_to_def[tool_def.identifier] = ToolDefinition( - tool_name=tool_def.identifier, + tool_name=identifier, description=tool_def.description, parameters={ param.name: ToolParamDefinition( @@ -824,9 +811,9 @@ class ChatAgent(ShieldRunnerMixin): for param in tool_def.parameters }, ) - tool_name_to_group_id[tool_def.identifier] = tool_def.toolgroup_id + tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {}) - return list(tool_name_to_def.values()), tool_name_to_group_id + self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]: """Parse a toolgroup name into its components. @@ -850,29 +837,31 @@ class ChatAgent(ShieldRunnerMixin): session_id: str, tool_call: ToolCall, ) -> ToolInvocationResult: - name = tool_call.tool_name - group_name = self.tool_name_to_group_id.get(name, None) - if group_name is None: + tool_name = tool_call.tool_name + registered_tool_names = [tool_def.tool_name for tool_def in self.tool_defs] + if tool_name not in registered_tool_names: raise ValueError( - f"Tool {name} not found in any tool group, available tools: {', '.join(self.tool_name_to_group_id.keys())}" + f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}" ) - if isinstance(name, BuiltinTool): - if name == BuiltinTool.brave_search: - name = WEB_SEARCH_TOOL + if isinstance(tool_name, BuiltinTool): + if tool_name == BuiltinTool.brave_search: + tool_name_str = WEB_SEARCH_TOOL else: - name = name.value + tool_name_str = tool_name.value + else: + tool_name_str = tool_name - logger.info(f"executing tool call: {name} with args: {tool_call.arguments}") + logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}") result = await self.tool_runtime_api.invoke_tool( - tool_name=name, + tool_name=tool_name_str, kwargs={ "session_id": session_id, # get the arguments generated by the model and augment with toolgroup arg overrides for the agent **tool_call.arguments, - **self.toolgroup_to_args.get(group_name, {}), + **self.tool_name_to_args.get(tool_name_str, {}), }, ) - logger.debug(f"tool call {name} completed with result: {result}") + logger.debug(f"tool call {tool_name_str} completed with result: {result}") return result async def handle_documents(