feat: RFC: tools API rework

# What does this PR do?
This PR proposes updates to the tools API in Inference and Agent.

Goals:
1. Agent's tool specification should be consistent with Inference's tool spec, but with add-ons.
2. Formal types should be defined for built in tools. Currently Agent tools args are untyped, e.g. how does one know that `builtin::rag_tool` takes a `vector_db_ids` param or even how to know 'builtin::rag_tool' is even available (in code, outside of docs)?

Inference:
1. BuiltinTool is to be removed and replaced by a formal `type` parameter.
2. 'brave_search' is replaced by 'web_search' to be more generic. It will still be translated back to brave_search when the prompt is constructed to be consistent with model training.
3. I'm not sure what `photogen` is. Maybe it can be removed?

Agent:
1. Uses the same format as in Inference for builtin tools.
2. New tools types are added, i.e. knowledge_sesarch (currently rag_tool), and MCP tool.
3. Toolgroup as a concept will be removed since it's really only used for MCP.
4. Instead MCPTool is its own type and available tools provided by the server will be expanded by default. Users can specify a subset of tool names if desired.

Example snippet:
```

agent = Agent(
    client,
    model=model_id,
    instructions="You are a helpful assistant. Use the tools you have access to for providing relevant answers.",
    tools=[
        KnowledgeSearchTool(vector_store_id="1234"),
        KnowledgeSearchTool(vector_store_id="5678", name="paper_search", description="Search research papers"),
        KnowledgeSearchTool(vector_store_id="1357", name="wiki_search", description="Search wiki pages"),
        # no need to register toolgroup, just pass in the server uri
        # all available tools will be used
        MCPTool(server_uri="http://localhost:8000/sse"),
        # can specify a subset of available tools
        MCPTool(server_uri="http://localhost:8000/sse", tool_names=["list_directory"]),
        MCPTool(server_uri="http://localhost:8000/sse", tool_names=["list_directory"]),
        # custom tool
        my_custom_tool,
    ]
)
```

## Test Plan
# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-03-26 11:14:40 -07:00
parent 39e094736f
commit 7027b537e0
22 changed files with 951 additions and 525 deletions

View file

@ -53,7 +53,7 @@ from llama_stack.apis.inference import (
SamplingParams,
StopReason,
SystemMessage,
ToolDefinition,
ToolDefinitionDeprecated,
ToolResponse,
ToolResponseMessage,
UserMessage,
@ -771,7 +771,7 @@ class ChatAgent(ShieldRunnerMixin):
for tool_def in self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
tool_name_to_def[tool_def.name] = ToolDefinition(
tool_name_to_def[tool_def.name] = ToolDefinitionDeprecated(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
@ -814,7 +814,7 @@ class ChatAgent(ShieldRunnerMixin):
if tool_name_to_def.get(identifier, None):
raise ValueError(f"Tool {identifier} already exists")
if identifier:
tool_name_to_def[tool_def.identifier] = ToolDefinition(
tool_name_to_def[tool_def.identifier] = ToolDefinitionDeprecated(
tool_name=identifier,
description=tool_def.description,
parameters={
@ -854,30 +854,23 @@ class ChatAgent(ShieldRunnerMixin):
tool_call: ToolCall,
) -> ToolInvocationResult:
tool_name = tool_call.tool_name
registered_tool_names = [tool_def.tool_name for tool_def in self.tool_defs]
registered_tool_names = list(self.tool_name_to_args.keys())
if tool_name not in registered_tool_names:
raise ValueError(
f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}"
)
if isinstance(tool_name, BuiltinTool):
if tool_name == BuiltinTool.brave_search:
tool_name_str = WEB_SEARCH_TOOL
else:
tool_name_str = tool_name.value
else:
tool_name_str = tool_name
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
logger.info(f"executing tool call: {tool_name} with args: {tool_call.arguments}")
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str,
tool_name=tool_name,
kwargs={
"session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**tool_call.arguments,
**self.tool_name_to_args.get(tool_name_str, {}),
**self.tool_name_to_args.get(tool_name, {}),
},
)
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
logger.debug(f"tool call {tool_name} completed with result: {result}")
return result
async def handle_documents(

View file

@ -16,7 +16,7 @@ from llama_stack.apis.inference import (
ToolChoice,
UserMessage,
)
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
from llama_stack.models.llama.datatypes import ToolDefinition, ToolType
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
@ -65,7 +65,7 @@ def _llama_stack_tools_to_openai_tools(
result = []
for t in tools:
if isinstance(t.tool_name, BuiltinTool):
if t.type != ToolType.function.value:
raise NotImplementedError("Built-in tools not yet implemented")
if t.parameters is None:
parameters = None

View file

@ -45,7 +45,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.datatypes import StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
@ -110,6 +110,8 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
for tool in tools:
properties = {}
compat_required = []
tool_name = tool.name
if tool.parameters:
for tool_key, tool_param in tool.parameters.items():
properties[tool_key] = {"type": tool_param.param_type}
@ -120,12 +122,6 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
if tool_param.required:
compat_required.append(tool_key)
# The tool.tool_name can be a str or a BuiltinTool enum. If
# it's the latter, convert to a string.
tool_name = tool.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
compat_tool = {
"type": "function",
"function": {

View file

@ -17,7 +17,6 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import BraveSearchToolConfig
@ -61,7 +60,6 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
parameter_type="string",
)
],
built_in_type=BuiltinTool.brave_search,
)
]

View file

@ -80,12 +80,12 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.models.llama.datatypes import (
BuiltinTool,
GreedySamplingStrategy,
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolType,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
@ -271,7 +271,7 @@ def process_chat_completion_response(
else:
# only return tool_calls if provided in the request
new_tool_calls = []
request_tools = {t.tool_name: t for t in request.tools}
request_tools = {t.name: t for t in request.tools}
for t in raw_message.tool_calls:
if t.tool_name in request_tools:
new_tool_calls.append(t)
@ -423,7 +423,7 @@ async def process_chat_completion_stream_response(
)
)
request_tools = {t.tool_name: t for t in request.tools}
request_tools = {t.name: t for t in request.tools}
for tool_call in message.tool_calls:
if tool_call.tool_name in request_tools:
yield ChatCompletionResponseStreamChunk(
@ -574,7 +574,7 @@ async def convert_message_to_openai_dict_new(
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
name=tool.tool_name,
arguments=json.dumps(tool.arguments),
),
type="function",
@ -638,7 +638,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
Convert a ToolDefinition to an OpenAI API-compatible dictionary.
ToolDefinition:
tool_name: str | BuiltinTool
tool_name: str
description: Optional[str]
parameters: Optional[Dict[str, ToolParamDefinition]]
@ -677,10 +677,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
}
function = out["function"]
if isinstance(tool.tool_name, BuiltinTool):
function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient?
else:
function.update(name=tool.tool_name)
function.update(name=tool.name)
if tool.description:
function.update(description=tool.description)
@ -761,6 +758,7 @@ def _convert_openai_tool_calls(
return [
ToolCall(
type=ToolType.function,
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
@ -975,6 +973,7 @@ async def convert_openai_chat_completion_stream(
try:
arguments = json.loads(buffer["arguments"])
tool_call = ToolCall(
type=ToolType.function,
call_id=buffer["call_id"],
tool_name=buffer["name"],
arguments=arguments,

View file

@ -43,6 +43,7 @@ from llama_stack.models.llama.datatypes import (
Role,
StopReason,
ToolPromptFormat,
ToolType,
is_multimodal,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
@ -374,8 +375,8 @@ def augment_messages_for_tools_llama_3_1(
messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
custom_tools = [t for t in request.tools if t.type == ToolType.function.value]
if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
@ -384,7 +385,6 @@ def augment_messages_for_tools_llama_3_1(
else:
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
messages.append(UserMessage(content=custom_template.render()))
@ -407,7 +407,7 @@ def augment_messages_for_tools_llama_3_2(
sys_content = ""
custom_tools, builtin_tools = [], []
for t in request.tools:
if isinstance(t.tool_name, str):
if t.type == ToolType.function.value:
custom_tools.append(t)
else:
builtin_tools.append(t)
@ -419,7 +419,7 @@ def augment_messages_for_tools_llama_3_2(
sys_content += tool_template.render()
sys_content += "\n"
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
custom_tools = [t for t in request.tools if t.type == ToolType.function.value]
if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list: