feat(tools)!: substantial clean up of "Tool" related datatypes (#3627)

This is a sweeping change to clean up some gunk around our "Tool"
definitions.

First, we had two types `Tool` and `ToolDef`. The first of these was a
"Resource" type for the registry but we had stopped registering tools
inside the Registry long back (and only registered ToolGroups.) The
latter was for specifying tools for the Agents API. This PR removes the
former and adds an optional `toolgroup_id` field to the latter.

Secondly, as pointed out by @bbrowning in
https://github.com/llamastack/llama-stack/pull/3003#issuecomment-3245270132,
we were doing a lossy conversion from a full JSON schema from the MCP
tool specification into our ToolDefinition to send it to the model.
There is no necessity to do this -- we ourselves aren't doing any
execution at all but merely passing it to the chat completions API which
supports this. By doing this (and by doing it poorly), we encountered
limitations like not supporting array items, or not resolving $refs,
etc.

To fix this, we replaced the `parameters` field by `{ input_schema,
output_schema }` which can be full blown JSON schemas.

Finally, there were some types in our llama-related chat format
conversion which needed some cleanup. We are taking this opportunity to
clean those up.

This PR is a substantial breaking change to the API. However, given our
window for introducing breaking changes, this suits us just fine. I will
be landing a concurrent `llama-stack-client` change as well since API
shapes are changing.
This commit is contained in:
Ashwin Bharambe 2025-10-02 15:12:03 -07:00 committed by GitHub
parent 1f5003d50e
commit ef0736527d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
179 changed files with 34186 additions and 9171 deletions

View file

@ -60,7 +60,6 @@ from llama_stack.apis.inference import (
StopReason,
SystemMessage,
ToolDefinition,
ToolParamDefinition,
ToolResponse,
ToolResponseMessage,
UserMessage,
@ -866,20 +865,12 @@ class ChatAgent(ShieldRunnerMixin):
for tool_def in self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
# Use input_schema from ToolDef directly
tool_name_to_def[tool_def.name] = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
items=param.items,
title=param.title,
default=param.default,
)
for param in tool_def.parameters
},
input_schema=tool_def.input_schema,
)
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
@ -889,44 +880,34 @@ class ChatAgent(ShieldRunnerMixin):
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
)
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
if input_tool_name is not None and not any(tool.name == input_tool_name for tool in tools.data):
raise ValueError(
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.name for tool in tools.data])}"
)
for tool_def in tools.data:
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
identifier: str | BuiltinTool | None = tool_def.identifier
identifier: str | BuiltinTool | None = tool_def.name
if identifier == "web_search":
identifier = BuiltinTool.brave_search
else:
identifier = BuiltinTool(identifier)
else:
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
if input_tool_name in (None, tool_def.identifier):
identifier = tool_def.identifier
if input_tool_name in (None, tool_def.name):
identifier = tool_def.name
else:
identifier = None
if tool_name_to_def.get(identifier, None):
raise ValueError(f"Tool {identifier} already exists")
if identifier:
tool_name_to_def[tool_def.identifier] = ToolDefinition(
tool_name_to_def[identifier] = ToolDefinition(
tool_name=identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
items=param.items,
title=param.title,
default=param.default,
)
for param in tool_def.parameters
},
input_schema=tool_def.input_schema,
)
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
self.tool_defs, self.tool_name_to_args = (
list(tool_name_to_def.values()),
@ -970,12 +951,18 @@ class ChatAgent(ShieldRunnerMixin):
tool_name_str = tool_name
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
try:
args = json.loads(tool_call.arguments)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse arguments for tool call: {tool_call.arguments}") from e
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str,
kwargs={
"session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**tool_call.arguments,
**args,
**self.tool_name_to_args.get(tool_name_str, {}),
},
)

View file

@ -62,22 +62,13 @@ def convert_tooldef_to_chat_tool(tool_def):
ChatCompletionToolParam suitable for OpenAI chat completion
"""
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.models.llama.datatypes import ToolDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
internal_tool_def = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
items=param.items,
)
for param in tool_def.parameters
},
input_schema=tool_def.input_schema,
)
return convert_tooldef_to_openai_tool(internal_tool_def)
@ -528,23 +519,15 @@ class StreamingResponseOrchestrator:
"""Process all tools and emit appropriate streaming events."""
from openai.types.chat import ChatCompletionToolParam
from llama_stack.apis.tools import Tool
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.apis.tools import ToolDef
from llama_stack.models.llama.datatypes import ToolDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
def make_openai_tool(tool_name: str, tool: ToolDef) -> ChatCompletionToolParam:
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
input_schema=tool.input_schema,
)
return convert_tooldef_to_openai_tool(tool_def)
@ -631,16 +614,11 @@ class StreamingResponseOrchestrator:
MCPListToolsTool(
name=t.name,
description=t.description,
input_schema={
input_schema=t.input_schema
or {
"type": "object",
"properties": {
p.name: {
"type": p.parameter_type,
"description": p.description,
}
for p in t.parameters
},
"required": [p.name for p in t.parameters if p.required],
"properties": {},
"required": [],
},
)
)

View file

@ -68,9 +68,7 @@ public class FunctionTagCustomToolGenerator {
{
"name": "{{t.tool_name}}",
"description": "{{t.description}}",
"parameters": {
"type": "dict",
"properties": { {{t.parameters}} }
"input_schema": { {{t.input_schema}} }
}
{{/let}}

View file

@ -33,7 +33,6 @@ from llama_stack.apis.tools import (
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.apis.vector_io import (
@ -301,13 +300,16 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
ToolDef(
name="knowledge_search",
description="Search for information in a database.",
parameters=[
ToolParameter(
name="query",
description="The query to search for. Can be a natural language sentence or keywords.",
parameter_type="string",
),
],
input_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for. Can be a natural language sentence or keywords.",
}
},
"required": ["query"],
},
),
]
)

View file

@ -89,8 +89,7 @@ def _convert_to_vllm_tool_calls_in_response(
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
arguments=call.function.arguments,
)
for call in tool_calls
]
@ -100,18 +99,6 @@ def _convert_to_vllm_tools_in_request(tools: list[ToolDefinition]) -> list[dict]
compat_tools = []
for tool in tools:
properties = {}
compat_required = []
if tool.parameters:
for tool_key, tool_param in tool.parameters.items():
properties[tool_key] = {"type": tool_param.param_type}
if tool_param.description:
properties[tool_key]["description"] = tool_param.description
if tool_param.default:
properties[tool_key]["default"] = tool_param.default
if tool_param.required:
compat_required.append(tool_key)
# The tool.tool_name can be a str or a BuiltinTool enum. If
# it's the latter, convert to a string.
tool_name = tool.tool_name
@ -123,10 +110,11 @@ def _convert_to_vllm_tools_in_request(tools: list[ToolDefinition]) -> list[dict]
"function": {
"name": tool_name,
"description": tool.description,
"parameters": {
"parameters": tool.input_schema
or {
"type": "object",
"properties": properties,
"required": compat_required,
"properties": {},
"required": [],
},
},
}
@ -161,7 +149,6 @@ def _process_vllm_chat_completion_end_of_stream(
for _index, tool_call_buf in sorted(tool_call_bufs.items()):
args_str = tool_call_buf.arguments or "{}"
try:
args = json.loads(args_str)
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -170,8 +157,7 @@ def _process_vllm_chat_completion_end_of_stream(
tool_call=ToolCall(
call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name,
arguments=args,
arguments_json=args_str,
arguments=args_str,
),
parse_status=ToolCallParseStatus.succeeded,
),

View file

@ -15,7 +15,6 @@ from llama_stack.apis.tools import (
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
@ -57,13 +56,16 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq
ToolDef(
name="web_search",
description="Search the web using Bing Search API",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
input_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for",
}
},
"required": ["query"],
},
)
]
)

View file

@ -14,7 +14,6 @@ from llama_stack.apis.tools import (
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
@ -56,13 +55,16 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
input_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for",
}
},
"required": ["query"],
},
built_in_type=BuiltinTool.brave_search,
)
]

View file

@ -15,7 +15,6 @@ from llama_stack.apis.tools import (
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
@ -56,13 +55,16 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
input_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for",
}
},
"required": ["query"],
},
)
]
)

View file

@ -15,7 +15,6 @@ from llama_stack.apis.tools import (
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
@ -57,13 +56,16 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
ToolDef(
name="wolfram_alpha",
description="Query WolframAlpha for computational knowledge",
parameters=[
ToolParameter(
name="query",
description="The query to compute",
parameter_type="string",
)
],
input_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to compute",
}
},
"required": ["query"],
},
)
]
)

View file

@ -125,7 +125,6 @@ from llama_stack.models.llama.datatypes import (
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
@ -537,18 +536,13 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
# arguments_json can be None, so attempt it first and fall back to arguments
if hasattr(tc, "arguments_json") and tc.arguments_json:
arguments = tc.arguments_json
else:
arguments = json.dumps(tc.arguments)
result["tool_calls"].append(
{
"id": tc.call_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": arguments,
"arguments": tc.arguments,
},
}
)
@ -641,7 +635,7 @@ async def convert_message_to_openai_dict_new(
id=tool.call_id,
function=OpenAIFunction(
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
arguments=json.dumps(tool.arguments),
arguments=tool.arguments, # Already a JSON string, don't double-encode
),
type="function",
)
@ -684,8 +678,7 @@ def convert_tool_call(
valid_tool_call = ToolCall(
call_id=tool_call.id,
tool_name=tool_call.function.name,
arguments=json.loads(tool_call.function.arguments),
arguments_json=tool_call.function.arguments,
arguments=tool_call.function.arguments,
)
except Exception:
return UnparseableToolCall(
@ -745,14 +738,8 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
ToolDefinition:
tool_name: str | BuiltinTool
description: Optional[str]
parameters: Optional[Dict[str, ToolParamDefinition]]
ToolParamDefinition:
param_type: str
description: Optional[str]
required: Optional[bool]
default: Optional[Any]
input_schema: Optional[Dict[str, Any]] # JSON Schema
output_schema: Optional[Dict[str, Any]] # JSON Schema (not used by OpenAI)
OpenAI spec -
@ -761,20 +748,11 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
"function": {
"name": tool_name,
"description": description,
"parameters": {
"type": "object",
"properties": {
param_name: {
"type": param_type,
"description": description,
"default": default,
},
...
},
"required": [param_name, ...],
},
"parameters": {<JSON Schema>},
},
}
NOTE: OpenAI does not support output_schema, so it is dropped here.
"""
out = {
"type": "function",
@ -783,37 +761,19 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
function = out["function"]
if isinstance(tool.tool_name, BuiltinTool):
function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient?
function["name"] = tool.tool_name.value
else:
function.update(name=tool.tool_name)
function["name"] = tool.tool_name
if tool.description:
function.update(description=tool.description)
function["description"] = tool.description
if tool.parameters:
parameters = {
"type": "object",
"properties": {},
}
properties = parameters["properties"]
required = []
for param_name, param in tool.parameters.items():
properties[param_name] = to_openai_param_type(param.param_type)
if param.description:
properties[param_name].update(description=param.description)
if param.default:
properties[param_name].update(default=param.default)
if param.items:
properties[param_name].update(items=param.items)
if param.title:
properties[param_name].update(title=param.title)
if param.required:
required.append(param_name)
if tool.input_schema:
# Pass through the entire JSON Schema as-is
function["parameters"] = tool.input_schema
if required:
parameters.update(required=required)
function.update(parameters=parameters)
# NOTE: OpenAI does not support output_schema, so we drop it here
# It's stored in LlamaStack for validation and other provider usage
return out
@ -874,22 +834,12 @@ def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) ->
tool_fn = tool.get("function", {})
tool_name = tool_fn.get("name", None)
tool_desc = tool_fn.get("description", None)
tool_params = tool_fn.get("parameters", None)
lls_tool_params = {}
if tool_params is not None:
tool_param_properties = tool_params.get("properties", {})
for tool_param_key, tool_param_value in tool_param_properties.items():
tool_param_def = ToolParamDefinition(
param_type=str(tool_param_value.get("type", None)),
description=tool_param_value.get("description", None),
)
lls_tool_params[tool_param_key] = tool_param_def
lls_tool = ToolDefinition(
tool_name=tool_name,
description=tool_desc,
parameters=lls_tool_params,
input_schema=tool_params, # Pass through entire JSON Schema
)
lls_tools.append(lls_tool)
return lls_tools
@ -939,8 +889,7 @@ def _convert_openai_tool_calls(
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
arguments=call.function.arguments,
)
for call in tool_calls
]
@ -1222,12 +1171,10 @@ async def convert_openai_chat_completion_stream(
)
try:
arguments = json.loads(buffer["arguments"])
tool_call = ToolCall(
call_id=buffer["call_id"],
tool_name=buffer["name"],
arguments=arguments,
arguments_json=buffer["arguments"],
arguments=buffer["arguments"],
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -1390,7 +1337,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
openai_tool_call = OpenAIChoiceDeltaToolCall(
index=0,
function=OpenAIChoiceDeltaToolCallFunction(
arguments=tool_call.arguments_json,
arguments=tool_call.arguments,
),
)
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])

View file

@ -286,34 +286,34 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
messages = [await _localize_image_url(m) for m in messages]
resp = await self.client.chat.completions.create(
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
params = await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
resp = await self.client.chat.completions.create(**params)
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
async def openai_embeddings(

View file

@ -20,7 +20,6 @@ from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
ToolInvocationResult,
ToolParameter,
)
from llama_stack.core.datatypes import AuthenticationRequiredError
from llama_stack.log import get_logger
@ -113,24 +112,12 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
async with client_wrapper(endpoint, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
required="default" not in param_schema,
items=param_schema.get("items", None),
title=param_schema.get("title", None),
default=param_schema.get("default", None),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
input_schema=tool.inputSchema,
output_schema=getattr(tool, "outputSchema", None),
metadata={
"endpoint": endpoint,
},