diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index b1bef0882..d1d2c266d 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -3711,6 +3711,12 @@
"type": "string"
}
},
+ "custom_tools": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/CustomToolDef"
+ }
+ },
"preprocessing_tools": {
"type": "array",
"items": {
@@ -3747,6 +3753,111 @@
"enable_session_persistence"
]
},
+ "CustomToolDef": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "custom",
+ "default": "custom"
+ },
+ "name": {
+ "type": "string"
+ },
+ "description": {
+ "type": "string"
+ },
+ "parameters": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ToolParameter"
+ }
+ },
+ "metadata": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
+ },
+ "tool_prompt_format": {
+ "$ref": "#/components/schemas/ToolPromptFormat",
+ "default": "json"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "name",
+ "description",
+ "parameters",
+ "metadata"
+ ]
+ },
+ "ToolParameter": {
+ "type": "object",
+ "properties": {
+ "name": {
+ "type": "string"
+ },
+ "parameter_type": {
+ "type": "string"
+ },
+ "description": {
+ "type": "string"
+ },
+ "required": {
+ "type": "boolean"
+ },
+ "default": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "name",
+ "parameter_type",
+ "description",
+ "required"
+ ]
+ },
"CreateAgentRequest": {
"type": "object",
"properties": {
@@ -4403,39 +4514,16 @@
"session_id"
]
},
- "MCPToolGroupDef": {
+ "BuiltInToolDef": {
"type": "object",
"properties": {
"type": {
"type": "string",
- "const": "model_context_protocol",
- "default": "model_context_protocol"
+ "const": "built_in",
+ "default": "built_in"
},
- "endpoint": {
- "$ref": "#/components/schemas/URL"
- }
- },
- "additionalProperties": false,
- "required": [
- "type",
- "endpoint"
- ],
- "title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information."
- },
- "ToolDef": {
- "type": "object",
- "properties": {
- "name": {
- "type": "string"
- },
- "description": {
- "type": "string"
- },
- "parameters": {
- "type": "array",
- "items": {
- "$ref": "#/components/schemas/ToolParameter"
- }
+ "built_in_type": {
+ "$ref": "#/components/schemas/BuiltinTool"
},
"metadata": {
"type": "object",
@@ -4461,18 +4549,41 @@
}
]
}
- },
- "tool_prompt_format": {
- "$ref": "#/components/schemas/ToolPromptFormat",
- "default": "json"
}
},
"additionalProperties": false,
"required": [
- "name",
- "description",
- "parameters",
- "metadata"
+ "type",
+ "built_in_type"
+ ]
+ },
+ "MCPToolGroupDef": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "model_context_protocol",
+ "default": "model_context_protocol"
+ },
+ "endpoint": {
+ "$ref": "#/components/schemas/URL"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "endpoint"
+ ],
+ "title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information."
+ },
+ "ToolDef": {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/CustomToolDef"
+ },
+ {
+ "$ref": "#/components/schemas/BuiltInToolDef"
+ }
]
},
"ToolGroupDef": {
@@ -4485,52 +4596,6 @@
}
]
},
- "ToolParameter": {
- "type": "object",
- "properties": {
- "name": {
- "type": "string"
- },
- "parameter_type": {
- "type": "string"
- },
- "description": {
- "type": "string"
- },
- "required": {
- "type": "boolean"
- },
- "default": {
- "oneOf": [
- {
- "type": "null"
- },
- {
- "type": "boolean"
- },
- {
- "type": "number"
- },
- {
- "type": "string"
- },
- {
- "type": "array"
- },
- {
- "type": "object"
- }
- ]
- }
- },
- "additionalProperties": false,
- "required": [
- "name",
- "parameter_type",
- "description",
- "required"
- ]
- },
"UserDefinedToolGroupDef": {
"type": "object",
"properties": {
@@ -5797,6 +5862,9 @@
"tool_group": {
"type": "string"
},
+ "tool_host": {
+ "$ref": "#/components/schemas/ToolHost"
+ },
"description": {
"type": "string"
},
@@ -5806,6 +5874,9 @@
"$ref": "#/components/schemas/ToolParameter"
}
},
+ "built_in_type": {
+ "$ref": "#/components/schemas/BuiltinTool"
+ },
"metadata": {
"type": "object",
"additionalProperties": {
@@ -5840,12 +5911,22 @@
"required": [
"identifier",
"provider_resource_id",
+ "provider_id",
"type",
"tool_group",
+ "tool_host",
"description",
"parameters"
]
},
+ "ToolHost": {
+ "type": "string",
+ "enum": [
+ "distribution",
+ "client",
+ "model_context_protocol"
+ ]
+ },
"ToolGroup": {
"type": "object",
"properties": {
@@ -7942,6 +8023,10 @@
"name": "BenchmarkEvalTaskConfig",
"description": ""
},
+ {
+ "name": "BuiltInToolDef",
+ "description": ""
+ },
{
"name": "BuiltinTool",
"description": ""
@@ -8002,6 +8087,10 @@
"name": "CreateAgentTurnRequest",
"description": ""
},
+ {
+ "name": "CustomToolDef",
+ "description": ""
+ },
{
"name": "DPOAlignmentConfig",
"description": ""
@@ -8481,6 +8570,10 @@
{
"name": "ToolGroups"
},
+ {
+ "name": "ToolHost",
+ "description": ""
+ },
{
"name": "ToolInvocationResult",
"description": ""
@@ -8624,6 +8717,7 @@
"BatchCompletionRequest",
"BatchCompletionResponse",
"BenchmarkEvalTaskConfig",
+ "BuiltInToolDef",
"BuiltinTool",
"CancelTrainingJobRequest",
"ChatCompletionRequest",
@@ -8639,6 +8733,7 @@
"CreateAgentRequest",
"CreateAgentSessionRequest",
"CreateAgentTurnRequest",
+ "CustomToolDef",
"DPOAlignmentConfig",
"DataConfig",
"Dataset",
@@ -8746,6 +8841,7 @@
"ToolExecutionStep",
"ToolGroup",
"ToolGroupDef",
+ "ToolHost",
"ToolInvocationResult",
"ToolParamDefinition",
"ToolParameter",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 5da647b54..4f7a9c91c 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -21,6 +21,10 @@ components:
items:
type: string
type: array
+ custom_tools:
+ items:
+ $ref: '#/components/schemas/CustomToolDef'
+ type: array
enable_session_persistence:
type: boolean
input_shields:
@@ -389,6 +393,29 @@ components:
- type
- eval_candidate
type: object
+ BuiltInToolDef:
+ additionalProperties: false
+ properties:
+ built_in_type:
+ $ref: '#/components/schemas/BuiltinTool'
+ metadata:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ type:
+ const: built_in
+ default: built_in
+ type: string
+ required:
+ - type
+ - built_in_type
+ type: object
BuiltinTool:
enum:
- brave_search
@@ -607,6 +634,41 @@ components:
- session_id
- messages
type: object
+ CustomToolDef:
+ additionalProperties: false
+ properties:
+ description:
+ type: string
+ metadata:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ name:
+ type: string
+ parameters:
+ items:
+ $ref: '#/components/schemas/ToolParameter'
+ type: array
+ tool_prompt_format:
+ $ref: '#/components/schemas/ToolPromptFormat'
+ default: json
+ type:
+ const: custom
+ default: custom
+ type: string
+ required:
+ - type
+ - name
+ - description
+ - parameters
+ - metadata
+ type: object
DPOAlignmentConfig:
additionalProperties: false
properties:
@@ -2557,6 +2619,8 @@ components:
Tool:
additionalProperties: false
properties:
+ built_in_type:
+ $ref: '#/components/schemas/BuiltinTool'
description:
type: string
identifier:
@@ -2581,6 +2645,8 @@ components:
type: string
tool_group:
type: string
+ tool_host:
+ $ref: '#/components/schemas/ToolHost'
tool_prompt_format:
$ref: '#/components/schemas/ToolPromptFormat'
default: json
@@ -2591,8 +2657,10 @@ components:
required:
- identifier
- provider_resource_id
+ - provider_id
- type
- tool_group
+ - tool_host
- description
- parameters
type: object
@@ -2661,35 +2729,9 @@ components:
- required
type: string
ToolDef:
- additionalProperties: false
- properties:
- description:
- type: string
- metadata:
- additionalProperties:
- oneOf:
- - type: 'null'
- - type: boolean
- - type: number
- - type: string
- - type: array
- - type: object
- type: object
- name:
- type: string
- parameters:
- items:
- $ref: '#/components/schemas/ToolParameter'
- type: array
- tool_prompt_format:
- $ref: '#/components/schemas/ToolPromptFormat'
- default: json
- required:
- - name
- - description
- - parameters
- - metadata
- type: object
+ oneOf:
+ - $ref: '#/components/schemas/CustomToolDef'
+ - $ref: '#/components/schemas/BuiltInToolDef'
ToolDefinition:
additionalProperties: false
properties:
@@ -2761,6 +2803,12 @@ components:
oneOf:
- $ref: '#/components/schemas/MCPToolGroupDef'
- $ref: '#/components/schemas/UserDefinedToolGroupDef'
+ ToolHost:
+ enum:
+ - distribution
+ - client
+ - model_context_protocol
+ type: string
ToolInvocationResult:
additionalProperties: false
properties:
@@ -4738,6 +4786,8 @@ tags:
- description:
name: BenchmarkEvalTaskConfig
+- description:
+ name: BuiltInToolDef
- description:
name: BuiltinTool
- description:
name: CreateAgentTurnRequest
+- description:
+ name: CustomToolDef
- description:
name: DPOAlignmentConfig
@@ -5111,6 +5163,8 @@ tags:
- description:
name: ToolGroupDef
- name: ToolGroups
+- description:
+ name: ToolHost
- description:
name: ToolInvocationResult
@@ -5224,6 +5278,7 @@ x-tagGroups:
- BatchCompletionRequest
- BatchCompletionResponse
- BenchmarkEvalTaskConfig
+ - BuiltInToolDef
- BuiltinTool
- CancelTrainingJobRequest
- ChatCompletionRequest
@@ -5239,6 +5294,7 @@ x-tagGroups:
- CreateAgentRequest
- CreateAgentSessionRequest
- CreateAgentTurnRequest
+ - CustomToolDef
- DPOAlignmentConfig
- DataConfig
- Dataset
@@ -5346,6 +5402,7 @@ x-tagGroups:
- ToolExecutionStep
- ToolGroup
- ToolGroupDef
+ - ToolHost
- ToolInvocationResult
- ToolParamDefinition
- ToolParameter
diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py
index 325ce9490..3348211c9 100644
--- a/llama_stack/apis/agents/agents.py
+++ b/llama_stack/apis/agents/agents.py
@@ -18,13 +18,11 @@ from typing import (
runtime_checkable,
)
-from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import URL, InterleavedContent
-from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
from llama_stack.apis.inference import (
CompletionMessage,
SamplingParams,
@@ -140,6 +138,7 @@ class AgentConfigCommon(BaseModel):
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
available_tools: Optional[List[str]] = Field(default_factory=list)
+ custom_tools: Optional[List[CustomToolDef]] = Field(default_factory=list)
preprocessing_tools: Optional[List[str]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
index cc4ef38a9..ba190f567 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
@@ -400,6 +400,10 @@ class ChatAgent(ShieldRunnerMixin):
output_attachments = []
n_iter = 0
+ # Build a map of custom tools to their definitions for faster lookup
+ custom_tools = {}
+ for tool in self.agent_config.custom_tools:
+ custom_tools[tool.name] = tool
while True:
msg = input_messages[-1]
@@ -530,6 +534,9 @@ class ChatAgent(ShieldRunnerMixin):
else:
log.info(f"{str(message)}")
tool_call = message.tool_calls[0]
+ if tool_call.tool_name in custom_tools:
+ yield message
+ return
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@@ -619,6 +626,22 @@ class ChatAgent(ShieldRunnerMixin):
async def _get_tools(self) -> List[ToolDefinition]:
ret = []
+ for tool in self.agent_config.custom_tools:
+ params = {}
+ for param in tool.parameters:
+ params[param.name] = ToolParamDefinition(
+ param_type=param.parameter_type,
+ description=param.description,
+ required=param.required,
+ default=param.default,
+ )
+ ret.append(
+ ToolDefinition(
+ tool_name=tool.name,
+ description=tool.description,
+ parameters=params,
+ )
+ )
for tool_name in self.agent_config.available_tools:
tool = await self.tool_groups_api.get_tool(tool_name)
if tool.built_in_type:
diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py
index 7939259d1..ef3c087fa 100644
--- a/tests/client-sdk/agents/test_agents.py
+++ b/tests/client-sdk/agents/test_agents.py
@@ -9,16 +9,13 @@ from typing import Dict, List
from uuid import uuid4
import pytest
-
-from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.custom_tool import CustomTool
from llama_stack_client.lib.agents.event_logger import EventLogger
-from llama_stack_client.types import CompletionMessage, ToolResponseMessage
+from llama_stack_client.types import ToolResponseMessage
from llama_stack_client.types.agent_create_params import AgentConfig
-from llama_stack_client.types.tool_param_definition_param import (
- ToolParamDefinitionParam,
-)
+from llama_stack_client.types.custom_tool_def import Parameter
+from llama_stack_client.types.shared.completion_message import CompletionMessage
class TestCustomTool(CustomTool):
@@ -54,13 +51,17 @@ class TestCustomTool(CustomTool):
def get_description(self) -> str:
return "Get the boiling point of a imaginary liquids (eg. polyjuice)"
- def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:
+ def get_params_definition(self) -> Dict[str, Parameter]:
return {
- "liquid_name": ToolParamDefinitionParam(
- param_type="string", description="The name of the liquid", required=True
+ "liquid_name": Parameter(
+ name="liquid_name",
+ parameter_type="string",
+ description="The name of the liquid",
+ required=True,
),
- "celcius": ToolParamDefinitionParam(
- param_type="boolean",
+ "celcius": Parameter(
+ name="celcius",
+ parameter_type="boolean",
description="Whether to return the boiling point in Celcius",
required=False,
),
@@ -203,37 +204,16 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
def test_custom_tool(llama_stack_client, agent_config):
+ custom_tool = TestCustomTool()
agent_config = {
**agent_config,
"model": "meta-llama/Llama-3.2-3B-Instruct",
- "tools": [
- {
- "type": "brave_search",
- "engine": "brave",
- "api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
- },
- {
- "function_name": "get_boiling_point",
- "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)",
- "parameters": {
- "liquid_name": {
- "param_type": "str",
- "description": "The name of the liquid",
- "required": True,
- },
- "celcius": {
- "param_type": "boolean",
- "description": "Whether to return the boiling point in Celcius",
- "required": False,
- },
- },
- "type": "function_call",
- },
- ],
+ "available_tools": ["brave_search"],
+ "custom_tools": [custom_tool.get_tool_definition()],
"tool_prompt_format": "python_list",
}
- agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),))
+ agent = Agent(llama_stack_client, agent_config, custom_tools=(custom_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(