more substantial cleanup of Tool vs. ToolDef crap

This commit is contained in:
Ashwin Bharambe 2025-10-01 15:54:14 -07:00
parent fa6ed28aea
commit 6749c853c0
34 changed files with 2676 additions and 615 deletions

View file

@ -19,48 +19,6 @@ from llama_stack.schema_utils import json_schema_type, webmethod
from .rag_tool import RAGToolRuntime
@json_schema_type
class ToolParameter(BaseModel):
"""Parameter definition for a tool.
:param name: Name of the parameter
:param parameter_type: Type of the parameter (e.g., string, integer)
:param description: Human-readable description of what the parameter does
:param required: Whether this parameter is required for tool invocation
:param items: Type of the elements when parameter_type is array
:param title: (Optional) Title of the parameter
:param default: (Optional) Default value for the parameter if not provided
"""
name: str
parameter_type: str
description: str
required: bool = Field(default=True)
items: dict | None = None
title: str | None = None
default: Any | None = None
@json_schema_type
class Tool(Resource):
"""A tool that can be invoked by agents.
:param type: Type of resource, always 'tool'
:param toolgroup_id: ID of the tool group this tool belongs to
:param description: Human-readable description of what the tool does
:param input_schema: JSON Schema for the tool's input parameters
:param output_schema: JSON Schema for the tool's output
:param metadata: (Optional) Additional metadata about the tool
"""
type: Literal[ResourceType.tool] = ResourceType.tool
toolgroup_id: str
description: str
input_schema: dict[str, Any] | None = None
output_schema: dict[str, Any] | None = None
metadata: dict[str, Any] | None = None
@json_schema_type
class ToolDef(BaseModel):
"""Tool definition used in runtime contexts.
@ -70,8 +28,10 @@ class ToolDef(BaseModel):
:param input_schema: (Optional) JSON Schema for tool inputs (MCP inputSchema)
:param output_schema: (Optional) JSON Schema for tool outputs (MCP outputSchema)
:param metadata: (Optional) Additional metadata about the tool
:param toolgroup_id: (Optional) ID of the tool group this tool belongs to
"""
toolgroup_id: str | None = None
name: str
description: str | None = None
input_schema: dict[str, Any] | None = None
@ -126,7 +86,7 @@ class ToolInvocationResult(BaseModel):
class ToolStore(Protocol):
async def get_tool(self, tool_name: str) -> Tool: ...
async def get_tool(self, tool_name: str) -> ToolDef: ...
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
@ -139,15 +99,6 @@ class ListToolGroupsResponse(BaseModel):
data: list[ToolGroup]
class ListToolsResponse(BaseModel):
"""Response containing a list of tools.
:param data: List of tools
"""
data: list[Tool]
class ListToolDefsResponse(BaseModel):
"""Response containing a list of tool definitions.
@ -198,11 +149,11 @@ class ToolGroups(Protocol):
...
@webmethod(route="/tools", method="GET", level=LLAMA_STACK_API_V1)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
"""List tools with optional tool group.
:param toolgroup_id: The ID of the tool group to list tools for.
:returns: A ListToolsResponse.
:returns: A ListToolDefsResponse.
"""
...
@ -210,11 +161,11 @@ class ToolGroups(Protocol):
async def get_tool(
self,
tool_name: str,
) -> Tool:
) -> ToolDef:
"""Get a tool by its name.
:param tool_name: The name of the tool to get.
:returns: A Tool.
:returns: A ToolDef.
"""
...

View file

@ -22,7 +22,7 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.access_control.datatypes import AccessRule
@ -84,15 +84,11 @@ class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
pass
class ToolWithOwner(Tool, ResourceWithOwner):
pass
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
pass
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | ToolGroup
RoutableObjectWithProvider = Annotated[
ModelWithOwner
@ -101,7 +97,6 @@ RoutableObjectWithProvider = Annotated[
| DatasetWithOwner
| ScoringFnWithOwner
| BenchmarkWithOwner
| ToolWithOwner
| ToolGroupWithOwner,
Field(discriminator="type"),
]

View file

@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.tools import (
ListToolsResponse,
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
@ -86,6 +86,6 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolsResponse:
) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.list_tools(tool_group_id)

View file

@ -8,7 +8,7 @@ from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.errors import ToolGroupNotFoundError
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.apis.tools import ListToolDefsResponse, ListToolGroupsResponse, ToolDef, ToolGroup, ToolGroups
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
from llama_stack.log import get_logger
@ -27,7 +27,7 @@ def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
toolgroups_to_tools: dict[str, list[Tool]] = {}
toolgroups_to_tools: dict[str, list[ToolDef]] = {}
tool_to_toolgroup: dict[str, str] = {}
# overridden
@ -43,7 +43,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
routing_key = self.tool_to_toolgroup[routing_key]
return await super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
if toolgroup_id:
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
toolgroup_id = group_id
@ -68,31 +68,19 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
continue
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
return ListToolsResponse(data=all_tools)
return ListToolDefsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup):
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
# TODO: kill this Tool vs ToolDef distinction
tooldefs = tooldefs_response.data
tools = []
for t in tooldefs:
tools.append(
Tool(
identifier=t.name,
toolgroup_id=toolgroup.identifier,
description=t.description or "",
input_schema=t.input_schema,
output_schema=t.output_schema,
metadata=t.metadata,
provider_id=toolgroup.provider_id,
)
)
t.toolgroup_id = toolgroup.identifier
self.toolgroups_to_tools[toolgroup.identifier] = tools
for tool in tools:
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
self.toolgroups_to_tools[toolgroup.identifier] = tooldefs
for tool in tooldefs:
self.tool_to_toolgroup[tool.name] = toolgroup.identifier
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
@ -103,12 +91,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
raise ToolGroupNotFoundError(toolgroup_id)
return tool_group
async def get_tool(self, tool_name: str) -> Tool:
async def get_tool(self, tool_name: str) -> ToolDef:
if tool_name in self.tool_to_toolgroup:
toolgroup_id = self.tool_to_toolgroup[tool_name]
tools = self.toolgroups_to_tools[toolgroup_id]
for tool in tools:
if tool.identifier == tool_name:
if tool.name == tool_name:
return tool
raise ValueError(f"Tool '{tool_name}' not found")
@ -133,7 +121,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
# baked in some of the code and tests right now.
if not toolgroup.mcp_endpoint:
await self._index_tools(toolgroup)
return toolgroup
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
await self.unregister_object(await self.get_tool_group(toolgroup_id))

View file

@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v9"
KEY_VERSION = "v10"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -81,7 +81,7 @@ def tool_chat_page():
for toolgroup_id in toolgroup_selection:
tools = client.tools.list(toolgroup_id=toolgroup_id)
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
grouped_tools[toolgroup_id] = [tool.name for tool in tools]
total_tools += len(tools)
st.markdown(f"Active Tools: 🛠 {total_tools}")

View file

@ -37,14 +37,7 @@ RecursiveType = Primitive | list[Primitive] | dict[str, Primitive]
class ToolCall(BaseModel):
call_id: str
tool_name: BuiltinTool | str
# Plan is to deprecate the Dict in favor of a JSON string
# that is parsed on the client side instead of trying to manage
# the recursive type here.
# Making this a union so that client side can start prepping for this change.
# Eventually, we will remove both the Dict and arguments_json field,
# and arguments will just be a str
arguments: str | dict[str, RecursiveType]
arguments_json: str | None = None
arguments: str
@field_validator("tool_name", mode="before")
@classmethod

View file

@ -232,8 +232,7 @@ class ChatFormat:
ToolCall(
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
arguments=json.dumps(tool_arguments),
)
)
content = ""

View file

@ -298,8 +298,7 @@ class ChatFormat:
ToolCall(
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
arguments=json.dumps(tool_arguments),
)
)
content = ""

View file

@ -804,61 +804,34 @@ class ChatAgent(ShieldRunnerMixin):
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
)
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
if input_tool_name is not None and not any(tool.name == input_tool_name for tool in tools.data):
raise ValueError(
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.name for tool in tools.data])}"
)
for tool_def in tools.data:
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
identifier: str | BuiltinTool | None = tool_def.identifier
identifier: str | BuiltinTool | None = tool_def.name
if identifier == "web_search":
identifier = BuiltinTool.brave_search
else:
identifier = BuiltinTool(identifier)
else:
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
if input_tool_name in (None, tool_def.identifier):
identifier = tool_def.identifier
if input_tool_name in (None, tool_def.name):
identifier = tool_def.name
else:
identifier = None
if tool_name_to_def.get(identifier, None):
raise ValueError(f"Tool {identifier} already exists")
if identifier:
# Build JSON Schema from tool parameters
properties = {}
required = []
for param in tool_def.parameters:
param_schema = {
"type": param.parameter_type,
"description": param.description,
}
if param.default is not None:
param_schema["default"] = param.default
if param.items is not None:
param_schema["items"] = param.items
if param.title is not None:
param_schema["title"] = param.title
properties[param.name] = param_schema
if param.required:
required.append(param.name)
input_schema = {
"type": "object",
"properties": properties,
"required": required,
}
tool_name_to_def[tool_def.identifier] = ToolDefinition(
tool_name_to_def[identifier] = ToolDefinition(
tool_name=identifier,
description=tool_def.description,
input_schema=input_schema,
input_schema=tool_def.input_schema,
)
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
self.tool_defs, self.tool_name_to_args = (
list(tool_name_to_def.values()),

View file

@ -33,7 +33,6 @@ from llama_stack.apis.tools import (
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.apis.vector_io import (
@ -301,13 +300,16 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
ToolDef(
name="knowledge_search",
description="Search for information in a database.",
parameters=[
ToolParameter(
name="query",
description="The query to search for. Can be a natural language sentence or keywords.",
parameter_type="string",
),
],
input_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for. Can be a natural language sentence or keywords.",
}
},
"required": ["query"],
},
),
]
)

View file

@ -99,8 +99,7 @@ def _convert_to_vllm_tool_calls_in_response(
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
arguments=call.function.arguments,
)
for call in tool_calls
]
@ -160,7 +159,6 @@ def _process_vllm_chat_completion_end_of_stream(
for _index, tool_call_buf in sorted(tool_call_bufs.items()):
args_str = tool_call_buf.arguments or "{}"
try:
args = json.loads(args_str)
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -169,8 +167,7 @@ def _process_vllm_chat_completion_end_of_stream(
tool_call=ToolCall(
call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name,
arguments=args,
arguments_json=args_str,
arguments=args_str,
),
parse_status=ToolCallParseStatus.succeeded,
),

View file

@ -15,7 +15,6 @@ from llama_stack.apis.tools import (
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
@ -57,13 +56,16 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq
ToolDef(
name="web_search",
description="Search the web using Bing Search API",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
input_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for",
}
},
"required": ["query"],
},
)
]
)

View file

@ -14,7 +14,6 @@ from llama_stack.apis.tools import (
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
@ -56,13 +55,16 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
input_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for",
}
},
"required": ["query"],
},
built_in_type=BuiltinTool.brave_search,
)
]

View file

@ -15,7 +15,6 @@ from llama_stack.apis.tools import (
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
@ -56,13 +55,16 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
input_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for",
}
},
"required": ["query"],
},
)
]
)

View file

@ -15,7 +15,6 @@ from llama_stack.apis.tools import (
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
@ -57,13 +56,16 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
ToolDef(
name="wolfram_alpha",
description="Query WolframAlpha for computational knowledge",
parameters=[
ToolParameter(
name="query",
description="The query to compute",
parameter_type="string",
)
],
input_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to compute",
}
},
"required": ["query"],
},
)
]
)

View file

@ -538,18 +538,13 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
# arguments_json can be None, so attempt it first and fall back to arguments
if hasattr(tc, "arguments_json") and tc.arguments_json:
arguments = tc.arguments_json
else:
arguments = json.dumps(tc.arguments)
result["tool_calls"].append(
{
"id": tc.call_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": arguments,
"arguments": tc.arguments,
},
}
)
@ -685,8 +680,7 @@ def convert_tool_call(
valid_tool_call = ToolCall(
call_id=tool_call.id,
tool_name=tool_call.function.name,
arguments=json.loads(tool_call.function.arguments),
arguments_json=tool_call.function.arguments,
arguments=tool_call.function.arguments,
)
except Exception:
return UnparseableToolCall(
@ -897,8 +891,7 @@ def _convert_openai_tool_calls(
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
arguments=call.function.arguments,
)
for call in tool_calls
]
@ -1184,8 +1177,7 @@ async def convert_openai_chat_completion_stream(
tool_call = ToolCall(
call_id=buffer["call_id"],
tool_name=buffer["name"],
arguments=arguments,
arguments_json=buffer["arguments"],
arguments=buffer["arguments"],
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -1418,7 +1410,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
openai_tool_call = OpenAIChoiceDeltaToolCall(
index=0,
function=OpenAIChoiceDeltaToolCallFunction(
arguments=tool_call.arguments_json,
arguments=tool_call.arguments,
),
)
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])