From 41c9bca1aa7a44cf2048b6c9371cd7740d2e47c1 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 12 Mar 2025 18:48:03 -0700 Subject: [PATCH] chore: refactor Agent toolgroup processing (#1381) Summary: Refactoring only. Centralize logic to preprocess toolgroup to one place. Test Plan: LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/api/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1381). * #1384 * __->__ #1381 --- .../agents/meta_reference/agent_instance.py | 120 ++++++++---------- 1 file changed, 55 insertions(+), 65 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 1d9f54e96..1884094df 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -12,7 +12,7 @@ import secrets import string import uuid from datetime import datetime -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union +from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse import httpx @@ -181,6 +181,7 @@ class ChatAgent(ShieldRunnerMixin): return messages async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: + await self._initialize_tools(request.toolgroups) async with tracing.span("create_and_execute_turn") as span: span.set_attribute("session_id", request.session_id) span.set_attribute("agent_id", self.agent_id) @@ -191,6 +192,7 @@ class ChatAgent(ShieldRunnerMixin): yield chunk async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: + await self._initialize_tools() async with tracing.span("resume_turn") as span: span.set_attribute("agent_id", self.agent_id) span.set_attribute("session_id", request.session_id) @@ -275,7 +277,6 @@ class ChatAgent(ShieldRunnerMixin): sampling_params=self.agent_config.sampling_params, stream=request.stream, documents=request.documents if not is_resume else None, - toolgroups_for_turn=request.toolgroups if not is_resume else None, ): if isinstance(chunk, CompletionMessage): output_message = chunk @@ -327,7 +328,6 @@ class ChatAgent(ShieldRunnerMixin): sampling_params: SamplingParams, stream: bool = False, documents: Optional[List[Document]] = None, - toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, ) -> AsyncGenerator: # Doing async generators makes downstream code much simpler and everything amenable to # streaming. However, it also makes things complicated here because AsyncGenerators cannot @@ -350,7 +350,6 @@ class ChatAgent(ShieldRunnerMixin): sampling_params, stream, documents, - toolgroups_for_turn, ): if isinstance(res, bool): return @@ -451,30 +450,17 @@ class ChatAgent(ShieldRunnerMixin): sampling_params: SamplingParams, stream: bool = False, documents: Optional[List[Document]] = None, - toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, ) -> AsyncGenerator: - # TODO: simplify all of this code, it can be simpler - toolgroup_args = {} - toolgroups = set() - for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []): - if isinstance(toolgroup, AgentToolGroupWithArgs): - tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name) - toolgroups.add(tool_group_name) - toolgroup_args[tool_group_name] = toolgroup.args - else: - toolgroups.add(toolgroup) - - tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn) if documents: - await self.handle_documents(session_id, documents, input_messages, tool_defs) + await self.handle_documents(session_id, documents, input_messages) session_info = await self.storage.get_session_info(session_id) # if the session has a memory bank id, let the memory tool use it if session_info and session_info.vector_db_id: - if RAG_TOOL_GROUP not in toolgroup_args: - toolgroup_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]} + if RAG_TOOL_GROUP not in self.toolgroup_to_args: + self.toolgroup_to_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]} else: - toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id) + self.toolgroup_to_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id) output_attachments = [] @@ -504,7 +490,7 @@ class ChatAgent(ShieldRunnerMixin): async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, - tools=tool_defs, + tools=self.tool_defs, tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, response_format=self.agent_config.response_format, stream=True, @@ -686,12 +672,9 @@ class ChatAgent(ShieldRunnerMixin): ) as span: tool_execution_start_time = datetime.now().astimezone().isoformat() tool_call = message.tool_calls[0] - tool_result = await execute_tool_call_maybe( - self.tool_runtime_api, + tool_result = await self.execute_tool_call_maybe( session_id, tool_call, - toolgroup_args, - tool_to_group, ) if tool_result.content is None: raise ValueError( @@ -744,6 +727,15 @@ class ChatAgent(ShieldRunnerMixin): input_messages = input_messages + [message, result_message] + async def _initialize_tools(self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None): + self.toolgroup_to_args = {} + for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []): + if isinstance(toolgroup, AgentToolGroupWithArgs): + tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name) + self.toolgroup_to_args[tool_group_name] = toolgroup.args + + self.tool_defs, self.tool_name_to_group_id = await self._get_tool_defs(toolgroups_for_turn) + async def _get_tool_defs( self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None ) -> Tuple[List[ToolDefinition], Dict[str, str]]: @@ -756,7 +748,7 @@ class ChatAgent(ShieldRunnerMixin): agent_config_toolgroups.append(name) tool_name_to_def = {} - tool_to_group = {} + tool_name_to_group_id = {} for tool_def in self.agent_config.client_tools: if tool_name_to_def.get(tool_def.name, None): @@ -774,7 +766,7 @@ class ChatAgent(ShieldRunnerMixin): for param in tool_def.parameters }, ) - tool_to_group[tool_def.name] = "__client_tools__" + tool_name_to_group_id[tool_def.name] = "__client_tools__" for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) @@ -813,7 +805,7 @@ class ChatAgent(ShieldRunnerMixin): for param in tool_def.parameters }, ) - tool_to_group[built_in_type] = tool_def.toolgroup_id + tool_name_to_group_id[built_in_type] = tool_def.toolgroup_id continue if tool_name_to_def.get(tool_def.identifier, None): @@ -832,9 +824,9 @@ class ChatAgent(ShieldRunnerMixin): for param in tool_def.parameters }, ) - tool_to_group[tool_def.identifier] = tool_def.toolgroup_id + tool_name_to_group_id[tool_def.identifier] = tool_def.toolgroup_id - return list(tool_name_to_def.values()), tool_to_group + return list(tool_name_to_def.values()), tool_name_to_group_id def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]: """Parse a toolgroup name into its components. @@ -853,15 +845,44 @@ class ChatAgent(ShieldRunnerMixin): tool_group, tool_name = split_names[0], None return tool_group, tool_name + async def execute_tool_call_maybe( + self, + session_id: str, + tool_call: ToolCall, + ) -> ToolInvocationResult: + name = tool_call.tool_name + group_name = self.tool_name_to_group_id.get(name, None) + if group_name is None: + raise ValueError( + f"Tool {name} not found in any tool group, available tools: {', '.join(self.tool_name_to_group_id.keys())}" + ) + if isinstance(name, BuiltinTool): + if name == BuiltinTool.brave_search: + name = WEB_SEARCH_TOOL + else: + name = name.value + + logger.info(f"executing tool call: {name} with args: {tool_call.arguments}") + result = await self.tool_runtime_api.invoke_tool( + tool_name=name, + kwargs={ + "session_id": session_id, + # get the arguments generated by the model and augment with toolgroup arg overrides for the agent + **tool_call.arguments, + **self.toolgroup_to_args.get(group_name, {}), + }, + ) + logger.debug(f"tool call {name} completed with result: {result}") + return result + async def handle_documents( self, session_id: str, documents: List[Document], input_messages: List[Message], - tool_defs: Dict[str, ToolDefinition], ) -> None: - memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs) - code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs) + memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs) + code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs) content_items = [] url_items = [] pattern = re.compile("^(https?://|file://|data:)") @@ -994,37 +1015,6 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa ) -async def execute_tool_call_maybe( - tool_runtime_api: ToolRuntime, - session_id: str, - tool_call: ToolCall, - toolgroup_args: Dict[str, Dict[str, Any]], - tool_to_group: Dict[str, str], -) -> ToolInvocationResult: - name = tool_call.tool_name - group_name = tool_to_group.get(name, None) - if group_name is None: - raise ValueError(f"Tool {name} not found in any tool group") - if isinstance(name, BuiltinTool): - if name == BuiltinTool.brave_search: - name = WEB_SEARCH_TOOL - else: - name = name.value - - logger.info(f"executing tool call: {name} with args: {tool_call.arguments}") - result = await tool_runtime_api.invoke_tool( - tool_name=name, - kwargs={ - "session_id": session_id, - # get the arguments generated by the model and augment with toolgroup arg overrides for the agent - **tool_call.arguments, - **toolgroup_args.get(group_name, {}), - }, - ) - logger.info(f"tool call {name} completed with result: {result}") - return result - - def _interpret_content_as_attachment( content: str, ) -> Optional[Attachment]: