feat: allow specifying specific tool within toolgroup (#1239)

Summary:

E.g. `builtin::rag::knowledge_search`

Test Plan:
```
LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/ --safety-shield meta-llama/Llama-Guard-3-8B
```
This commit is contained in:
ehhuang 2025-02-26 14:07:05 -08:00 committed by GitHub
parent 657efc67bc
commit c8a20b8ed0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 80 additions and 64 deletions

View file

@ -803,7 +803,7 @@
} }
], ],
"source": [ "source": [
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n", "model_id = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
"\n", "\n",
"model_id\n" "model_id\n"
] ]
@ -1688,7 +1688,7 @@
" enable_session_persistence=False,\n", " enable_session_persistence=False,\n",
" toolgroups = [\n", " toolgroups = [\n",
" {\n", " {\n",
" \"name\": \"builtin::rag\",\n", " \"name\": \"builtin::rag/knowledge_search\",\n",
" \"args\" : {\n", " \"args\" : {\n",
" \"vector_db_ids\": [vector_db_id],\n", " \"vector_db_ids\": [vector_db_id],\n",
" }\n", " }\n",

View file

@ -7,12 +7,12 @@ Each agent turn follows these key steps:
1. **Initial Safety Check**: The user's input is first screened through configured safety shields 1. **Initial Safety Check**: The user's input is first screened through configured safety shields
2. **Context Retrieval**: 2. **Context Retrieval**:
- If RAG is enabled, the agent queries relevant documents from memory banks - If RAG is enabled, the agent can choose to query relevant documents from memory banks. You can use the `instructions` field to steer the agent.
- For new documents, they are first inserted into the memory bank - For new documents, they are first inserted into the memory bank.
- Retrieved context is augmented to the user's prompt - Retrieved context is provided to the LLM as a tool response in the message history.
3. **Inference Loop**: The agent enters its main execution loop: 3. **Inference Loop**: The agent enters its main execution loop:
- The LLM receives the augmented prompt (with context and/or previous tool outputs) - The LLM receives a user prompt (with previous tool outputs)
- The LLM generates a response, potentially with tool calls - The LLM generates a response, potentially with tool calls
- If tool calls are present: - If tool calls are present:
- Tool inputs are safety-checked - Tool inputs are safety-checked
@ -40,19 +40,16 @@ sequenceDiagram
S->>E: Input Safety Check S->>E: Input Safety Check
deactivate S deactivate S
E->>M: 2.1 Query Context
M-->>E: 2.2 Retrieved Documents
loop Inference Loop loop Inference Loop
E->>L: 3.1 Augment with Context E->>L: 2.1 Augment with Context
L-->>E: 3.2 Response (with/without tool calls) L-->>E: 2.2 Response (with/without tool calls)
alt Has Tool Calls alt Has Tool Calls
E->>S: Check Tool Input E->>S: Check Tool Input
S->>T: 4.1 Execute Tool S->>T: 3.1 Execute Tool
T-->>E: 4.2 Tool Response T-->>E: 3.2 Tool Response
E->>L: 5.1 Tool Response E->>L: 4.1 Tool Response
L-->>E: 5.2 Synthesized Response L-->>E: 4.2 Synthesized Response
end end
opt Stop Conditions opt Stop Conditions
@ -64,7 +61,7 @@ sequenceDiagram
end end
E->>S: Output Safety Check E->>S: Output Safety Check
S->>U: 6. Final Response S->>U: 5. Final Response
``` ```
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution: Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
@ -77,7 +74,10 @@ agent_config = AgentConfig(
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
# Enable both RAG and tool usage # Enable both RAG and tool usage
toolgroups=[ toolgroups=[
{"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}}, {
"name": "builtin::rag/knowledge_search",
"args": {"vector_db_ids": ["my_docs"]},
},
"builtin::code_interpreter", "builtin::code_interpreter",
], ],
# Configure safety # Configure safety

View file

@ -91,7 +91,7 @@ agent_config = AgentConfig(
enable_session_persistence=False, enable_session_persistence=False,
toolgroups=[ toolgroups=[
{ {
"name": "builtin::rag", "name": "builtin::rag/knowledge_search",
"args": { "args": {
"vector_db_ids": [vector_db_id], "vector_db_ids": [vector_db_id],
}, },

View file

@ -243,7 +243,7 @@ agent_config = AgentConfig(
# Define tools available to the agent # Define tools available to the agent
toolgroups=[ toolgroups=[
{ {
"name": "builtin::rag", "name": "builtin::rag/knowledge_search",
"args": { "args": {
"vector_db_ids": [vector_db_id], "vector_db_ids": [vector_db_id],
}, },

View file

@ -132,7 +132,7 @@ def rag_chat_page():
}, },
toolgroups=[ toolgroups=[
dict( dict(
name="builtin::rag", name="builtin::rag/knowledge_search",
args={ args={
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs], "vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
}, },

View file

@ -497,19 +497,13 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: simplify all of this code, it can be simpler # TODO: simplify all of this code, it can be simpler
toolgroup_args = {} toolgroup_args = {}
toolgroups = set() toolgroups = set()
for toolgroup in self.agent_config.toolgroups: for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
if isinstance(toolgroup, AgentToolGroupWithArgs): if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name) tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args toolgroups.add(tool_group_name)
toolgroup_args[tool_group_name] = toolgroup.args
else: else:
toolgroups.add(toolgroup) toolgroups.add(toolgroup)
if toolgroups_for_turn:
for toolgroup in toolgroups_for_turn:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args
else:
toolgroups.add(toolgroup)
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn) tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents: if documents:
@ -542,7 +536,7 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in await self.inference_api.chat_completion( async for chunk in await self.inference_api.chat_completion(
self.agent_config.model, self.agent_config.model,
input_messages, input_messages,
tools=[tool for tool in tool_defs.values()], tools=tool_defs,
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
response_format=self.agent_config.response_format, response_format=self.agent_config.response_format,
stream=True, stream=True,
@ -768,7 +762,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _get_tool_defs( async def _get_tool_defs(
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]: ) -> Tuple[List[ToolDefinition], Dict[str, str]]:
# Determine which tools to include # Determine which tools to include
agent_config_toolgroups = set( agent_config_toolgroups = set(
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup) (toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
@ -783,13 +777,13 @@ class ChatAgent(ShieldRunnerMixin):
} }
) )
tool_def_map = {} tool_name_to_def = {}
tool_to_group = {} tool_to_group = {}
for tool_def in self.agent_config.client_tools: for tool_def in self.agent_config.client_tools:
if tool_def_map.get(tool_def.name, None): if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists") raise ValueError(f"Tool {tool_def.name} already exists")
tool_def_map[tool_def.name] = ToolDefinition( tool_name_to_def[tool_def.name] = ToolDefinition(
tool_name=tool_def.name, tool_name=tool_def.name,
description=tool_def.description, description=tool_def.description,
parameters={ parameters={
@ -803,10 +797,17 @@ class ChatAgent(ShieldRunnerMixin):
}, },
) )
tool_to_group[tool_def.name] = "__client_tools__" tool_to_group[tool_def.name] = "__client_tools__"
for toolgroup_name in agent_config_toolgroups: for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
if toolgroup_name not in toolgroups_for_turn_set: if toolgroup_name_with_maybe_tool_name not in toolgroups_for_turn_set:
continue continue
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
raise ValueError(
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
)
for tool_def in tools.data: for tool_def in tools.data:
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP: if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
tool_name = tool_def.identifier tool_name = tool_def.identifier
@ -816,10 +817,10 @@ class ChatAgent(ShieldRunnerMixin):
else: else:
built_in_type = BuiltinTool(tool_name) built_in_type = BuiltinTool(tool_name)
if tool_def_map.get(built_in_type, None): if tool_name_to_def.get(built_in_type, None):
raise ValueError(f"Tool {built_in_type} already exists") raise ValueError(f"Tool {built_in_type} already exists")
tool_def_map[built_in_type] = ToolDefinition( tool_name_to_def[built_in_type] = ToolDefinition(
tool_name=built_in_type, tool_name=built_in_type,
description=tool_def.description, description=tool_def.description,
parameters={ parameters={
@ -835,24 +836,42 @@ class ChatAgent(ShieldRunnerMixin):
tool_to_group[built_in_type] = tool_def.toolgroup_id tool_to_group[built_in_type] = tool_def.toolgroup_id
continue continue
if tool_def_map.get(tool_def.identifier, None): if tool_name_to_def.get(tool_def.identifier, None):
raise ValueError(f"Tool {tool_def.identifier} already exists") raise ValueError(f"Tool {tool_def.identifier} already exists")
tool_def_map[tool_def.identifier] = ToolDefinition( if tool_name in (None, tool_def.identifier):
tool_name=tool_def.identifier, tool_name_to_def[tool_def.identifier] = ToolDefinition(
description=tool_def.description, tool_name=tool_def.identifier,
parameters={ description=tool_def.description,
param.name: ToolParamDefinition( parameters={
param_type=param.parameter_type, param.name: ToolParamDefinition(
description=param.description, param_type=param.parameter_type,
required=param.required, description=param.description,
default=param.default, required=param.required,
) default=param.default,
for param in tool_def.parameters )
}, for param in tool_def.parameters
) },
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id )
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
return tool_def_map, tool_to_group return list(tool_name_to_def.values()), tool_to_group
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
"""Parse a toolgroup name into its components.
Args:
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search")
Returns:
A tuple of (tool_type, tool_group, tool_name)
"""
split_names = toolgroup_name_with_maybe_tool_name.split("/")
if len(split_names) == 2:
# e.g. "builtin::rag"
tool_group, tool_name = split_names
else:
tool_group, tool_name = split_names[0], None
return tool_group, tool_name
async def handle_documents( async def handle_documents(
self, self,
@ -861,8 +880,8 @@ class ChatAgent(ShieldRunnerMixin):
input_messages: List[Message], input_messages: List[Message],
tool_defs: Dict[str, ToolDefinition], tool_defs: Dict[str, ToolDefinition],
) -> None: ) -> None:
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None) memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs)
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None) code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs)
content_items = [] content_items = []
url_items = [] url_items = []
pattern = re.compile("^(https?://|file://|data:)") pattern = re.compile("^(https?://|file://|data:)")

View file

@ -441,7 +441,8 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
assert "get_boiling_point" in logs_str assert "get_boiling_point" in logs_str
def test_rag_agent(llama_stack_client, agent_config): @pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [ documents = [
Document( Document(
@ -469,7 +470,7 @@ def test_rag_agent(llama_stack_client, agent_config):
**agent_config, **agent_config,
"toolgroups": [ "toolgroups": [
dict( dict(
name="builtin::rag", name=rag_tool_name,
args={ args={
"vector_db_ids": [vector_db_id], "vector_db_ids": [vector_db_id],
}, },
@ -483,10 +484,6 @@ def test_rag_agent(llama_stack_client, agent_config):
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?", "Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
"grouped", "grouped",
), ),
(
"What `tune` command to use for getting access to Llama3-8B-Instruct ?",
"download",
),
] ]
for prompt, expected_kw in user_prompts: for prompt, expected_kw in user_prompts:
response = rag_agent.create_turn( response = rag_agent.create_turn(
@ -541,7 +538,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
**agent_config, **agent_config,
"toolgroups": [ "toolgroups": [
dict( dict(
name="builtin::rag", name="builtin::rag/knowledge_search",
args={"vector_db_ids": [vector_db_id]}, args={"vector_db_ids": [vector_db_id]},
), ),
"builtin::code_interpreter", "builtin::code_interpreter",