chore: refactor Agent toolgroup processing (#1381)

Summary:
Refactoring only.

Centralize logic to preprocess toolgroup to one place. 

Test Plan:
LLAMA_STACK_CONFIG=fireworks pytest -s -v
tests/api/agents/test_agents.py --safety-shield
meta-llama/Llama-Guard-3-8B
---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1381).
* #1384
* __->__ #1381
This commit is contained in:
ehhuang 2025-03-12 18:48:03 -07:00 committed by GitHub
parent 99bbe0e70b
commit 41c9bca1aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -12,7 +12,7 @@ import secrets
import string import string
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
@ -181,6 +181,7 @@ class ChatAgent(ShieldRunnerMixin):
return messages return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
await self._initialize_tools(request.toolgroups)
async with tracing.span("create_and_execute_turn") as span: async with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
@ -191,6 +192,7 @@ class ChatAgent(ShieldRunnerMixin):
yield chunk yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
await self._initialize_tools()
async with tracing.span("resume_turn") as span: async with tracing.span("resume_turn") as span:
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
@ -275,7 +277,6 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params=self.agent_config.sampling_params, sampling_params=self.agent_config.sampling_params,
stream=request.stream, stream=request.stream,
documents=request.documents if not is_resume else None, documents=request.documents if not is_resume else None,
toolgroups_for_turn=request.toolgroups if not is_resume else None,
): ):
if isinstance(chunk, CompletionMessage): if isinstance(chunk, CompletionMessage):
output_message = chunk output_message = chunk
@ -327,7 +328,6 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams, sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
documents: Optional[List[Document]] = None, documents: Optional[List[Document]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to # Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot # streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -350,7 +350,6 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params, sampling_params,
stream, stream,
documents, documents,
toolgroups_for_turn,
): ):
if isinstance(res, bool): if isinstance(res, bool):
return return
@ -451,30 +450,17 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams, sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
documents: Optional[List[Document]] = None, documents: Optional[List[Document]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
# TODO: simplify all of this code, it can be simpler
toolgroup_args = {}
toolgroups = set()
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
if isinstance(toolgroup, AgentToolGroupWithArgs):
tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name)
toolgroups.add(tool_group_name)
toolgroup_args[tool_group_name] = toolgroup.args
else:
toolgroups.add(toolgroup)
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents: if documents:
await self.handle_documents(session_id, documents, input_messages, tool_defs) await self.handle_documents(session_id, documents, input_messages)
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it # if the session has a memory bank id, let the memory tool use it
if session_info and session_info.vector_db_id: if session_info and session_info.vector_db_id:
if RAG_TOOL_GROUP not in toolgroup_args: if RAG_TOOL_GROUP not in self.toolgroup_to_args:
toolgroup_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]} self.toolgroup_to_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]}
else: else:
toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id) self.toolgroup_to_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
output_attachments = [] output_attachments = []
@ -504,7 +490,7 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in await self.inference_api.chat_completion( async for chunk in await self.inference_api.chat_completion(
self.agent_config.model, self.agent_config.model,
input_messages, input_messages,
tools=tool_defs, tools=self.tool_defs,
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
response_format=self.agent_config.response_format, response_format=self.agent_config.response_format,
stream=True, stream=True,
@ -686,12 +672,9 @@ class ChatAgent(ShieldRunnerMixin):
) as span: ) as span:
tool_execution_start_time = datetime.now().astimezone().isoformat() tool_execution_start_time = datetime.now().astimezone().isoformat()
tool_call = message.tool_calls[0] tool_call = message.tool_calls[0]
tool_result = await execute_tool_call_maybe( tool_result = await self.execute_tool_call_maybe(
self.tool_runtime_api,
session_id, session_id,
tool_call, tool_call,
toolgroup_args,
tool_to_group,
) )
if tool_result.content is None: if tool_result.content is None:
raise ValueError( raise ValueError(
@ -744,6 +727,15 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message, result_message] input_messages = input_messages + [message, result_message]
async def _initialize_tools(self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None):
self.toolgroup_to_args = {}
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
if isinstance(toolgroup, AgentToolGroupWithArgs):
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
self.toolgroup_to_args[tool_group_name] = toolgroup.args
self.tool_defs, self.tool_name_to_group_id = await self._get_tool_defs(toolgroups_for_turn)
async def _get_tool_defs( async def _get_tool_defs(
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[List[ToolDefinition], Dict[str, str]]: ) -> Tuple[List[ToolDefinition], Dict[str, str]]:
@ -756,7 +748,7 @@ class ChatAgent(ShieldRunnerMixin):
agent_config_toolgroups.append(name) agent_config_toolgroups.append(name)
tool_name_to_def = {} tool_name_to_def = {}
tool_to_group = {} tool_name_to_group_id = {}
for tool_def in self.agent_config.client_tools: for tool_def in self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None): if tool_name_to_def.get(tool_def.name, None):
@ -774,7 +766,7 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters for param in tool_def.parameters
}, },
) )
tool_to_group[tool_def.name] = "__client_tools__" tool_name_to_group_id[tool_def.name] = "__client_tools__"
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
@ -813,7 +805,7 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters for param in tool_def.parameters
}, },
) )
tool_to_group[built_in_type] = tool_def.toolgroup_id tool_name_to_group_id[built_in_type] = tool_def.toolgroup_id
continue continue
if tool_name_to_def.get(tool_def.identifier, None): if tool_name_to_def.get(tool_def.identifier, None):
@ -832,9 +824,9 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters for param in tool_def.parameters
}, },
) )
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id tool_name_to_group_id[tool_def.identifier] = tool_def.toolgroup_id
return list(tool_name_to_def.values()), tool_to_group return list(tool_name_to_def.values()), tool_name_to_group_id
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]: def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
"""Parse a toolgroup name into its components. """Parse a toolgroup name into its components.
@ -853,15 +845,44 @@ class ChatAgent(ShieldRunnerMixin):
tool_group, tool_name = split_names[0], None tool_group, tool_name = split_names[0], None
return tool_group, tool_name return tool_group, tool_name
async def execute_tool_call_maybe(
self,
session_id: str,
tool_call: ToolCall,
) -> ToolInvocationResult:
name = tool_call.tool_name
group_name = self.tool_name_to_group_id.get(name, None)
if group_name is None:
raise ValueError(
f"Tool {name} not found in any tool group, available tools: {', '.join(self.tool_name_to_group_id.keys())}"
)
if isinstance(name, BuiltinTool):
if name == BuiltinTool.brave_search:
name = WEB_SEARCH_TOOL
else:
name = name.value
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
result = await self.tool_runtime_api.invoke_tool(
tool_name=name,
kwargs={
"session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**tool_call.arguments,
**self.toolgroup_to_args.get(group_name, {}),
},
)
logger.debug(f"tool call {name} completed with result: {result}")
return result
async def handle_documents( async def handle_documents(
self, self,
session_id: str, session_id: str,
documents: List[Document], documents: List[Document],
input_messages: List[Message], input_messages: List[Message],
tool_defs: Dict[str, ToolDefinition],
) -> None: ) -> None:
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs) memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs) code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
content_items = [] content_items = []
url_items = [] url_items = []
pattern = re.compile("^(https?://|file://|data:)") pattern = re.compile("^(https?://|file://|data:)")
@ -994,37 +1015,6 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
) )
async def execute_tool_call_maybe(
tool_runtime_api: ToolRuntime,
session_id: str,
tool_call: ToolCall,
toolgroup_args: Dict[str, Dict[str, Any]],
tool_to_group: Dict[str, str],
) -> ToolInvocationResult:
name = tool_call.tool_name
group_name = tool_to_group.get(name, None)
if group_name is None:
raise ValueError(f"Tool {name} not found in any tool group")
if isinstance(name, BuiltinTool):
if name == BuiltinTool.brave_search:
name = WEB_SEARCH_TOOL
else:
name = name.value
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
result = await tool_runtime_api.invoke_tool(
tool_name=name,
kwargs={
"session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**tool_call.arguments,
**toolgroup_args.get(group_name, {}),
},
)
logger.info(f"tool call {name} completed with result: {result}")
return result
def _interpret_content_as_attachment( def _interpret_content_as_attachment(
content: str, content: str,
) -> Optional[Attachment]: ) -> Optional[Attachment]: