feat: RFC: tools API rework

# What does this PR do?
This PR proposes updates to the tools API in Inference and Agent.

Goals:
1. Agent's tool specification should be consistent with Inference's tool spec, but with add-ons.
2. Formal types should be defined for built in tools. Currently Agent tools args are untyped, e.g. how does one know that `builtin::rag_tool` takes a `vector_db_ids` param or even how to know 'builtin::rag_tool' is even available (in code, outside of docs)?

Inference:
1. BuiltinTool is to be removed and replaced by a formal `type` parameter.
2. 'brave_search' is replaced by 'web_search' to be more generic. It will still be translated back to brave_search when the prompt is constructed to be consistent with model training.
3. I'm not sure what `photogen` is. Maybe it can be removed?

Agent:
1. Uses the same format as in Inference for builtin tools.
2. New tools types are added, i.e. knowledge_sesarch (currently rag_tool), and MCP tool.
3. Toolgroup as a concept will be removed since it's really only used for MCP.
4. Instead MCPTool is its own type and available tools provided by the server will be expanded by default. Users can specify a subset of tool names if desired.

Example snippet:
```

agent = Agent(
    client,
    model=model_id,
    instructions="You are a helpful assistant. Use the tools you have access to for providing relevant answers.",
    tools=[
        KnowledgeSearchTool(vector_store_id="1234"),
        KnowledgeSearchTool(vector_store_id="5678", name="paper_search", description="Search research papers"),
        KnowledgeSearchTool(vector_store_id="1357", name="wiki_search", description="Search wiki pages"),
        # no need to register toolgroup, just pass in the server uri
        # all available tools will be used
        MCPTool(server_uri="http://localhost:8000/sse"),
        # can specify a subset of available tools
        MCPTool(server_uri="http://localhost:8000/sse", tool_names=["list_directory"]),
        MCPTool(server_uri="http://localhost:8000/sse", tool_names=["list_directory"]),
        # custom tool
        my_custom_tool,
    ]
)
```

## Test Plan
# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-03-26 11:14:40 -07:00
parent 39e094736f
commit 7027b537e0
22 changed files with 951 additions and 525 deletions

View file

@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
ToolChoice,
ToolConfig,
ToolDefinition,
ToolDefinitionDeprecated,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
@ -54,6 +55,9 @@ from llama_stack.apis.tools import (
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
ToolType,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import RoutingTable
@ -229,7 +233,7 @@ class InferenceRouter(Inference):
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tools: Optional[List[ToolDefinition] | List[ToolDefinitionDeprecated]] = None,
tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
@ -259,24 +263,42 @@ class InferenceRouter(Inference):
params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params)
tools = tools or []
# TODO: remove ToolDefinitionDeprecated in v0.1.10
converted_tools = []
for tool in tools or []:
if isinstance(tool, ToolDefinitionDeprecated):
logger.warning(f"ToolDefinitionDeprecated: {tool}, use ToolDefinition instead")
converted_tools.append(tool.to_tool_definition())
else:
converted_tools.append(tool)
if tool_config.tool_choice == ToolChoice.none:
tools = []
converted_tools = []
elif tool_config.tool_choice == ToolChoice.auto:
pass
elif tool_config.tool_choice == ToolChoice.required:
pass
else:
# verify tool_choice is one of the tools
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
if tool_config.tool_choice not in tool_names:
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
for t in converted_tools:
if t.type == ToolType.function.value:
if tool_config.tool_choice == t.name:
break
elif t.type in (
ToolType.web_search.value,
ToolType.wolfram_alpha.value,
ToolType.code_interpreter.value,
):
if tool_config.tool_choice == t.type:
break
else:
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {converted_tools}")
params = dict(
model_id=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tools=converted_tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,