diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index b1bef0882..d1d2c266d 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -3711,6 +3711,12 @@ "type": "string" } }, + "custom_tools": { + "type": "array", + "items": { + "$ref": "#/components/schemas/CustomToolDef" + } + }, "preprocessing_tools": { "type": "array", "items": { @@ -3747,6 +3753,111 @@ "enable_session_persistence" ] }, + "CustomToolDef": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "custom", + "default": "custom" + }, + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "parameters": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolParameter" + } + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "tool_prompt_format": { + "$ref": "#/components/schemas/ToolPromptFormat", + "default": "json" + } + }, + "additionalProperties": false, + "required": [ + "type", + "name", + "description", + "parameters", + "metadata" + ] + }, + "ToolParameter": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "parameter_type": { + "type": "string" + }, + "description": { + "type": "string" + }, + "required": { + "type": "boolean" + }, + "default": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "additionalProperties": false, + "required": [ + "name", + "parameter_type", + "description", + "required" + ] + }, "CreateAgentRequest": { "type": "object", "properties": { @@ -4403,39 +4514,16 @@ "session_id" ] }, - "MCPToolGroupDef": { + "BuiltInToolDef": { "type": "object", "properties": { "type": { "type": "string", - "const": "model_context_protocol", - "default": "model_context_protocol" + "const": "built_in", + "default": "built_in" }, - "endpoint": { - "$ref": "#/components/schemas/URL" - } - }, - "additionalProperties": false, - "required": [ - "type", - "endpoint" - ], - "title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information." - }, - "ToolDef": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "description": { - "type": "string" - }, - "parameters": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolParameter" - } + "built_in_type": { + "$ref": "#/components/schemas/BuiltinTool" }, "metadata": { "type": "object", @@ -4461,18 +4549,41 @@ } ] } - }, - "tool_prompt_format": { - "$ref": "#/components/schemas/ToolPromptFormat", - "default": "json" } }, "additionalProperties": false, "required": [ - "name", - "description", - "parameters", - "metadata" + "type", + "built_in_type" + ] + }, + "MCPToolGroupDef": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "model_context_protocol", + "default": "model_context_protocol" + }, + "endpoint": { + "$ref": "#/components/schemas/URL" + } + }, + "additionalProperties": false, + "required": [ + "type", + "endpoint" + ], + "title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information." + }, + "ToolDef": { + "oneOf": [ + { + "$ref": "#/components/schemas/CustomToolDef" + }, + { + "$ref": "#/components/schemas/BuiltInToolDef" + } ] }, "ToolGroupDef": { @@ -4485,52 +4596,6 @@ } ] }, - "ToolParameter": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "parameter_type": { - "type": "string" - }, - "description": { - "type": "string" - }, - "required": { - "type": "boolean" - }, - "default": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "additionalProperties": false, - "required": [ - "name", - "parameter_type", - "description", - "required" - ] - }, "UserDefinedToolGroupDef": { "type": "object", "properties": { @@ -5797,6 +5862,9 @@ "tool_group": { "type": "string" }, + "tool_host": { + "$ref": "#/components/schemas/ToolHost" + }, "description": { "type": "string" }, @@ -5806,6 +5874,9 @@ "$ref": "#/components/schemas/ToolParameter" } }, + "built_in_type": { + "$ref": "#/components/schemas/BuiltinTool" + }, "metadata": { "type": "object", "additionalProperties": { @@ -5840,12 +5911,22 @@ "required": [ "identifier", "provider_resource_id", + "provider_id", "type", "tool_group", + "tool_host", "description", "parameters" ] }, + "ToolHost": { + "type": "string", + "enum": [ + "distribution", + "client", + "model_context_protocol" + ] + }, "ToolGroup": { "type": "object", "properties": { @@ -7942,6 +8023,10 @@ "name": "BenchmarkEvalTaskConfig", "description": "" }, + { + "name": "BuiltInToolDef", + "description": "" + }, { "name": "BuiltinTool", "description": "" @@ -8002,6 +8087,10 @@ "name": "CreateAgentTurnRequest", "description": "" }, + { + "name": "CustomToolDef", + "description": "" + }, { "name": "DPOAlignmentConfig", "description": "" @@ -8481,6 +8570,10 @@ { "name": "ToolGroups" }, + { + "name": "ToolHost", + "description": "" + }, { "name": "ToolInvocationResult", "description": "" @@ -8624,6 +8717,7 @@ "BatchCompletionRequest", "BatchCompletionResponse", "BenchmarkEvalTaskConfig", + "BuiltInToolDef", "BuiltinTool", "CancelTrainingJobRequest", "ChatCompletionRequest", @@ -8639,6 +8733,7 @@ "CreateAgentRequest", "CreateAgentSessionRequest", "CreateAgentTurnRequest", + "CustomToolDef", "DPOAlignmentConfig", "DataConfig", "Dataset", @@ -8746,6 +8841,7 @@ "ToolExecutionStep", "ToolGroup", "ToolGroupDef", + "ToolHost", "ToolInvocationResult", "ToolParamDefinition", "ToolParameter", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 5da647b54..4f7a9c91c 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -21,6 +21,10 @@ components: items: type: string type: array + custom_tools: + items: + $ref: '#/components/schemas/CustomToolDef' + type: array enable_session_persistence: type: boolean input_shields: @@ -389,6 +393,29 @@ components: - type - eval_candidate type: object + BuiltInToolDef: + additionalProperties: false + properties: + built_in_type: + $ref: '#/components/schemas/BuiltinTool' + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + type: + const: built_in + default: built_in + type: string + required: + - type + - built_in_type + type: object BuiltinTool: enum: - brave_search @@ -607,6 +634,41 @@ components: - session_id - messages type: object + CustomToolDef: + additionalProperties: false + properties: + description: + type: string + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + name: + type: string + parameters: + items: + $ref: '#/components/schemas/ToolParameter' + type: array + tool_prompt_format: + $ref: '#/components/schemas/ToolPromptFormat' + default: json + type: + const: custom + default: custom + type: string + required: + - type + - name + - description + - parameters + - metadata + type: object DPOAlignmentConfig: additionalProperties: false properties: @@ -2557,6 +2619,8 @@ components: Tool: additionalProperties: false properties: + built_in_type: + $ref: '#/components/schemas/BuiltinTool' description: type: string identifier: @@ -2581,6 +2645,8 @@ components: type: string tool_group: type: string + tool_host: + $ref: '#/components/schemas/ToolHost' tool_prompt_format: $ref: '#/components/schemas/ToolPromptFormat' default: json @@ -2591,8 +2657,10 @@ components: required: - identifier - provider_resource_id + - provider_id - type - tool_group + - tool_host - description - parameters type: object @@ -2661,35 +2729,9 @@ components: - required type: string ToolDef: - additionalProperties: false - properties: - description: - type: string - metadata: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - name: - type: string - parameters: - items: - $ref: '#/components/schemas/ToolParameter' - type: array - tool_prompt_format: - $ref: '#/components/schemas/ToolPromptFormat' - default: json - required: - - name - - description - - parameters - - metadata - type: object + oneOf: + - $ref: '#/components/schemas/CustomToolDef' + - $ref: '#/components/schemas/BuiltInToolDef' ToolDefinition: additionalProperties: false properties: @@ -2761,6 +2803,12 @@ components: oneOf: - $ref: '#/components/schemas/MCPToolGroupDef' - $ref: '#/components/schemas/UserDefinedToolGroupDef' + ToolHost: + enum: + - distribution + - client + - model_context_protocol + type: string ToolInvocationResult: additionalProperties: false properties: @@ -4738,6 +4786,8 @@ tags: - description: name: BenchmarkEvalTaskConfig +- description: + name: BuiltInToolDef - description: name: BuiltinTool - description: name: CreateAgentTurnRequest +- description: + name: CustomToolDef - description: name: DPOAlignmentConfig @@ -5111,6 +5163,8 @@ tags: - description: name: ToolGroupDef - name: ToolGroups +- description: + name: ToolHost - description: name: ToolInvocationResult @@ -5224,6 +5278,7 @@ x-tagGroups: - BatchCompletionRequest - BatchCompletionResponse - BenchmarkEvalTaskConfig + - BuiltInToolDef - BuiltinTool - CancelTrainingJobRequest - ChatCompletionRequest @@ -5239,6 +5294,7 @@ x-tagGroups: - CreateAgentRequest - CreateAgentSessionRequest - CreateAgentTurnRequest + - CustomToolDef - DPOAlignmentConfig - DataConfig - Dataset @@ -5346,6 +5402,7 @@ x-tagGroups: - ToolExecutionStep - ToolGroup - ToolGroupDef + - ToolHost - ToolInvocationResult - ToolParamDefinition - ToolParameter diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 325ce9490..3348211c9 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -18,13 +18,11 @@ from typing import ( runtime_checkable, ) -from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated from llama_stack.apis.common.content_types import URL, InterleavedContent -from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig from llama_stack.apis.inference import ( CompletionMessage, SamplingParams, @@ -140,6 +138,7 @@ class AgentConfigCommon(BaseModel): input_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list) available_tools: Optional[List[str]] = Field(default_factory=list) + custom_tools: Optional[List[CustomToolDef]] = Field(default_factory=list) preprocessing_tools: Optional[List[str]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index cc4ef38a9..ba190f567 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -400,6 +400,10 @@ class ChatAgent(ShieldRunnerMixin): output_attachments = [] n_iter = 0 + # Build a map of custom tools to their definitions for faster lookup + custom_tools = {} + for tool in self.agent_config.custom_tools: + custom_tools[tool.name] = tool while True: msg = input_messages[-1] @@ -530,6 +534,9 @@ class ChatAgent(ShieldRunnerMixin): else: log.info(f"{str(message)}") tool_call = message.tool_calls[0] + if tool_call.tool_name in custom_tools: + yield message + return step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -619,6 +626,22 @@ class ChatAgent(ShieldRunnerMixin): async def _get_tools(self) -> List[ToolDefinition]: ret = [] + for tool in self.agent_config.custom_tools: + params = {} + for param in tool.parameters: + params[param.name] = ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + ret.append( + ToolDefinition( + tool_name=tool.name, + description=tool.description, + parameters=params, + ) + ) for tool_name in self.agent_config.available_tools: tool = await self.tool_groups_api.get_tool(tool_name) if tool.built_in_type: diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 7939259d1..ef3c087fa 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -9,16 +9,13 @@ from typing import Dict, List from uuid import uuid4 import pytest - -from llama_stack.providers.tests.env import get_env_or_fail from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.custom_tool import CustomTool from llama_stack_client.lib.agents.event_logger import EventLogger -from llama_stack_client.types import CompletionMessage, ToolResponseMessage +from llama_stack_client.types import ToolResponseMessage from llama_stack_client.types.agent_create_params import AgentConfig -from llama_stack_client.types.tool_param_definition_param import ( - ToolParamDefinitionParam, -) +from llama_stack_client.types.custom_tool_def import Parameter +from llama_stack_client.types.shared.completion_message import CompletionMessage class TestCustomTool(CustomTool): @@ -54,13 +51,17 @@ class TestCustomTool(CustomTool): def get_description(self) -> str: return "Get the boiling point of a imaginary liquids (eg. polyjuice)" - def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]: + def get_params_definition(self) -> Dict[str, Parameter]: return { - "liquid_name": ToolParamDefinitionParam( - param_type="string", description="The name of the liquid", required=True + "liquid_name": Parameter( + name="liquid_name", + parameter_type="string", + description="The name of the liquid", + required=True, ), - "celcius": ToolParamDefinitionParam( - param_type="boolean", + "celcius": Parameter( + name="celcius", + parameter_type="boolean", description="Whether to return the boiling point in Celcius", required=False, ), @@ -203,37 +204,16 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config): def test_custom_tool(llama_stack_client, agent_config): + custom_tool = TestCustomTool() agent_config = { **agent_config, "model": "meta-llama/Llama-3.2-3B-Instruct", - "tools": [ - { - "type": "brave_search", - "engine": "brave", - "api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"), - }, - { - "function_name": "get_boiling_point", - "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)", - "parameters": { - "liquid_name": { - "param_type": "str", - "description": "The name of the liquid", - "required": True, - }, - "celcius": { - "param_type": "boolean", - "description": "Whether to return the boiling point in Celcius", - "required": False, - }, - }, - "type": "function_call", - }, - ], + "available_tools": ["brave_search"], + "custom_tools": [custom_tool.get_tool_definition()], "tool_prompt_format": "python_list", } - agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),)) + agent = Agent(llama_stack_client, agent_config, custom_tools=(custom_tool,)) session_id = agent.create_session(f"test-session-{uuid4()}") response = agent.create_turn(