mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
add support for built in tool type
This commit is contained in:
parent
517bc9ebea
commit
1a66ddc1b5
8 changed files with 83 additions and 75 deletions
|
@ -4,9 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
@ -25,13 +26,21 @@ class ToolParameter(BaseModel):
|
|||
default: Optional[Any] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolHost(Enum):
|
||||
distribution = "distribution"
|
||||
client = "client"
|
||||
model_context_protocol = "model_context_protocol"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
||||
tool_group: str
|
||||
tool_host: ToolHost
|
||||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
provider_id: Optional[str] = None
|
||||
built_in_type: Optional[BuiltinTool] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
|
@ -39,7 +48,8 @@ class Tool(Resource):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDef(BaseModel):
|
||||
class CustomToolDef(BaseModel):
|
||||
type: Literal["custom"] = "custom"
|
||||
name: str
|
||||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
|
@ -49,6 +59,19 @@ class ToolDef(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BuiltInToolDef(BaseModel):
|
||||
type: Literal["built_in"] = "built_in"
|
||||
built_in_type: BuiltinTool
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
ToolDef = register_schema(
|
||||
Annotated[Union[CustomToolDef, BuiltInToolDef], Field(discriminator="type")],
|
||||
name="ToolDef",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MCPToolGroupDef(BaseModel):
|
||||
"""
|
||||
|
@ -149,3 +172,14 @@ class ToolRuntime(Protocol):
|
|||
) -> ToolInvocationResult:
|
||||
"""Run a tool with the given arguments"""
|
||||
...
|
||||
|
||||
|
||||
# Three tool types:
|
||||
# 1. Built-in tools
|
||||
# 2. Client tools
|
||||
# 3. Model-context-protocol tools
|
||||
|
||||
# Suport registration of agents with tool groups
|
||||
# TBD: Have a client utility to hide the pre processing tools.
|
||||
# Attachments are confusing right now since they are inserted into memory first and retireved through RAG, even before a question is asked.
|
||||
#
|
||||
|
|
|
@ -516,6 +516,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
) -> None:
|
||||
tools = []
|
||||
tool_defs = []
|
||||
tool_host = ToolHost.distribution
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id.keys()) > 1:
|
||||
raise ValueError(
|
||||
|
@ -529,25 +530,42 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
||||
tool_group
|
||||
)
|
||||
|
||||
tool_host = ToolHost.model_context_protocol
|
||||
elif isinstance(tool_group, UserDefinedToolGroupDef):
|
||||
tool_defs = tool_group.tools
|
||||
else:
|
||||
raise ValueError(f"Unknown tool group: {tool_group}")
|
||||
|
||||
for tool_def in tool_defs:
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool_def.name,
|
||||
tool_group=tool_group_id,
|
||||
description=tool_def.description,
|
||||
parameters=tool_def.parameters,
|
||||
provider_id=provider_id,
|
||||
tool_prompt_format=tool_def.tool_prompt_format,
|
||||
provider_resource_id=tool_def.name,
|
||||
metadata=tool_def.metadata,
|
||||
if isinstance(tool_def, CustomToolDef):
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool_def.name,
|
||||
tool_group=tool_group_id,
|
||||
description=tool_def.description,
|
||||
parameters=tool_def.parameters,
|
||||
provider_id=provider_id,
|
||||
tool_prompt_format=tool_def.tool_prompt_format,
|
||||
provider_resource_id=tool_def.name,
|
||||
metadata=tool_def.metadata,
|
||||
tool_host=tool_host,
|
||||
)
|
||||
)
|
||||
elif isinstance(tool_def, BuiltInToolDef):
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool_def.built_in_type.value,
|
||||
tool_group=tool_group_id,
|
||||
built_in_type=tool_def.built_in_type,
|
||||
description="",
|
||||
parameters=[],
|
||||
provider_id=provider_id,
|
||||
tool_prompt_format=ToolPromptFormat.json,
|
||||
provider_resource_id=tool_def.built_in_type.value,
|
||||
metadata=tool_def.metadata,
|
||||
tool_host=tool_host,
|
||||
)
|
||||
)
|
||||
)
|
||||
for tool in tools:
|
||||
existing_tool = await self.get_tool(tool.identifier)
|
||||
# Compare existing and new object if one exists
|
||||
|
|
|
@ -621,6 +621,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
ret = []
|
||||
for tool_name in self.agent_config.available_tools:
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
if tool.built_in_type:
|
||||
ret.append(ToolDefinition(tool_name=tool.built_in_type))
|
||||
continue
|
||||
params = {}
|
||||
for param in tool.parameters:
|
||||
params[param.name] = ToolParamDefinition(
|
||||
|
|
|
@ -25,8 +25,7 @@ class BraveSearchToolRuntimeImpl(
|
|||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
if tool.identifier != "brave_search":
|
||||
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
|
|
@ -26,8 +26,7 @@ class TavilySearchToolRuntimeImpl(
|
|||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
if tool.identifier != "tavily_search":
|
||||
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List
|
|||
from urllib.parse import urlparse
|
||||
|
||||
from llama_stack.apis.tools import (
|
||||
CustomToolDef,
|
||||
MCPToolGroupDef,
|
||||
ToolDef,
|
||||
ToolGroupDef,
|
||||
|
@ -52,7 +53,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
)
|
||||
)
|
||||
tools.append(
|
||||
ToolDef(
|
||||
CustomToolDef(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
|
|
|
@ -9,10 +9,12 @@ import tempfile
|
|||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.apis.tools import (
|
||||
ToolDef,
|
||||
BuiltInToolDef,
|
||||
CustomToolDef,
|
||||
ToolGroupInput,
|
||||
ToolParameter,
|
||||
UserDefinedToolGroupDef,
|
||||
|
@ -151,42 +153,12 @@ async def agents_stack(request, inference_model, safety_shield):
|
|||
)
|
||||
)
|
||||
tool_groups = [
|
||||
ToolGroupInput(
|
||||
tool_group_id="brave_search_group",
|
||||
tool_group=UserDefinedToolGroupDef(
|
||||
tools=[
|
||||
ToolDef(
|
||||
name="brave_search",
|
||||
description="brave_search",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="query",
|
||||
parameter_type="string",
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
metadata={},
|
||||
),
|
||||
],
|
||||
),
|
||||
provider_id="brave-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
tool_group_id="tavily_search_group",
|
||||
tool_group=UserDefinedToolGroupDef(
|
||||
tools=[
|
||||
ToolDef(
|
||||
name="tavily_search",
|
||||
description="tavily_search",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="query",
|
||||
parameter_type="string",
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
BuiltInToolDef(
|
||||
built_in_type=BuiltinTool.brave_search,
|
||||
metadata={},
|
||||
),
|
||||
],
|
||||
|
@ -197,7 +169,7 @@ async def agents_stack(request, inference_model, safety_shield):
|
|||
tool_group_id="memory_group",
|
||||
tool_group=UserDefinedToolGroupDef(
|
||||
tools=[
|
||||
ToolDef(
|
||||
CustomToolDef(
|
||||
name="memory",
|
||||
description="memory",
|
||||
parameters=[
|
||||
|
@ -230,10 +202,8 @@ async def agents_stack(request, inference_model, safety_shield):
|
|||
tool_group_id="code_interpreter_group",
|
||||
tool_group=UserDefinedToolGroupDef(
|
||||
tools=[
|
||||
ToolDef(
|
||||
name="code_interpreter",
|
||||
description="code_interpreter",
|
||||
parameters=[],
|
||||
BuiltInToolDef(
|
||||
built_in_type=BuiltinTool.code_interpreter,
|
||||
metadata={},
|
||||
)
|
||||
],
|
||||
|
|
|
@ -150,9 +150,7 @@ async def create_agent_turn_with_search_tool(
|
|||
assert isinstance(tool_execution, ToolExecutionStep)
|
||||
assert len(tool_execution.tool_calls) > 0
|
||||
actual_tool_name = tool_execution.tool_calls[0].tool_name
|
||||
if isinstance(actual_tool_name, BuiltinTool):
|
||||
actual_tool_name = actual_tool_name.value
|
||||
assert actual_tool_name == tool_name
|
||||
assert actual_tool_name.value == tool_name
|
||||
assert len(tool_execution.tool_responses) > 0
|
||||
|
||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||
|
@ -305,20 +303,6 @@ class TestAgents:
|
|||
"brave_search",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_tavily_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
):
|
||||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
await create_agent_turn_with_search_tool(
|
||||
agents_stack,
|
||||
search_query_messages,
|
||||
common_params,
|
||||
"tavily_search",
|
||||
)
|
||||
|
||||
|
||||
def check_event_types(turn_response):
|
||||
event_types = [chunk.event.payload.event_type for chunk in turn_response]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue