mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
chore: simplify _get_tool_defs (#1384)
Summary: Test Plan: LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/integration/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct
This commit is contained in:
parent
41c9bca1aa
commit
ed6caead72
1 changed files with 50 additions and 61 deletions
|
@ -12,7 +12,7 @@ import secrets
|
||||||
import string
|
import string
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -457,10 +457,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
# if the session has a memory bank id, let the memory tool use it
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
if session_info and session_info.vector_db_id:
|
if session_info and session_info.vector_db_id:
|
||||||
if RAG_TOOL_GROUP not in self.toolgroup_to_args:
|
for tool_name in self.tool_name_to_args.keys():
|
||||||
self.toolgroup_to_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]}
|
if tool_name == MEMORY_QUERY_TOOL:
|
||||||
else:
|
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
|
||||||
self.toolgroup_to_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
|
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id]
|
||||||
|
else:
|
||||||
|
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
|
@ -727,18 +729,16 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
input_messages = input_messages + [message, result_message]
|
input_messages = input_messages + [message, result_message]
|
||||||
|
|
||||||
async def _initialize_tools(self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None):
|
async def _initialize_tools(
|
||||||
self.toolgroup_to_args = {}
|
self,
|
||||||
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
|
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||||
|
) -> None:
|
||||||
|
toolgroup_to_args = {}
|
||||||
|
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
||||||
self.toolgroup_to_args[tool_group_name] = toolgroup.args
|
toolgroup_to_args[tool_group_name] = toolgroup.args
|
||||||
|
|
||||||
self.tool_defs, self.tool_name_to_group_id = await self._get_tool_defs(toolgroups_for_turn)
|
|
||||||
|
|
||||||
async def _get_tool_defs(
|
|
||||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
|
||||||
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
|
||||||
# Determine which tools to include
|
# Determine which tools to include
|
||||||
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
|
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
|
||||||
agent_config_toolgroups = []
|
agent_config_toolgroups = []
|
||||||
|
@ -747,8 +747,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if name not in agent_config_toolgroups:
|
if name not in agent_config_toolgroups:
|
||||||
agent_config_toolgroups.append(name)
|
agent_config_toolgroups.append(name)
|
||||||
|
|
||||||
|
toolgroup_to_args = toolgroup_to_args or {}
|
||||||
|
|
||||||
tool_name_to_def = {}
|
tool_name_to_def = {}
|
||||||
tool_name_to_group_id = {}
|
tool_name_to_args = {}
|
||||||
|
|
||||||
for tool_def in self.agent_config.client_tools:
|
for tool_def in self.agent_config.client_tools:
|
||||||
if tool_name_to_def.get(tool_def.name, None):
|
if tool_name_to_def.get(tool_def.name, None):
|
||||||
|
@ -766,53 +768,38 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
for param in tool_def.parameters
|
for param in tool_def.parameters
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
tool_name_to_group_id[tool_def.name] = "__client_tools__"
|
|
||||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||||
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||||
if not tools.data:
|
if not tools.data:
|
||||||
available_tool_groups = ", ".join(
|
available_tool_groups = ", ".join(
|
||||||
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
|
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
|
||||||
)
|
)
|
||||||
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
|
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
|
||||||
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
|
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||||
)
|
)
|
||||||
|
|
||||||
for tool_def in tools.data:
|
for tool_def in tools.data:
|
||||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||||
tool_name = tool_def.identifier
|
identifier: str | BuiltinTool | None = tool_def.identifier
|
||||||
built_in_type = BuiltinTool.brave_search
|
if identifier == "web_search":
|
||||||
if tool_name == "web_search":
|
identifier = BuiltinTool.brave_search
|
||||||
built_in_type = BuiltinTool.brave_search
|
|
||||||
else:
|
else:
|
||||||
built_in_type = BuiltinTool(tool_name)
|
identifier = BuiltinTool(identifier)
|
||||||
|
else:
|
||||||
|
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
|
||||||
|
if input_tool_name in (None, tool_def.identifier):
|
||||||
|
identifier = tool_def.identifier
|
||||||
|
else:
|
||||||
|
identifier = None
|
||||||
|
|
||||||
if tool_name_to_def.get(built_in_type, None):
|
if tool_name_to_def.get(identifier, None):
|
||||||
raise ValueError(f"Tool {built_in_type} already exists")
|
raise ValueError(f"Tool {identifier} already exists")
|
||||||
|
if identifier:
|
||||||
tool_name_to_def[built_in_type] = ToolDefinition(
|
|
||||||
tool_name=built_in_type,
|
|
||||||
description=tool_def.description,
|
|
||||||
parameters={
|
|
||||||
param.name: ToolParamDefinition(
|
|
||||||
param_type=param.parameter_type,
|
|
||||||
description=param.description,
|
|
||||||
required=param.required,
|
|
||||||
default=param.default,
|
|
||||||
)
|
|
||||||
for param in tool_def.parameters
|
|
||||||
},
|
|
||||||
)
|
|
||||||
tool_name_to_group_id[built_in_type] = tool_def.toolgroup_id
|
|
||||||
continue
|
|
||||||
|
|
||||||
if tool_name_to_def.get(tool_def.identifier, None):
|
|
||||||
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
|
||||||
if tool_name in (None, tool_def.identifier):
|
|
||||||
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
||||||
tool_name=tool_def.identifier,
|
tool_name=identifier,
|
||||||
description=tool_def.description,
|
description=tool_def.description,
|
||||||
parameters={
|
parameters={
|
||||||
param.name: ToolParamDefinition(
|
param.name: ToolParamDefinition(
|
||||||
|
@ -824,9 +811,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
for param in tool_def.parameters
|
for param in tool_def.parameters
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
tool_name_to_group_id[tool_def.identifier] = tool_def.toolgroup_id
|
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||||
|
|
||||||
return list(tool_name_to_def.values()), tool_name_to_group_id
|
self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args
|
||||||
|
|
||||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
||||||
"""Parse a toolgroup name into its components.
|
"""Parse a toolgroup name into its components.
|
||||||
|
@ -850,29 +837,31 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
session_id: str,
|
session_id: str,
|
||||||
tool_call: ToolCall,
|
tool_call: ToolCall,
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
name = tool_call.tool_name
|
tool_name = tool_call.tool_name
|
||||||
group_name = self.tool_name_to_group_id.get(name, None)
|
registered_tool_names = [tool_def.tool_name for tool_def in self.tool_defs]
|
||||||
if group_name is None:
|
if tool_name not in registered_tool_names:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Tool {name} not found in any tool group, available tools: {', '.join(self.tool_name_to_group_id.keys())}"
|
f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}"
|
||||||
)
|
)
|
||||||
if isinstance(name, BuiltinTool):
|
if isinstance(tool_name, BuiltinTool):
|
||||||
if name == BuiltinTool.brave_search:
|
if tool_name == BuiltinTool.brave_search:
|
||||||
name = WEB_SEARCH_TOOL
|
tool_name_str = WEB_SEARCH_TOOL
|
||||||
else:
|
else:
|
||||||
name = name.value
|
tool_name_str = tool_name.value
|
||||||
|
else:
|
||||||
|
tool_name_str = tool_name
|
||||||
|
|
||||||
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
|
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
tool_name=name,
|
tool_name=tool_name_str,
|
||||||
kwargs={
|
kwargs={
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||||
**tool_call.arguments,
|
**tool_call.arguments,
|
||||||
**self.toolgroup_to_args.get(group_name, {}),
|
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.debug(f"tool call {name} completed with result: {result}")
|
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def handle_documents(
|
async def handle_documents(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue