simplify toolgroups registration

This commit is contained in:
Dinesh Yeduguru 2025-01-07 15:37:52 -08:00
parent ba242c04cc
commit f9a98c278a
15 changed files with 350 additions and 256 deletions

View file

@ -13,7 +13,7 @@ import secrets
import string
import uuid
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional
from typing import Any, AsyncGenerator, Dict, List, Optional
from urllib.parse import urlparse
import httpx
@ -21,8 +21,8 @@ from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDe
from llama_stack.apis.agents import (
AgentConfig,
AgentTool,
AgentToolWithArgs,
AgentToolGroup,
AgentToolGroupWithArgs,
AgentTurnCreateRequest,
AgentTurnResponseEvent,
AgentTurnResponseEventType,
@ -76,6 +76,10 @@ def make_random_string(length: int = 8):
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_TOOL_GROUP_ID = "builtin::memory"
MEMORY_QUERY_TOOL = "query_memory"
CODE_INTERPRETER_TOOL = "code_interpreter"
WEB_SEARCH_TOOL = "web_search"
class ChatAgent(ShieldRunnerMixin):
@ -192,7 +196,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
documents=request.documents,
tools_for_turn=request.tools,
toolgroups_for_turn=request.toolgroups,
):
if isinstance(chunk, CompletionMessage):
log.info(
@ -243,7 +247,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
tools_for_turn: Optional[List[AgentTool]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -266,7 +270,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params,
stream,
documents,
tools_for_turn,
toolgroups_for_turn,
):
if isinstance(res, bool):
return
@ -362,21 +366,24 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
tools_for_turn: Optional[List[AgentTool]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator:
tool_args = {}
if tools_for_turn:
for tool in tools_for_turn:
if isinstance(tool, AgentToolWithArgs):
tool_args[tool.name] = tool.args
toolgroup_args = {}
for toolgroup in self.agent_config.toolgroups:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroup_args[toolgroup.name] = toolgroup.args
if toolgroups_for_turn:
for toolgroup in toolgroups_for_turn:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroup_args[toolgroup.name] = toolgroup.args
tool_defs = await self._get_tool_defs(tools_for_turn)
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents:
await self.handle_documents(
session_id, documents, input_messages, tool_defs
)
if "memory" in tool_defs and len(input_messages) > 0:
with tracing.span("memory_tool") as span:
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -386,18 +393,16 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
extra_args = tool_args.get("memory", {})
tool_args = {
# Query memory with the last message's content
"query": input_messages[-1],
**extra_args,
query_args = {
"messages": [msg.content for msg in input_messages],
**toolgroup_args.get(MEMORY_TOOL_GROUP_ID, {}),
}
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id:
tool_args["memory_bank_id"] = session_info.memory_bank_id
serialized_args = tracing.serialize_value(tool_args)
query_args["memory_bank_id"] = session_info.memory_bank_id
serialized_args = tracing.serialize_value(query_args)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
@ -415,8 +420,8 @@ class ChatAgent(ShieldRunnerMixin):
)
)
result = await self.tool_runtime_api.invoke_tool(
tool_name="memory",
args=tool_args,
tool_name=MEMORY_QUERY_TOOL,
args=query_args,
)
yield AgentTurnResponseStreamChunk(
@ -485,7 +490,8 @@ class ChatAgent(ShieldRunnerMixin):
tools=[
tool
for tool in tool_defs.values()
if tool.tool_name != "memory"
if tool_to_group.get(tool.tool_name, None)
!= MEMORY_TOOL_GROUP_ID
],
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
@ -632,6 +638,8 @@ class ChatAgent(ShieldRunnerMixin):
self.tool_runtime_api,
session_id,
[message],
toolgroup_args,
tool_to_group,
)
assert (
len(result_messages) == 1
@ -690,26 +698,37 @@ class ChatAgent(ShieldRunnerMixin):
n_iter += 1
async def _get_tool_defs(
self, tools_for_turn: Optional[List[AgentTool]]
self, toolgroups_for_turn: Optional[List[AgentToolGroup]]
) -> Dict[str, ToolDefinition]:
# Determine which tools to include
agent_config_tools = set(
tool.name if isinstance(tool, AgentToolWithArgs) else tool
for tool in self.agent_config.tools
agent_config_toolgroups = set(
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
for toolgroup in self.agent_config.toolgroups
)
tools_for_turn_set = (
agent_config_tools
if tools_for_turn is None
toolgroups_for_turn_set = (
agent_config_toolgroups
if toolgroups_for_turn is None
else {
tool.name if isinstance(tool, AgentToolWithArgs) else tool
for tool in tools_for_turn
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
for toolgroup in toolgroups_for_turn
}
)
ret = {}
tool_def_map = {}
tool_to_group = {}
for tool_def in self.agent_config.client_tools:
ret[tool_def.name] = ToolDefinition(
if tool_def_map.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
tool_def_map[tool_def.name] = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
@ -722,41 +741,42 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters
},
)
for tool_name in agent_config_tools:
if tool_name not in tools_for_turn_set:
tool_to_group[tool_def.name] = "__client_tools__"
for toolgroup_name in agent_config_toolgroups:
if toolgroup_name not in toolgroups_for_turn_set:
continue
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
for tool_def in tools:
if tool_def.built_in_type:
if tool_def_map.get(tool_def.built_in_type, None):
raise ValueError(
f"Tool {tool_def.built_in_type} already exists"
)
tool_def = await self.tool_groups_api.get_tool(tool_name)
if tool_def is None:
raise ValueError(f"Tool {tool_name} not found")
if tool_def.identifier.startswith("builtin::"):
built_in_type = tool_def.identifier[len("builtin::") :]
if built_in_type == "web_search":
built_in_type = "brave_search"
if built_in_type not in BuiltinTool.__members__:
raise ValueError(f"Unknown built-in tool: {built_in_type}")
ret[built_in_type] = ToolDefinition(
tool_name=BuiltinTool(built_in_type)
)
continue
ret[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
tool_def_map[tool_def.built_in_type] = ToolDefinition(
tool_name=tool_def.built_in_type
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.built_in_type] = tool_def.toolgroup_id
continue
return ret
if tool_def_map.get(tool_def.identifier, None):
raise ValueError(f"Tool {tool_def.identifier} already exists")
tool_def_map[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
return tool_def_map, tool_to_group
async def handle_documents(
self,
@ -765,8 +785,8 @@ class ChatAgent(ShieldRunnerMixin):
input_messages: List[Message],
tool_defs: Dict[str, ToolDefinition],
) -> None:
memory_tool = tool_defs.get("memory", None)
code_interpreter_tool = tool_defs.get("code_interpreter", None)
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
code_interpreter_tool = tool_defs.get(CODE_INTERPRETER_TOOL, None)
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")
@ -903,7 +923,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
async def execute_tool_call_maybe(
tool_runtime_api: ToolRuntime, session_id: str, messages: List[CompletionMessage]
tool_runtime_api: ToolRuntime,
session_id: str,
messages: List[CompletionMessage],
toolgroup_args: Dict[str, Dict[str, Any]],
tool_to_group: Dict[str, str],
) -> List[ToolResponseMessage]:
# While Tools.run interface takes a list of messages,
# All tools currently only run on a single message
@ -915,18 +939,26 @@ async def execute_tool_call_maybe(
tool_call = message.tool_calls[0]
name = tool_call.tool_name
group_name = tool_to_group.get(name, None)
if group_name is None:
raise ValueError(f"Tool {name} not found in any tool group")
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
tool_call_args = tool_call.arguments
tool_call_args.update(toolgroup_args.get(group_name, {}))
if isinstance(name, BuiltinTool):
if name == BuiltinTool.brave_search:
name = "builtin::web_search"
name = WEB_SEARCH_TOOL
else:
name = "builtin::" + name.value
name = name.value
result = await tool_runtime_api.invoke_tool(
tool_name=name,
args=dict(
session_id=session_id,
**tool_call.arguments,
**tool_call_args,
),
)
return [
ToolResponseMessage(
call_id=tool_call.call_id,

View file

@ -19,7 +19,7 @@ from llama_stack.apis.agents import (
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentTool,
AgentToolGroup,
AgentTurnCreateRequest,
Document,
Session,
@ -147,7 +147,7 @@ class MetaReferenceAgentsImpl(Agents):
ToolResponseMessage,
]
],
tools: Optional[List[AgentTool]] = None,
tools: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator: