diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 75f1cb9c0..09184d09a 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -36,7 +36,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.memory import MemoryBank from llama_stack.apis.safety import SafetyViolation -from llama_stack.apis.tools import CustomToolDef +from llama_stack.apis.tools import UserDefinedToolDef from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -137,8 +137,8 @@ class AgentConfigCommon(BaseModel): input_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list) - available_tools: Optional[List[str]] = Field(default_factory=list) - custom_tools: Optional[List[CustomToolDef]] = Field(default_factory=list) + tool_names: Optional[List[str]] = Field(default_factory=list) + client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list) preprocessing_tools: Optional[List[str]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 65d5b8444..6585f3fd2 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -48,8 +48,8 @@ class Tool(Resource): @json_schema_type -class CustomToolDef(BaseModel): - type: Literal["custom"] = "custom" +class UserDefinedToolDef(BaseModel): + type: Literal["user_defined"] = "user_defined" name: str description: str parameters: List[ToolParameter] @@ -67,7 +67,7 @@ class BuiltInToolDef(BaseModel): ToolDef = register_schema( - Annotated[Union[CustomToolDef, BuiltInToolDef], Field(discriminator="type")], + Annotated[Union[UserDefinedToolDef, BuiltInToolDef], Field(discriminator="type")], name="ToolDef", ) @@ -172,14 +172,3 @@ class ToolRuntime(Protocol): ) -> ToolInvocationResult: """Run a tool with the given arguments""" ... - - -# Three tool types: -# 1. Built-in tools -# 2. Client tools -# 3. Model-context-protocol tools - -# Suport registration of agents with tool groups -# TBD: Have a client utility to hide the pre processing tools. -# Attachments are confusing right now since they are inserted into memory first and retireved through RAG, even before a question is asked. -# diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index f0d55eaf2..ccea470ae 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -28,7 +28,6 @@ from llama_stack.apis.scoring_functions import ( from llama_stack.apis.shields import Shield, Shields from llama_stack.apis.tools import ( BuiltInToolDef, - CustomToolDef, MCPToolGroupDef, Tool, ToolGroup, @@ -36,6 +35,7 @@ from llama_stack.apis.tools import ( ToolGroups, ToolHost, ToolPromptFormat, + UserDefinedToolDef, UserDefinedToolGroupDef, ) from llama_stack.distribution.datatypes import ( @@ -540,7 +540,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): raise ValueError(f"Unknown tool group: {tool_group}") for tool_def in tool_defs: - if isinstance(tool_def, CustomToolDef): + if isinstance(tool_def, UserDefinedToolDef): tools.append( Tool( identifier=tool_def.name, diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 219afe621..b035ac098 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -429,9 +429,9 @@ class ChatAgent(ShieldRunnerMixin): n_iter = 0 # Build a map of custom tools to their definitions for faster lookup - custom_tools = {} - for tool in self.agent_config.custom_tools: - custom_tools[tool.name] = tool + client_tools = {} + for tool in self.agent_config.client_tools: + client_tools[tool.name] = tool while True: step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -560,7 +560,7 @@ class ChatAgent(ShieldRunnerMixin): else: log.info(f"{str(message)}") tool_call = message.tool_calls[0] - if tool_call.tool_name in custom_tools: + if tool_call.tool_name in client_tools: yield message return @@ -656,7 +656,7 @@ class ChatAgent(ShieldRunnerMixin): async def _get_tools(self) -> List[ToolDefinition]: ret = [] - for tool in self.agent_config.custom_tools: + for tool in self.agent_config.client_tools: params = {} for param in tool.parameters: params[param.name] = ToolParamDefinition( @@ -672,7 +672,7 @@ class ChatAgent(ShieldRunnerMixin): parameters=params, ) ) - for tool_name in self.agent_config.available_tools: + for tool_name in self.agent_config.tool_names: tool = await self.tool_groups_api.get_tool(tool_name) if tool.built_in_type: ret.append(ToolDefinition(tool_name=tool.built_in_type)) diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index c77929f99..537ae3ab5 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -8,13 +8,13 @@ from typing import Any, Dict, List from urllib.parse import urlparse from llama_stack.apis.tools import ( - CustomToolDef, MCPToolGroupDef, ToolDef, ToolGroupDef, ToolInvocationResult, ToolParameter, ToolRuntime, + UserDefinedToolDef, ) from llama_stack.providers.datatypes import ToolsProtocolPrivate @@ -53,7 +53,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): ) ) tools.append( - CustomToolDef( + UserDefinedToolDef( name=tool.name, description=tool.description, parameters=parameters, diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index e02af9c92..44b0f8a2e 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -104,7 +104,7 @@ async def create_agent_turn_with_search_tool( agent_config = AgentConfig( **{ **common_params, - "available_tools": [tool_name], + "tool_names": [tool_name], } ) diff --git a/llama_stack/providers/tests/tools/fixtures.py b/llama_stack/providers/tests/tools/fixtures.py index 911043011..58defd57d 100644 --- a/llama_stack/providers/tests/tools/fixtures.py +++ b/llama_stack/providers/tests/tools/fixtures.py @@ -13,9 +13,9 @@ from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.models import ModelInput, ModelType from llama_stack.apis.tools import ( BuiltInToolDef, - CustomToolDef, ToolGroupInput, ToolParameter, + UserDefinedToolDef, UserDefinedToolGroupDef, ) from llama_stack.distribution.datatypes import Api, Provider @@ -50,7 +50,7 @@ def tool_group_input_memory() -> ToolGroupInput: tool_group_id="memory_group", tool_group=UserDefinedToolGroupDef( tools=[ - CustomToolDef( + UserDefinedToolDef( name="memory", description="Query the memory bank", parameters=[ diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 8e391a48b..68ff3089b 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -151,7 +151,7 @@ def test_agent_simple(llama_stack_client, agent_config): def test_builtin_tool_brave_search(llama_stack_client, agent_config): agent_config = { **agent_config, - "available_tools": [ + "tool_names": [ "brave_search", ], } @@ -181,7 +181,7 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config): def test_builtin_tool_code_execution(llama_stack_client, agent_config): agent_config = { **agent_config, - "available_tools": [ + "tool_names": [ "code_interpreter", ], } @@ -209,12 +209,12 @@ def test_custom_tool(llama_stack_client, agent_config): agent_config = { **agent_config, "model": "meta-llama/Llama-3.2-3B-Instruct", - "available_tools": ["brave_search"], - "custom_tools": [custom_tool.get_tool_definition()], + "tool_names": ["brave_search"], + "client_tools": [custom_tool.get_tool_definition()], "tool_prompt_format": "python_list", } - agent = Agent(llama_stack_client, agent_config, custom_tools=(custom_tool,)) + agent = Agent(llama_stack_client, agent_config, client_tools=(custom_tool,)) session_id = agent.create_session(f"test-session-{uuid4()}") response = agent.create_turn(