forked from phoenix-oss/llama-stack-mirror
feat: allow specifying specific tool within toolgroup (#1239)
Summary: E.g. `builtin::rag::knowledge_search` Test Plan: ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/ --safety-shield meta-llama/Llama-Guard-3-8B ```
This commit is contained in:
parent
657efc67bc
commit
c8a20b8ed0
7 changed files with 80 additions and 64 deletions
|
@ -803,7 +803,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
|
"model_id = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model_id\n"
|
"model_id\n"
|
||||||
]
|
]
|
||||||
|
@ -1688,7 +1688,7 @@
|
||||||
" enable_session_persistence=False,\n",
|
" enable_session_persistence=False,\n",
|
||||||
" toolgroups = [\n",
|
" toolgroups = [\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"name\": \"builtin::rag\",\n",
|
" \"name\": \"builtin::rag/knowledge_search\",\n",
|
||||||
" \"args\" : {\n",
|
" \"args\" : {\n",
|
||||||
" \"vector_db_ids\": [vector_db_id],\n",
|
" \"vector_db_ids\": [vector_db_id],\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
|
|
|
@ -7,12 +7,12 @@ Each agent turn follows these key steps:
|
||||||
1. **Initial Safety Check**: The user's input is first screened through configured safety shields
|
1. **Initial Safety Check**: The user's input is first screened through configured safety shields
|
||||||
|
|
||||||
2. **Context Retrieval**:
|
2. **Context Retrieval**:
|
||||||
- If RAG is enabled, the agent queries relevant documents from memory banks
|
- If RAG is enabled, the agent can choose to query relevant documents from memory banks. You can use the `instructions` field to steer the agent.
|
||||||
- For new documents, they are first inserted into the memory bank
|
- For new documents, they are first inserted into the memory bank.
|
||||||
- Retrieved context is augmented to the user's prompt
|
- Retrieved context is provided to the LLM as a tool response in the message history.
|
||||||
|
|
||||||
3. **Inference Loop**: The agent enters its main execution loop:
|
3. **Inference Loop**: The agent enters its main execution loop:
|
||||||
- The LLM receives the augmented prompt (with context and/or previous tool outputs)
|
- The LLM receives a user prompt (with previous tool outputs)
|
||||||
- The LLM generates a response, potentially with tool calls
|
- The LLM generates a response, potentially with tool calls
|
||||||
- If tool calls are present:
|
- If tool calls are present:
|
||||||
- Tool inputs are safety-checked
|
- Tool inputs are safety-checked
|
||||||
|
@ -40,19 +40,16 @@ sequenceDiagram
|
||||||
S->>E: Input Safety Check
|
S->>E: Input Safety Check
|
||||||
deactivate S
|
deactivate S
|
||||||
|
|
||||||
E->>M: 2.1 Query Context
|
|
||||||
M-->>E: 2.2 Retrieved Documents
|
|
||||||
|
|
||||||
loop Inference Loop
|
loop Inference Loop
|
||||||
E->>L: 3.1 Augment with Context
|
E->>L: 2.1 Augment with Context
|
||||||
L-->>E: 3.2 Response (with/without tool calls)
|
L-->>E: 2.2 Response (with/without tool calls)
|
||||||
|
|
||||||
alt Has Tool Calls
|
alt Has Tool Calls
|
||||||
E->>S: Check Tool Input
|
E->>S: Check Tool Input
|
||||||
S->>T: 4.1 Execute Tool
|
S->>T: 3.1 Execute Tool
|
||||||
T-->>E: 4.2 Tool Response
|
T-->>E: 3.2 Tool Response
|
||||||
E->>L: 5.1 Tool Response
|
E->>L: 4.1 Tool Response
|
||||||
L-->>E: 5.2 Synthesized Response
|
L-->>E: 4.2 Synthesized Response
|
||||||
end
|
end
|
||||||
|
|
||||||
opt Stop Conditions
|
opt Stop Conditions
|
||||||
|
@ -64,7 +61,7 @@ sequenceDiagram
|
||||||
end
|
end
|
||||||
|
|
||||||
E->>S: Output Safety Check
|
E->>S: Output Safety Check
|
||||||
S->>U: 6. Final Response
|
S->>U: 5. Final Response
|
||||||
```
|
```
|
||||||
|
|
||||||
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
|
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
|
||||||
|
@ -77,7 +74,10 @@ agent_config = AgentConfig(
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
# Enable both RAG and tool usage
|
# Enable both RAG and tool usage
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
{"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}},
|
{
|
||||||
|
"name": "builtin::rag/knowledge_search",
|
||||||
|
"args": {"vector_db_ids": ["my_docs"]},
|
||||||
|
},
|
||||||
"builtin::code_interpreter",
|
"builtin::code_interpreter",
|
||||||
],
|
],
|
||||||
# Configure safety
|
# Configure safety
|
||||||
|
|
|
@ -91,7 +91,7 @@ agent_config = AgentConfig(
|
||||||
enable_session_persistence=False,
|
enable_session_persistence=False,
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
{
|
{
|
||||||
"name": "builtin::rag",
|
"name": "builtin::rag/knowledge_search",
|
||||||
"args": {
|
"args": {
|
||||||
"vector_db_ids": [vector_db_id],
|
"vector_db_ids": [vector_db_id],
|
||||||
},
|
},
|
||||||
|
|
|
@ -243,7 +243,7 @@ agent_config = AgentConfig(
|
||||||
# Define tools available to the agent
|
# Define tools available to the agent
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
{
|
{
|
||||||
"name": "builtin::rag",
|
"name": "builtin::rag/knowledge_search",
|
||||||
"args": {
|
"args": {
|
||||||
"vector_db_ids": [vector_db_id],
|
"vector_db_ids": [vector_db_id],
|
||||||
},
|
},
|
||||||
|
|
|
@ -132,7 +132,7 @@ def rag_chat_page():
|
||||||
},
|
},
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
dict(
|
dict(
|
||||||
name="builtin::rag",
|
name="builtin::rag/knowledge_search",
|
||||||
args={
|
args={
|
||||||
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
|
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
|
||||||
},
|
},
|
||||||
|
|
|
@ -497,19 +497,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# TODO: simplify all of this code, it can be simpler
|
# TODO: simplify all of this code, it can be simpler
|
||||||
toolgroup_args = {}
|
toolgroup_args = {}
|
||||||
toolgroups = set()
|
toolgroups = set()
|
||||||
for toolgroup in self.agent_config.toolgroups:
|
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
toolgroups.add(toolgroup.name)
|
tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name)
|
||||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
toolgroups.add(tool_group_name)
|
||||||
|
toolgroup_args[tool_group_name] = toolgroup.args
|
||||||
else:
|
else:
|
||||||
toolgroups.add(toolgroup)
|
toolgroups.add(toolgroup)
|
||||||
if toolgroups_for_turn:
|
|
||||||
for toolgroup in toolgroups_for_turn:
|
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
|
||||||
toolgroups.add(toolgroup.name)
|
|
||||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
|
||||||
else:
|
|
||||||
toolgroups.add(toolgroup)
|
|
||||||
|
|
||||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||||
if documents:
|
if documents:
|
||||||
|
@ -542,7 +536,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
tools=[tool for tool in tool_defs.values()],
|
tools=tool_defs,
|
||||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||||
response_format=self.agent_config.response_format,
|
response_format=self.agent_config.response_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
@ -768,7 +762,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
async def _get_tool_defs(
|
async def _get_tool_defs(
|
||||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||||
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
|
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
||||||
# Determine which tools to include
|
# Determine which tools to include
|
||||||
agent_config_toolgroups = set(
|
agent_config_toolgroups = set(
|
||||||
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
||||||
|
@ -783,13 +777,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_def_map = {}
|
tool_name_to_def = {}
|
||||||
tool_to_group = {}
|
tool_to_group = {}
|
||||||
|
|
||||||
for tool_def in self.agent_config.client_tools:
|
for tool_def in self.agent_config.client_tools:
|
||||||
if tool_def_map.get(tool_def.name, None):
|
if tool_name_to_def.get(tool_def.name, None):
|
||||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||||
tool_def_map[tool_def.name] = ToolDefinition(
|
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||||
tool_name=tool_def.name,
|
tool_name=tool_def.name,
|
||||||
description=tool_def.description,
|
description=tool_def.description,
|
||||||
parameters={
|
parameters={
|
||||||
|
@ -803,10 +797,17 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
tool_to_group[tool_def.name] = "__client_tools__"
|
tool_to_group[tool_def.name] = "__client_tools__"
|
||||||
for toolgroup_name in agent_config_toolgroups:
|
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||||
if toolgroup_name not in toolgroups_for_turn_set:
|
if toolgroup_name_with_maybe_tool_name not in toolgroups_for_turn_set:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||||
|
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||||
|
)
|
||||||
|
|
||||||
for tool_def in tools.data:
|
for tool_def in tools.data:
|
||||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||||
tool_name = tool_def.identifier
|
tool_name = tool_def.identifier
|
||||||
|
@ -816,10 +817,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
built_in_type = BuiltinTool(tool_name)
|
built_in_type = BuiltinTool(tool_name)
|
||||||
|
|
||||||
if tool_def_map.get(built_in_type, None):
|
if tool_name_to_def.get(built_in_type, None):
|
||||||
raise ValueError(f"Tool {built_in_type} already exists")
|
raise ValueError(f"Tool {built_in_type} already exists")
|
||||||
|
|
||||||
tool_def_map[built_in_type] = ToolDefinition(
|
tool_name_to_def[built_in_type] = ToolDefinition(
|
||||||
tool_name=built_in_type,
|
tool_name=built_in_type,
|
||||||
description=tool_def.description,
|
description=tool_def.description,
|
||||||
parameters={
|
parameters={
|
||||||
|
@ -835,24 +836,42 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if tool_def_map.get(tool_def.identifier, None):
|
if tool_name_to_def.get(tool_def.identifier, None):
|
||||||
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
||||||
tool_def_map[tool_def.identifier] = ToolDefinition(
|
if tool_name in (None, tool_def.identifier):
|
||||||
tool_name=tool_def.identifier,
|
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
||||||
description=tool_def.description,
|
tool_name=tool_def.identifier,
|
||||||
parameters={
|
description=tool_def.description,
|
||||||
param.name: ToolParamDefinition(
|
parameters={
|
||||||
param_type=param.parameter_type,
|
param.name: ToolParamDefinition(
|
||||||
description=param.description,
|
param_type=param.parameter_type,
|
||||||
required=param.required,
|
description=param.description,
|
||||||
default=param.default,
|
required=param.required,
|
||||||
)
|
default=param.default,
|
||||||
for param in tool_def.parameters
|
)
|
||||||
},
|
for param in tool_def.parameters
|
||||||
)
|
},
|
||||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
)
|
||||||
|
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||||
|
|
||||||
return tool_def_map, tool_to_group
|
return list(tool_name_to_def.values()), tool_to_group
|
||||||
|
|
||||||
|
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
||||||
|
"""Parse a toolgroup name into its components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (tool_type, tool_group, tool_name)
|
||||||
|
"""
|
||||||
|
split_names = toolgroup_name_with_maybe_tool_name.split("/")
|
||||||
|
if len(split_names) == 2:
|
||||||
|
# e.g. "builtin::rag"
|
||||||
|
tool_group, tool_name = split_names
|
||||||
|
else:
|
||||||
|
tool_group, tool_name = split_names[0], None
|
||||||
|
return tool_group, tool_name
|
||||||
|
|
||||||
async def handle_documents(
|
async def handle_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -861,8 +880,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
tool_defs: Dict[str, ToolDefinition],
|
tool_defs: Dict[str, ToolDefinition],
|
||||||
) -> None:
|
) -> None:
|
||||||
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
|
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs)
|
||||||
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
|
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs)
|
||||||
content_items = []
|
content_items = []
|
||||||
url_items = []
|
url_items = []
|
||||||
pattern = re.compile("^(https?://|file://|data:)")
|
pattern = re.compile("^(https?://|file://|data:)")
|
||||||
|
|
|
@ -441,7 +441,8 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
|
||||||
assert "get_boiling_point" in logs_str
|
assert "get_boiling_point" in logs_str
|
||||||
|
|
||||||
|
|
||||||
def test_rag_agent(llama_stack_client, agent_config):
|
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
|
||||||
|
def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
|
||||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||||
documents = [
|
documents = [
|
||||||
Document(
|
Document(
|
||||||
|
@ -469,7 +470,7 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"toolgroups": [
|
"toolgroups": [
|
||||||
dict(
|
dict(
|
||||||
name="builtin::rag",
|
name=rag_tool_name,
|
||||||
args={
|
args={
|
||||||
"vector_db_ids": [vector_db_id],
|
"vector_db_ids": [vector_db_id],
|
||||||
},
|
},
|
||||||
|
@ -483,10 +484,6 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
|
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
|
||||||
"grouped",
|
"grouped",
|
||||||
),
|
),
|
||||||
(
|
|
||||||
"What `tune` command to use for getting access to Llama3-8B-Instruct ?",
|
|
||||||
"download",
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
for prompt, expected_kw in user_prompts:
|
for prompt, expected_kw in user_prompts:
|
||||||
response = rag_agent.create_turn(
|
response = rag_agent.create_turn(
|
||||||
|
@ -541,7 +538,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"toolgroups": [
|
"toolgroups": [
|
||||||
dict(
|
dict(
|
||||||
name="builtin::rag",
|
name="builtin::rag/knowledge_search",
|
||||||
args={"vector_db_ids": [vector_db_id]},
|
args={"vector_db_ids": [vector_db_id]},
|
||||||
),
|
),
|
||||||
"builtin::code_interpreter",
|
"builtin::code_interpreter",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue