From c8a20b8ed0e0100ada7dfb8b3eec5065f454005c Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 26 Feb 2025 14:07:05 -0800 Subject: [PATCH] feat: allow specifying specific tool within toolgroup (#1239) Summary: E.g. `builtin::rag::knowledge_search` Test Plan: ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/ --safety-shield meta-llama/Llama-Guard-3-8B ``` --- docs/getting_started.ipynb | 4 +- .../agent_execution_loop.md | 30 +++--- docs/source/building_applications/rag.md | 2 +- docs/source/getting_started/index.md | 2 +- .../distribution/ui/page/playground/rag.py | 2 +- .../agents/meta_reference/agent_instance.py | 93 +++++++++++-------- tests/client-sdk/agents/test_agents.py | 11 +-- 7 files changed, 80 insertions(+), 64 deletions(-) diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb index 3b3059285..329734f4c 100644 --- a/docs/getting_started.ipynb +++ b/docs/getting_started.ipynb @@ -803,7 +803,7 @@ } ], "source": [ - "model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n", + "model_id = \"meta-llama/Llama-3.3-70B-Instruct\"\n", "\n", "model_id\n" ] @@ -1688,7 +1688,7 @@ " enable_session_persistence=False,\n", " toolgroups = [\n", " {\n", - " \"name\": \"builtin::rag\",\n", + " \"name\": \"builtin::rag/knowledge_search\",\n", " \"args\" : {\n", " \"vector_db_ids\": [vector_db_id],\n", " }\n", diff --git a/docs/source/building_applications/agent_execution_loop.md b/docs/source/building_applications/agent_execution_loop.md index 6b3f64423..0d212df7a 100644 --- a/docs/source/building_applications/agent_execution_loop.md +++ b/docs/source/building_applications/agent_execution_loop.md @@ -7,12 +7,12 @@ Each agent turn follows these key steps: 1. **Initial Safety Check**: The user's input is first screened through configured safety shields 2. **Context Retrieval**: - - If RAG is enabled, the agent queries relevant documents from memory banks - - For new documents, they are first inserted into the memory bank - - Retrieved context is augmented to the user's prompt + - If RAG is enabled, the agent can choose to query relevant documents from memory banks. You can use the `instructions` field to steer the agent. + - For new documents, they are first inserted into the memory bank. + - Retrieved context is provided to the LLM as a tool response in the message history. 3. **Inference Loop**: The agent enters its main execution loop: - - The LLM receives the augmented prompt (with context and/or previous tool outputs) + - The LLM receives a user prompt (with previous tool outputs) - The LLM generates a response, potentially with tool calls - If tool calls are present: - Tool inputs are safety-checked @@ -40,19 +40,16 @@ sequenceDiagram S->>E: Input Safety Check deactivate S - E->>M: 2.1 Query Context - M-->>E: 2.2 Retrieved Documents - loop Inference Loop - E->>L: 3.1 Augment with Context - L-->>E: 3.2 Response (with/without tool calls) + E->>L: 2.1 Augment with Context + L-->>E: 2.2 Response (with/without tool calls) alt Has Tool Calls E->>S: Check Tool Input - S->>T: 4.1 Execute Tool - T-->>E: 4.2 Tool Response - E->>L: 5.1 Tool Response - L-->>E: 5.2 Synthesized Response + S->>T: 3.1 Execute Tool + T-->>E: 3.2 Tool Response + E->>L: 4.1 Tool Response + L-->>E: 4.2 Synthesized Response end opt Stop Conditions @@ -64,7 +61,7 @@ sequenceDiagram end E->>S: Output Safety Check - S->>U: 6. Final Response + S->>U: 5. Final Response ``` Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution: @@ -77,7 +74,10 @@ agent_config = AgentConfig( instructions="You are a helpful assistant", # Enable both RAG and tool usage toolgroups=[ - {"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}}, + { + "name": "builtin::rag/knowledge_search", + "args": {"vector_db_ids": ["my_docs"]}, + }, "builtin::code_interpreter", ], # Configure safety diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index e6d628193..e2e5fd6b5 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -91,7 +91,7 @@ agent_config = AgentConfig( enable_session_persistence=False, toolgroups=[ { - "name": "builtin::rag", + "name": "builtin::rag/knowledge_search", "args": { "vector_db_ids": [vector_db_id], }, diff --git a/docs/source/getting_started/index.md b/docs/source/getting_started/index.md index 554f4354a..f017a9723 100644 --- a/docs/source/getting_started/index.md +++ b/docs/source/getting_started/index.md @@ -243,7 +243,7 @@ agent_config = AgentConfig( # Define tools available to the agent toolgroups=[ { - "name": "builtin::rag", + "name": "builtin::rag/knowledge_search", "args": { "vector_db_ids": [vector_db_id], }, diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index d84418241..202c9322f 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -132,7 +132,7 @@ def rag_chat_page(): }, toolgroups=[ dict( - name="builtin::rag", + name="builtin::rag/knowledge_search", args={ "vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs], }, diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index c910598b1..b17179463 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -497,19 +497,13 @@ class ChatAgent(ShieldRunnerMixin): # TODO: simplify all of this code, it can be simpler toolgroup_args = {} toolgroups = set() - for toolgroup in self.agent_config.toolgroups: + for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []): if isinstance(toolgroup, AgentToolGroupWithArgs): - toolgroups.add(toolgroup.name) - toolgroup_args[toolgroup.name] = toolgroup.args + tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name) + toolgroups.add(tool_group_name) + toolgroup_args[tool_group_name] = toolgroup.args else: toolgroups.add(toolgroup) - if toolgroups_for_turn: - for toolgroup in toolgroups_for_turn: - if isinstance(toolgroup, AgentToolGroupWithArgs): - toolgroups.add(toolgroup.name) - toolgroup_args[toolgroup.name] = toolgroup.args - else: - toolgroups.add(toolgroup) tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn) if documents: @@ -542,7 +536,7 @@ class ChatAgent(ShieldRunnerMixin): async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, - tools=[tool for tool in tool_defs.values()], + tools=tool_defs, tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, response_format=self.agent_config.response_format, stream=True, @@ -768,7 +762,7 @@ class ChatAgent(ShieldRunnerMixin): async def _get_tool_defs( self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None - ) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]: + ) -> Tuple[List[ToolDefinition], Dict[str, str]]: # Determine which tools to include agent_config_toolgroups = set( (toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup) @@ -783,13 +777,13 @@ class ChatAgent(ShieldRunnerMixin): } ) - tool_def_map = {} + tool_name_to_def = {} tool_to_group = {} for tool_def in self.agent_config.client_tools: - if tool_def_map.get(tool_def.name, None): + if tool_name_to_def.get(tool_def.name, None): raise ValueError(f"Tool {tool_def.name} already exists") - tool_def_map[tool_def.name] = ToolDefinition( + tool_name_to_def[tool_def.name] = ToolDefinition( tool_name=tool_def.name, description=tool_def.description, parameters={ @@ -803,10 +797,17 @@ class ChatAgent(ShieldRunnerMixin): }, ) tool_to_group[tool_def.name] = "__client_tools__" - for toolgroup_name in agent_config_toolgroups: - if toolgroup_name not in toolgroups_for_turn_set: + for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: + if toolgroup_name_with_maybe_tool_name not in toolgroups_for_turn_set: continue + + toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) + if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data): + raise ValueError( + f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}" + ) + for tool_def in tools.data: if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP: tool_name = tool_def.identifier @@ -816,10 +817,10 @@ class ChatAgent(ShieldRunnerMixin): else: built_in_type = BuiltinTool(tool_name) - if tool_def_map.get(built_in_type, None): + if tool_name_to_def.get(built_in_type, None): raise ValueError(f"Tool {built_in_type} already exists") - tool_def_map[built_in_type] = ToolDefinition( + tool_name_to_def[built_in_type] = ToolDefinition( tool_name=built_in_type, description=tool_def.description, parameters={ @@ -835,24 +836,42 @@ class ChatAgent(ShieldRunnerMixin): tool_to_group[built_in_type] = tool_def.toolgroup_id continue - if tool_def_map.get(tool_def.identifier, None): + if tool_name_to_def.get(tool_def.identifier, None): raise ValueError(f"Tool {tool_def.identifier} already exists") - tool_def_map[tool_def.identifier] = ToolDefinition( - tool_name=tool_def.identifier, - description=tool_def.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - for param in tool_def.parameters - }, - ) - tool_to_group[tool_def.identifier] = tool_def.toolgroup_id + if tool_name in (None, tool_def.identifier): + tool_name_to_def[tool_def.identifier] = ToolDefinition( + tool_name=tool_def.identifier, + description=tool_def.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool_def.parameters + }, + ) + tool_to_group[tool_def.identifier] = tool_def.toolgroup_id - return tool_def_map, tool_to_group + return list(tool_name_to_def.values()), tool_to_group + + def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]: + """Parse a toolgroup name into its components. + + Args: + toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search") + + Returns: + A tuple of (tool_type, tool_group, tool_name) + """ + split_names = toolgroup_name_with_maybe_tool_name.split("/") + if len(split_names) == 2: + # e.g. "builtin::rag" + tool_group, tool_name = split_names + else: + tool_group, tool_name = split_names[0], None + return tool_group, tool_name async def handle_documents( self, @@ -861,8 +880,8 @@ class ChatAgent(ShieldRunnerMixin): input_messages: List[Message], tool_defs: Dict[str, ToolDefinition], ) -> None: - memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None) - code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None) + memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs) + code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs) content_items = [] url_items = [] pattern = re.compile("^(https?://|file://|data:)") diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 8e2c793e6..6e3dc0739 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -441,7 +441,8 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config): assert "get_boiling_point" in logs_str -def test_rag_agent(llama_stack_client, agent_config): +@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"]) +def test_rag_agent(llama_stack_client, agent_config, rag_tool_name): urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] documents = [ Document( @@ -469,7 +470,7 @@ def test_rag_agent(llama_stack_client, agent_config): **agent_config, "toolgroups": [ dict( - name="builtin::rag", + name=rag_tool_name, args={ "vector_db_ids": [vector_db_id], }, @@ -483,10 +484,6 @@ def test_rag_agent(llama_stack_client, agent_config): "Instead of the standard multi-head attention, what attention type does Llama3-8B use?", "grouped", ), - ( - "What `tune` command to use for getting access to Llama3-8B-Instruct ?", - "download", - ), ] for prompt, expected_kw in user_prompts: response = rag_agent.create_turn( @@ -541,7 +538,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config): **agent_config, "toolgroups": [ dict( - name="builtin::rag", + name="builtin::rag/knowledge_search", args={"vector_db_ids": [vector_db_id]}, ), "builtin::code_interpreter",