feat(tools): use { input_schema, output_schema } for ToolDefinition

This commit is contained in:
Ashwin Bharambe 2025-09-30 19:13:15 -07:00
parent 42414a1a1b
commit 139320e19f
20 changed files with 1989 additions and 386 deletions

View file

@ -27,14 +27,12 @@ from llama_stack.models.llama.datatypes import (
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
register_schema(ToolCall)
register_schema(ToolParamDefinition)
register_schema(ToolDefinition)
from enum import StrEnum

View file

@ -65,13 +65,15 @@ class ToolDef(BaseModel):
:param name: Name of the tool
:param description: (Optional) Human-readable description of what the tool does
:param parameters: (Optional) List of parameters this tool accepts
:param input_schema: (Optional) JSON Schema for tool inputs (MCP inputSchema)
:param output_schema: (Optional) JSON Schema for tool outputs (MCP outputSchema)
:param metadata: (Optional) Additional metadata about the tool
"""
name: str
description: str | None = None
parameters: list[ToolParameter] | None = None
input_schema: dict[str, Any] | None = None
output_schema: dict[str, Any] | None = None
metadata: dict[str, Any] | None = None

View file

@ -257,7 +257,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
return result
except Exception as e:
if logger.isEnabledFor(logging.DEBUG):
if logger.isEnabledFor(logging.INFO):
logger.exception(f"Error executing endpoint {route=} {method=}")
else:
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")

View file

@ -88,19 +88,11 @@ class StopReason(Enum):
out_of_tokens = "out_of_tokens"
class ToolParamDefinition(BaseModel):
param_type: str
description: str | None = None
required: bool | None = True
items: Any | None = None
title: str | None = None
default: Any | None = None
class ToolDefinition(BaseModel):
tool_name: BuiltinTool | str
description: str | None = None
parameters: dict[str, ToolParamDefinition] | None = None
input_schema: dict[str, Any] | None = None
output_schema: dict[str, Any] | None = None
@field_validator("tool_name", mode="before")
@classmethod

View file

@ -18,7 +18,6 @@ from typing import Any
from llama_stack.apis.inference import (
BuiltinTool,
ToolDefinition,
ToolParamDefinition,
)
from .base import PromptTemplate, PromptTemplateGeneratorBase
@ -101,11 +100,8 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
{%- set tname = t.tool_name -%}
{%- set tdesc = t.description -%}
{%- set tparams = t.parameters -%}
{%- set required_params = [] -%}
{%- for name, param in tparams.items() if param.required == true -%}
{%- set _ = required_params.append(name) -%}
{%- endfor -%}
{%- set tprops = t.input_schema.get('properties', {}) -%}
{%- set required_params = t.input_schema.get('required', []) -%}
{
"type": "function",
"function": {
@ -114,11 +110,11 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
"parameters": {
"type": "object",
"properties": [
{%- for name, param in tparams.items() %}
{%- for name, param in tprops.items() %}
{
"{{name}}": {
"type": "object",
"description": "{{param.description}}"
"description": "{{param.get('description', '')}}"
}
}{% if not loop.last %},{% endif %}
{%- endfor %}
@ -143,17 +139,19 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
ToolDefinition(
tool_name="trending_songs",
description="Returns the trending songs on a Music site",
parameters={
"n": ToolParamDefinition(
param_type="int",
description="The number of songs to return",
required=True,
),
"genre": ToolParamDefinition(
param_type="str",
description="The genre of the songs to return",
required=False,
),
input_schema={
"type": "object",
"properties": {
"n": {
"type": "int",
"description": "The number of songs to return",
},
"genre": {
"type": "str",
"description": "The genre of the songs to return",
},
},
"required": ["n"],
},
),
]
@ -170,11 +168,14 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
{#- manually setting up JSON because jinja sorts keys in unexpected ways -#}
{%- set tname = t.tool_name -%}
{%- set tdesc = t.description -%}
{%- set modified_params = t.parameters.copy() -%}
{%- for key, value in modified_params.items() -%}
{%- if 'default' in value -%}
{%- set _ = value.pop('default', None) -%}
{%- set tprops = t.input_schema.get('properties', {}) -%}
{%- set modified_params = {} -%}
{%- for key, value in tprops.items() -%}
{%- set param_copy = value.copy() -%}
{%- if 'default' in param_copy -%}
{%- set _ = param_copy.pop('default', None) -%}
{%- endif -%}
{%- set _ = modified_params.update({key: param_copy}) -%}
{%- endfor -%}
{%- set tparams = modified_params | tojson -%}
Use the function '{{ tname }}' to '{{ tdesc }}':
@ -205,17 +206,19 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
ToolDefinition(
tool_name="trending_songs",
description="Returns the trending songs on a Music site",
parameters={
"n": ToolParamDefinition(
param_type="int",
description="The number of songs to return",
required=True,
),
"genre": ToolParamDefinition(
param_type="str",
description="The genre of the songs to return",
required=False,
),
input_schema={
"type": "object",
"properties": {
"n": {
"type": "int",
"description": "The number of songs to return",
},
"genre": {
"type": "str",
"description": "The genre of the songs to return",
},
},
"required": ["n"],
},
),
]
@ -255,11 +258,8 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
{%- set tname = t.tool_name -%}
{%- set tdesc = t.description -%}
{%- set tparams = t.parameters -%}
{%- set required_params = [] -%}
{%- for name, param in tparams.items() if param.required == true -%}
{%- set _ = required_params.append(name) -%}
{%- endfor -%}
{%- set tprops = t.input_schema.get('properties', {}) -%}
{%- set required_params = t.input_schema.get('required', []) -%}
{
"name": "{{tname}}",
"description": "{{tdesc}}",
@ -267,11 +267,11 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
"type": "dict",
"required": {{ required_params | tojson }},
"properties": {
{%- for name, param in tparams.items() %}
{%- for name, param in tprops.items() %}
"{{name}}": {
"type": "{{param.param_type}}",
"description": "{{param.description}}"{% if param.default %},
"default": "{{param.default}}"{% endif %}
"type": "{{param.get('type', 'string')}}",
"description": "{{param.get('description', '')}}"{% if param.get('default') %},
"default": "{{param.get('default')}}"{% endif %}
}{% if not loop.last %},{% endif %}
{%- endfor %}
}
@ -299,18 +299,20 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
ToolDefinition(
tool_name="get_weather",
description="Get weather info for places",
parameters={
"city": ToolParamDefinition(
param_type="string",
description="The name of the city to get the weather for",
required=True,
),
"metric": ToolParamDefinition(
param_type="string",
description="The metric for weather. Options are: celsius, fahrenheit",
required=False,
default="celsius",
),
input_schema={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for",
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius",
},
},
"required": ["city"],
},
),
]

View file

@ -13,7 +13,7 @@
import textwrap
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
from llama_stack.apis.inference import ToolDefinition
from llama_stack.models.llama.llama3.prompt_templates.base import (
PromptTemplate,
PromptTemplateGeneratorBase,
@ -81,11 +81,8 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
{%- set tname = t.tool_name -%}
{%- set tdesc = t.description -%}
{%- set tparams = t.parameters -%}
{%- set required_params = [] -%}
{%- for name, param in tparams.items() if param.required == true -%}
{%- set _ = required_params.append(name) -%}
{%- endfor -%}
{%- set tprops = t.input_schema.get('properties', {}) -%}
{%- set required_params = t.input_schema.get('required', []) -%}
{
"name": "{{tname}}",
"description": "{{tdesc}}",
@ -93,11 +90,11 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
"type": "dict",
"required": {{ required_params | tojson }},
"properties": {
{%- for name, param in tparams.items() %}
{%- for name, param in tprops.items() %}
"{{name}}": {
"type": "{{param.param_type}}",
"description": "{{param.description}}"{% if param.default %},
"default": "{{param.default}}"{% endif %}
"type": "{{param.get('type', 'string')}}",
"description": "{{param.get('description', '')}}"{% if param.get('default') %},
"default": "{{param.get('default')}}"{% endif %}
}{% if not loop.last %},{% endif %}
{%- endfor %}
}
@ -119,18 +116,20 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
ToolDefinition(
tool_name="get_weather",
description="Get weather info for places",
parameters={
"city": ToolParamDefinition(
param_type="string",
description="The name of the city to get the weather for",
required=True,
),
"metric": ToolParamDefinition(
param_type="string",
description="The metric for weather. Options are: celsius, fahrenheit",
required=False,
default="celsius",
),
input_schema={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for",
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius",
},
},
"required": ["city"],
},
),
]

View file

@ -54,7 +54,6 @@ from llama_stack.apis.inference import (
StopReason,
SystemMessage,
ToolDefinition,
ToolParamDefinition,
ToolResponse,
ToolResponseMessage,
UserMessage,
@ -790,20 +789,38 @@ class ChatAgent(ShieldRunnerMixin):
for tool_def in self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
# Build JSON Schema from tool parameters
properties = {}
required = []
for param in tool_def.parameters:
param_schema = {
"type": param.parameter_type,
"description": param.description,
}
if param.default is not None:
param_schema["default"] = param.default
if param.items is not None:
param_schema["items"] = param.items
if param.title is not None:
param_schema["title"] = param.title
properties[param.name] = param_schema
if param.required:
required.append(param.name)
input_schema = {
"type": "object",
"properties": properties,
"required": required,
}
tool_name_to_def[tool_def.name] = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
items=param.items,
title=param.title,
default=param.default,
)
for param in tool_def.parameters
},
input_schema=input_schema,
)
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
@ -835,20 +852,37 @@ class ChatAgent(ShieldRunnerMixin):
if tool_name_to_def.get(identifier, None):
raise ValueError(f"Tool {identifier} already exists")
if identifier:
# Build JSON Schema from tool parameters
properties = {}
required = []
for param in tool_def.parameters:
param_schema = {
"type": param.parameter_type,
"description": param.description,
}
if param.default is not None:
param_schema["default"] = param.default
if param.items is not None:
param_schema["items"] = param.items
if param.title is not None:
param_schema["title"] = param.title
properties[param.name] = param_schema
if param.required:
required.append(param.name)
input_schema = {
"type": "object",
"properties": properties,
"required": required,
}
tool_name_to_def[tool_def.identifier] = ToolDefinition(
tool_name=identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
items=param.items,
title=param.title,
default=param.default,
)
for param in tool_def.parameters
},
input_schema=input_schema,
)
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})

View file

@ -62,22 +62,38 @@ def convert_tooldef_to_chat_tool(tool_def):
ChatCompletionToolParam suitable for OpenAI chat completion
"""
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.models.llama.datatypes import ToolDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
# Build JSON Schema from tool parameters
properties = {}
required = []
for param in tool_def.parameters:
param_schema = {
"type": param.parameter_type,
"description": param.description,
}
if param.default is not None:
param_schema["default"] = param.default
if param.items is not None:
param_schema["items"] = param.items
properties[param.name] = param_schema
if param.required:
required.append(param.name)
input_schema = {
"type": "object",
"properties": properties,
"required": required,
}
internal_tool_def = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
items=param.items,
)
for param in tool_def.parameters
},
input_schema=input_schema,
)
return convert_tooldef_to_openai_tool(internal_tool_def)
@ -526,22 +542,37 @@ class StreamingResponseOrchestrator:
from openai.types.chat import ChatCompletionToolParam
from llama_stack.apis.tools import Tool
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.models.llama.datatypes import ToolDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
# Build JSON Schema from tool parameters
properties = {}
required = []
for param in tool.parameters:
param_schema = {
"type": param.parameter_type,
"description": param.description,
}
if param.default is not None:
param_schema["default"] = param.default
properties[param.name] = param_schema
if param.required:
required.append(param.name)
input_schema = {
"type": "object",
"properties": properties,
"required": required,
}
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
input_schema=input_schema,
)
return convert_tooldef_to_openai_tool(tool_def)

View file

@ -127,7 +127,6 @@ from llama_stack.models.llama.datatypes import (
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
@ -747,14 +746,8 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
ToolDefinition:
tool_name: str | BuiltinTool
description: Optional[str]
parameters: Optional[Dict[str, ToolParamDefinition]]
ToolParamDefinition:
param_type: str
description: Optional[str]
required: Optional[bool]
default: Optional[Any]
input_schema: Optional[Dict[str, Any]] # JSON Schema
output_schema: Optional[Dict[str, Any]] # JSON Schema (not used by OpenAI)
OpenAI spec -
@ -763,20 +756,11 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
"function": {
"name": tool_name,
"description": description,
"parameters": {
"type": "object",
"properties": {
param_name: {
"type": param_type,
"description": description,
"default": default,
},
...
},
"required": [param_name, ...],
},
"parameters": {<JSON Schema>},
},
}
NOTE: OpenAI does not support output_schema, so it is dropped here.
"""
out = {
"type": "function",
@ -785,37 +769,19 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
function = out["function"]
if isinstance(tool.tool_name, BuiltinTool):
function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient?
function["name"] = tool.tool_name.value
else:
function.update(name=tool.tool_name)
function["name"] = tool.tool_name
if tool.description:
function.update(description=tool.description)
function["description"] = tool.description
if tool.parameters:
parameters = {
"type": "object",
"properties": {},
}
properties = parameters["properties"]
required = []
for param_name, param in tool.parameters.items():
properties[param_name] = to_openai_param_type(param.param_type)
if param.description:
properties[param_name].update(description=param.description)
if param.default:
properties[param_name].update(default=param.default)
if param.items:
properties[param_name].update(items=param.items)
if param.title:
properties[param_name].update(title=param.title)
if param.required:
required.append(param_name)
if tool.input_schema:
# Pass through the entire JSON Schema as-is
function["parameters"] = tool.input_schema
if required:
parameters.update(required=required)
function.update(parameters=parameters)
# NOTE: OpenAI does not support output_schema, so we drop it here
# It's stored in LlamaStack for validation and other provider usage
return out
@ -876,22 +842,12 @@ def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) ->
tool_fn = tool.get("function", {})
tool_name = tool_fn.get("name", None)
tool_desc = tool_fn.get("description", None)
tool_params = tool_fn.get("parameters", None)
lls_tool_params = {}
if tool_params is not None:
tool_param_properties = tool_params.get("properties", {})
for tool_param_key, tool_param_value in tool_param_properties.items():
tool_param_def = ToolParamDefinition(
param_type=str(tool_param_value.get("type", None)),
description=tool_param_value.get("description", None),
)
lls_tool_params[tool_param_key] = tool_param_def
lls_tool = ToolDefinition(
tool_name=tool_name,
description=tool_desc,
parameters=lls_tool_params,
input_schema=tool_params, # Pass through entire JSON Schema
)
lls_tools.append(lls_tool)
return lls_tools

View file

@ -20,7 +20,6 @@ from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
ToolInvocationResult,
ToolParameter,
)
from llama_stack.core.datatypes import AuthenticationRequiredError
from llama_stack.log import get_logger
@ -113,24 +112,12 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
async with client_wrapper(endpoint, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
required="default" not in param_schema,
items=param_schema.get("items", None),
title=param_schema.get("title", None),
default=param_schema.get("default", None),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
input_schema=tool.inputSchema,
output_schema=getattr(tool, "outputSchema", None),
metadata={
"endpoint": endpoint,
},