mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: allow specifying specific tool within toolgroup (#1239)
Summary: E.g. `builtin::rag::knowledge_search` Test Plan: ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/ --safety-shield meta-llama/Llama-Guard-3-8B ```
This commit is contained in:
parent
657efc67bc
commit
c8a20b8ed0
7 changed files with 80 additions and 64 deletions
|
@ -132,7 +132,7 @@ def rag_chat_page():
|
|||
},
|
||||
toolgroups=[
|
||||
dict(
|
||||
name="builtin::rag",
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
|
||||
},
|
||||
|
|
|
@ -497,19 +497,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# TODO: simplify all of this code, it can be simpler
|
||||
toolgroup_args = {}
|
||||
toolgroups = set()
|
||||
for toolgroup in self.agent_config.toolgroups:
|
||||
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroups.add(toolgroup.name)
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name)
|
||||
toolgroups.add(tool_group_name)
|
||||
toolgroup_args[tool_group_name] = toolgroup.args
|
||||
else:
|
||||
toolgroups.add(toolgroup)
|
||||
if toolgroups_for_turn:
|
||||
for toolgroup in toolgroups_for_turn:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroups.add(toolgroup.name)
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
else:
|
||||
toolgroups.add(toolgroup)
|
||||
|
||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||
if documents:
|
||||
|
@ -542,7 +536,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=[tool for tool in tool_defs.values()],
|
||||
tools=tool_defs,
|
||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||
response_format=self.agent_config.response_format,
|
||||
stream=True,
|
||||
|
@ -768,7 +762,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _get_tool_defs(
|
||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
|
||||
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
||||
# Determine which tools to include
|
||||
agent_config_toolgroups = set(
|
||||
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
||||
|
@ -783,13 +777,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
}
|
||||
)
|
||||
|
||||
tool_def_map = {}
|
||||
tool_name_to_def = {}
|
||||
tool_to_group = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_def_map.get(tool_def.name, None):
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
tool_def_map[tool_def.name] = ToolDefinition(
|
||||
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
|
@ -803,10 +797,17 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
},
|
||||
)
|
||||
tool_to_group[tool_def.name] = "__client_tools__"
|
||||
for toolgroup_name in agent_config_toolgroups:
|
||||
if toolgroup_name not in toolgroups_for_turn_set:
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
if toolgroup_name_with_maybe_tool_name not in toolgroups_for_turn_set:
|
||||
continue
|
||||
|
||||
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
|
||||
raise ValueError(
|
||||
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||
)
|
||||
|
||||
for tool_def in tools.data:
|
||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||
tool_name = tool_def.identifier
|
||||
|
@ -816,10 +817,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
built_in_type = BuiltinTool(tool_name)
|
||||
|
||||
if tool_def_map.get(built_in_type, None):
|
||||
if tool_name_to_def.get(built_in_type, None):
|
||||
raise ValueError(f"Tool {built_in_type} already exists")
|
||||
|
||||
tool_def_map[built_in_type] = ToolDefinition(
|
||||
tool_name_to_def[built_in_type] = ToolDefinition(
|
||||
tool_name=built_in_type,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
|
@ -835,24 +836,42 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||
continue
|
||||
|
||||
if tool_def_map.get(tool_def.identifier, None):
|
||||
if tool_name_to_def.get(tool_def.identifier, None):
|
||||
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
||||
tool_def_map[tool_def.identifier] = ToolDefinition(
|
||||
tool_name=tool_def.identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||
if tool_name in (None, tool_def.identifier):
|
||||
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
||||
tool_name=tool_def.identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||
|
||||
return tool_def_map, tool_to_group
|
||||
return list(tool_name_to_def.values()), tool_to_group
|
||||
|
||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
||||
"""Parse a toolgroup name into its components.
|
||||
|
||||
Args:
|
||||
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search")
|
||||
|
||||
Returns:
|
||||
A tuple of (tool_type, tool_group, tool_name)
|
||||
"""
|
||||
split_names = toolgroup_name_with_maybe_tool_name.split("/")
|
||||
if len(split_names) == 2:
|
||||
# e.g. "builtin::rag"
|
||||
tool_group, tool_name = split_names
|
||||
else:
|
||||
tool_group, tool_name = split_names[0], None
|
||||
return tool_group, tool_name
|
||||
|
||||
async def handle_documents(
|
||||
self,
|
||||
|
@ -861,8 +880,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages: List[Message],
|
||||
tool_defs: Dict[str, ToolDefinition],
|
||||
) -> None:
|
||||
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
|
||||
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
|
||||
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs)
|
||||
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs)
|
||||
content_items = []
|
||||
url_items = []
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue