mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
add support for built in tool type
This commit is contained in:
parent
517bc9ebea
commit
1a66ddc1b5
8 changed files with 83 additions and 75 deletions
|
@ -4,9 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
@ -25,13 +26,21 @@ class ToolParameter(BaseModel):
|
||||||
default: Optional[Any] = None
|
default: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolHost(Enum):
|
||||||
|
distribution = "distribution"
|
||||||
|
client = "client"
|
||||||
|
model_context_protocol = "model_context_protocol"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Tool(Resource):
|
class Tool(Resource):
|
||||||
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
||||||
tool_group: str
|
tool_group: str
|
||||||
|
tool_host: ToolHost
|
||||||
description: str
|
description: str
|
||||||
parameters: List[ToolParameter]
|
parameters: List[ToolParameter]
|
||||||
provider_id: Optional[str] = None
|
built_in_type: Optional[BuiltinTool] = None
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
default=ToolPromptFormat.json
|
default=ToolPromptFormat.json
|
||||||
|
@ -39,7 +48,8 @@ class Tool(Resource):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolDef(BaseModel):
|
class CustomToolDef(BaseModel):
|
||||||
|
type: Literal["custom"] = "custom"
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
parameters: List[ToolParameter]
|
parameters: List[ToolParameter]
|
||||||
|
@ -49,6 +59,19 @@ class ToolDef(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BuiltInToolDef(BaseModel):
|
||||||
|
type: Literal["built_in"] = "built_in"
|
||||||
|
built_in_type: BuiltinTool
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
ToolDef = register_schema(
|
||||||
|
Annotated[Union[CustomToolDef, BuiltInToolDef], Field(discriminator="type")],
|
||||||
|
name="ToolDef",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MCPToolGroupDef(BaseModel):
|
class MCPToolGroupDef(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -149,3 +172,14 @@ class ToolRuntime(Protocol):
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
"""Run a tool with the given arguments"""
|
"""Run a tool with the given arguments"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# Three tool types:
|
||||||
|
# 1. Built-in tools
|
||||||
|
# 2. Client tools
|
||||||
|
# 3. Model-context-protocol tools
|
||||||
|
|
||||||
|
# Suport registration of agents with tool groups
|
||||||
|
# TBD: Have a client utility to hide the pre processing tools.
|
||||||
|
# Attachments are confusing right now since they are inserted into memory first and retireved through RAG, even before a question is asked.
|
||||||
|
#
|
||||||
|
|
|
@ -516,6 +516,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
) -> None:
|
) -> None:
|
||||||
tools = []
|
tools = []
|
||||||
tool_defs = []
|
tool_defs = []
|
||||||
|
tool_host = ToolHost.distribution
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
if len(self.impls_by_provider_id.keys()) > 1:
|
if len(self.impls_by_provider_id.keys()) > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -529,25 +530,42 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
||||||
tool_group
|
tool_group
|
||||||
)
|
)
|
||||||
|
tool_host = ToolHost.model_context_protocol
|
||||||
elif isinstance(tool_group, UserDefinedToolGroupDef):
|
elif isinstance(tool_group, UserDefinedToolGroupDef):
|
||||||
tool_defs = tool_group.tools
|
tool_defs = tool_group.tools
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown tool group: {tool_group}")
|
raise ValueError(f"Unknown tool group: {tool_group}")
|
||||||
|
|
||||||
for tool_def in tool_defs:
|
for tool_def in tool_defs:
|
||||||
tools.append(
|
if isinstance(tool_def, CustomToolDef):
|
||||||
Tool(
|
tools.append(
|
||||||
identifier=tool_def.name,
|
Tool(
|
||||||
tool_group=tool_group_id,
|
identifier=tool_def.name,
|
||||||
description=tool_def.description,
|
tool_group=tool_group_id,
|
||||||
parameters=tool_def.parameters,
|
description=tool_def.description,
|
||||||
provider_id=provider_id,
|
parameters=tool_def.parameters,
|
||||||
tool_prompt_format=tool_def.tool_prompt_format,
|
provider_id=provider_id,
|
||||||
provider_resource_id=tool_def.name,
|
tool_prompt_format=tool_def.tool_prompt_format,
|
||||||
metadata=tool_def.metadata,
|
provider_resource_id=tool_def.name,
|
||||||
|
metadata=tool_def.metadata,
|
||||||
|
tool_host=tool_host,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(tool_def, BuiltInToolDef):
|
||||||
|
tools.append(
|
||||||
|
Tool(
|
||||||
|
identifier=tool_def.built_in_type.value,
|
||||||
|
tool_group=tool_group_id,
|
||||||
|
built_in_type=tool_def.built_in_type,
|
||||||
|
description="",
|
||||||
|
parameters=[],
|
||||||
|
provider_id=provider_id,
|
||||||
|
tool_prompt_format=ToolPromptFormat.json,
|
||||||
|
provider_resource_id=tool_def.built_in_type.value,
|
||||||
|
metadata=tool_def.metadata,
|
||||||
|
tool_host=tool_host,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
existing_tool = await self.get_tool(tool.identifier)
|
existing_tool = await self.get_tool(tool.identifier)
|
||||||
# Compare existing and new object if one exists
|
# Compare existing and new object if one exists
|
||||||
|
|
|
@ -621,6 +621,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
ret = []
|
ret = []
|
||||||
for tool_name in self.agent_config.available_tools:
|
for tool_name in self.agent_config.available_tools:
|
||||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||||
|
if tool.built_in_type:
|
||||||
|
ret.append(ToolDefinition(tool_name=tool.built_in_type))
|
||||||
|
continue
|
||||||
params = {}
|
params = {}
|
||||||
for param in tool.parameters:
|
for param in tool.parameters:
|
||||||
params[param.name] = ToolParamDefinition(
|
params[param.name] = ToolParamDefinition(
|
||||||
|
|
|
@ -25,8 +25,7 @@ class BraveSearchToolRuntimeImpl(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool):
|
async def register_tool(self, tool: Tool):
|
||||||
if tool.identifier != "brave_search":
|
pass
|
||||||
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
|
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
|
@ -26,8 +26,7 @@ class TavilySearchToolRuntimeImpl(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool):
|
async def register_tool(self, tool: Tool):
|
||||||
if tool.identifier != "tavily_search":
|
pass
|
||||||
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
|
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
CustomToolDef,
|
||||||
MCPToolGroupDef,
|
MCPToolGroupDef,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolGroupDef,
|
ToolGroupDef,
|
||||||
|
@ -52,7 +53,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tools.append(
|
tools.append(
|
||||||
ToolDef(
|
CustomToolDef(
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
description=tool.description,
|
description=tool.description,
|
||||||
parameters=parameters,
|
parameters=parameters,
|
||||||
|
|
|
@ -9,10 +9,12 @@ import tempfile
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelInput, ModelType
|
from llama_stack.apis.models import ModelInput, ModelType
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ToolDef,
|
BuiltInToolDef,
|
||||||
|
CustomToolDef,
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
UserDefinedToolGroupDef,
|
UserDefinedToolGroupDef,
|
||||||
|
@ -151,42 +153,12 @@ async def agents_stack(request, inference_model, safety_shield):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tool_groups = [
|
tool_groups = [
|
||||||
ToolGroupInput(
|
|
||||||
tool_group_id="brave_search_group",
|
|
||||||
tool_group=UserDefinedToolGroupDef(
|
|
||||||
tools=[
|
|
||||||
ToolDef(
|
|
||||||
name="brave_search",
|
|
||||||
description="brave_search",
|
|
||||||
parameters=[
|
|
||||||
ToolParameter(
|
|
||||||
name="query",
|
|
||||||
description="query",
|
|
||||||
parameter_type="string",
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
metadata={},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
provider_id="brave-search",
|
|
||||||
),
|
|
||||||
ToolGroupInput(
|
ToolGroupInput(
|
||||||
tool_group_id="tavily_search_group",
|
tool_group_id="tavily_search_group",
|
||||||
tool_group=UserDefinedToolGroupDef(
|
tool_group=UserDefinedToolGroupDef(
|
||||||
tools=[
|
tools=[
|
||||||
ToolDef(
|
BuiltInToolDef(
|
||||||
name="tavily_search",
|
built_in_type=BuiltinTool.brave_search,
|
||||||
description="tavily_search",
|
|
||||||
parameters=[
|
|
||||||
ToolParameter(
|
|
||||||
name="query",
|
|
||||||
description="query",
|
|
||||||
parameter_type="string",
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
metadata={},
|
metadata={},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
@ -197,7 +169,7 @@ async def agents_stack(request, inference_model, safety_shield):
|
||||||
tool_group_id="memory_group",
|
tool_group_id="memory_group",
|
||||||
tool_group=UserDefinedToolGroupDef(
|
tool_group=UserDefinedToolGroupDef(
|
||||||
tools=[
|
tools=[
|
||||||
ToolDef(
|
CustomToolDef(
|
||||||
name="memory",
|
name="memory",
|
||||||
description="memory",
|
description="memory",
|
||||||
parameters=[
|
parameters=[
|
||||||
|
@ -230,10 +202,8 @@ async def agents_stack(request, inference_model, safety_shield):
|
||||||
tool_group_id="code_interpreter_group",
|
tool_group_id="code_interpreter_group",
|
||||||
tool_group=UserDefinedToolGroupDef(
|
tool_group=UserDefinedToolGroupDef(
|
||||||
tools=[
|
tools=[
|
||||||
ToolDef(
|
BuiltInToolDef(
|
||||||
name="code_interpreter",
|
built_in_type=BuiltinTool.code_interpreter,
|
||||||
description="code_interpreter",
|
|
||||||
parameters=[],
|
|
||||||
metadata={},
|
metadata={},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
|
|
@ -150,9 +150,7 @@ async def create_agent_turn_with_search_tool(
|
||||||
assert isinstance(tool_execution, ToolExecutionStep)
|
assert isinstance(tool_execution, ToolExecutionStep)
|
||||||
assert len(tool_execution.tool_calls) > 0
|
assert len(tool_execution.tool_calls) > 0
|
||||||
actual_tool_name = tool_execution.tool_calls[0].tool_name
|
actual_tool_name = tool_execution.tool_calls[0].tool_name
|
||||||
if isinstance(actual_tool_name, BuiltinTool):
|
assert actual_tool_name.value == tool_name
|
||||||
actual_tool_name = actual_tool_name.value
|
|
||||||
assert actual_tool_name == tool_name
|
|
||||||
assert len(tool_execution.tool_responses) > 0
|
assert len(tool_execution.tool_responses) > 0
|
||||||
|
|
||||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||||
|
@ -305,20 +303,6 @@ class TestAgents:
|
||||||
"brave_search",
|
"brave_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_agent_turn_with_tavily_search(
|
|
||||||
self, agents_stack, search_query_messages, common_params
|
|
||||||
):
|
|
||||||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
|
||||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
|
||||||
|
|
||||||
await create_agent_turn_with_search_tool(
|
|
||||||
agents_stack,
|
|
||||||
search_query_messages,
|
|
||||||
common_params,
|
|
||||||
"tavily_search",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def check_event_types(turn_response):
|
def check_event_types(turn_response):
|
||||||
event_types = [chunk.event.payload.event_type for chunk in turn_response]
|
event_types = [chunk.event.payload.event_type for chunk in turn_response]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue