add support for built in tool type

This commit is contained in:
Dinesh Yeduguru 2024-12-23 16:50:03 -08:00
parent 517bc9ebea
commit 1a66ddc1b5
8 changed files with 83 additions and 75 deletions

View file

@ -4,9 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable
@ -25,13 +26,21 @@ class ToolParameter(BaseModel):
default: Optional[Any] = None
@json_schema_type
class ToolHost(Enum):
distribution = "distribution"
client = "client"
model_context_protocol = "model_context_protocol"
@json_schema_type
class Tool(Resource):
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
tool_group: str
tool_host: ToolHost
description: str
parameters: List[ToolParameter]
provider_id: Optional[str] = None
built_in_type: Optional[BuiltinTool] = None
metadata: Optional[Dict[str, Any]] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
@ -39,7 +48,8 @@ class Tool(Resource):
@json_schema_type
class ToolDef(BaseModel):
class CustomToolDef(BaseModel):
type: Literal["custom"] = "custom"
name: str
description: str
parameters: List[ToolParameter]
@ -49,6 +59,19 @@ class ToolDef(BaseModel):
)
@json_schema_type
class BuiltInToolDef(BaseModel):
type: Literal["built_in"] = "built_in"
built_in_type: BuiltinTool
metadata: Optional[Dict[str, Any]] = None
ToolDef = register_schema(
Annotated[Union[CustomToolDef, BuiltInToolDef], Field(discriminator="type")],
name="ToolDef",
)
@json_schema_type
class MCPToolGroupDef(BaseModel):
"""
@ -149,3 +172,14 @@ class ToolRuntime(Protocol):
) -> ToolInvocationResult:
"""Run a tool with the given arguments"""
...
# Three tool types:
# 1. Built-in tools
# 2. Client tools
# 3. Model-context-protocol tools
# Suport registration of agents with tool groups
# TBD: Have a client utility to hide the pre processing tools.
# Attachments are confusing right now since they are inserted into memory first and retireved through RAG, even before a question is asked.
#

View file

@ -516,6 +516,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
) -> None:
tools = []
tool_defs = []
tool_host = ToolHost.distribution
if provider_id is None:
if len(self.impls_by_provider_id.keys()) > 1:
raise ValueError(
@ -529,25 +530,42 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
tool_group
)
tool_host = ToolHost.model_context_protocol
elif isinstance(tool_group, UserDefinedToolGroupDef):
tool_defs = tool_group.tools
else:
raise ValueError(f"Unknown tool group: {tool_group}")
for tool_def in tool_defs:
tools.append(
Tool(
identifier=tool_def.name,
tool_group=tool_group_id,
description=tool_def.description,
parameters=tool_def.parameters,
provider_id=provider_id,
tool_prompt_format=tool_def.tool_prompt_format,
provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
if isinstance(tool_def, CustomToolDef):
tools.append(
Tool(
identifier=tool_def.name,
tool_group=tool_group_id,
description=tool_def.description,
parameters=tool_def.parameters,
provider_id=provider_id,
tool_prompt_format=tool_def.tool_prompt_format,
provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
tool_host=tool_host,
)
)
elif isinstance(tool_def, BuiltInToolDef):
tools.append(
Tool(
identifier=tool_def.built_in_type.value,
tool_group=tool_group_id,
built_in_type=tool_def.built_in_type,
description="",
parameters=[],
provider_id=provider_id,
tool_prompt_format=ToolPromptFormat.json,
provider_resource_id=tool_def.built_in_type.value,
metadata=tool_def.metadata,
tool_host=tool_host,
)
)
)
for tool in tools:
existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists

View file

@ -621,6 +621,9 @@ class ChatAgent(ShieldRunnerMixin):
ret = []
for tool_name in self.agent_config.available_tools:
tool = await self.tool_groups_api.get_tool(tool_name)
if tool.built_in_type:
ret.append(ToolDefinition(tool_name=tool.built_in_type))
continue
params = {}
for param in tool.parameters:
params[param.name] = ToolParamDefinition(

View file

@ -25,8 +25,7 @@ class BraveSearchToolRuntimeImpl(
pass
async def register_tool(self, tool: Tool):
if tool.identifier != "brave_search":
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
pass
async def unregister_tool(self, tool_id: str) -> None:
return

View file

@ -26,8 +26,7 @@ class TavilySearchToolRuntimeImpl(
pass
async def register_tool(self, tool: Tool):
if tool.identifier != "tavily_search":
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
pass
async def unregister_tool(self, tool_id: str) -> None:
return

View file

@ -8,6 +8,7 @@ from typing import Any, Dict, List
from urllib.parse import urlparse
from llama_stack.apis.tools import (
CustomToolDef,
MCPToolGroupDef,
ToolDef,
ToolGroupDef,
@ -52,7 +53,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
)
)
tools.append(
ToolDef(
CustomToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,

View file

@ -9,10 +9,12 @@ import tempfile
import pytest
import pytest_asyncio
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.apis.tools import (
ToolDef,
BuiltInToolDef,
CustomToolDef,
ToolGroupInput,
ToolParameter,
UserDefinedToolGroupDef,
@ -151,42 +153,12 @@ async def agents_stack(request, inference_model, safety_shield):
)
)
tool_groups = [
ToolGroupInput(
tool_group_id="brave_search_group",
tool_group=UserDefinedToolGroupDef(
tools=[
ToolDef(
name="brave_search",
description="brave_search",
parameters=[
ToolParameter(
name="query",
description="query",
parameter_type="string",
required=True,
),
],
metadata={},
),
],
),
provider_id="brave-search",
),
ToolGroupInput(
tool_group_id="tavily_search_group",
tool_group=UserDefinedToolGroupDef(
tools=[
ToolDef(
name="tavily_search",
description="tavily_search",
parameters=[
ToolParameter(
name="query",
description="query",
parameter_type="string",
required=True,
),
],
BuiltInToolDef(
built_in_type=BuiltinTool.brave_search,
metadata={},
),
],
@ -197,7 +169,7 @@ async def agents_stack(request, inference_model, safety_shield):
tool_group_id="memory_group",
tool_group=UserDefinedToolGroupDef(
tools=[
ToolDef(
CustomToolDef(
name="memory",
description="memory",
parameters=[
@ -230,10 +202,8 @@ async def agents_stack(request, inference_model, safety_shield):
tool_group_id="code_interpreter_group",
tool_group=UserDefinedToolGroupDef(
tools=[
ToolDef(
name="code_interpreter",
description="code_interpreter",
parameters=[],
BuiltInToolDef(
built_in_type=BuiltinTool.code_interpreter,
metadata={},
)
],

View file

@ -150,9 +150,7 @@ async def create_agent_turn_with_search_tool(
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
actual_tool_name = tool_execution.tool_calls[0].tool_name
if isinstance(actual_tool_name, BuiltinTool):
actual_tool_name = actual_tool_name.value
assert actual_tool_name == tool_name
assert actual_tool_name.value == tool_name
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)
@ -305,20 +303,6 @@ class TestAgents:
"brave_search",
)
@pytest.mark.asyncio
async def test_create_agent_turn_with_tavily_search(
self, agents_stack, search_query_messages, common_params
):
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
await create_agent_turn_with_search_tool(
agents_stack,
search_query_messages,
common_params,
"tavily_search",
)
def check_event_types(turn_response):
event_types = [chunk.event.payload.event_type for chunk in turn_response]