add support for built in tool type

This commit is contained in:
Dinesh Yeduguru 2024-12-23 16:50:03 -08:00
parent 517bc9ebea
commit 1a66ddc1b5
8 changed files with 83 additions and 75 deletions

View file

@ -4,9 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Union from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from llama_models.llama3.api.datatypes import ToolPromptFormat from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat
from llama_models.schema_utils import json_schema_type, register_schema, webmethod from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable from typing_extensions import Protocol, runtime_checkable
@ -25,13 +26,21 @@ class ToolParameter(BaseModel):
default: Optional[Any] = None default: Optional[Any] = None
@json_schema_type
class ToolHost(Enum):
distribution = "distribution"
client = "client"
model_context_protocol = "model_context_protocol"
@json_schema_type @json_schema_type
class Tool(Resource): class Tool(Resource):
type: Literal[ResourceType.tool.value] = ResourceType.tool.value type: Literal[ResourceType.tool.value] = ResourceType.tool.value
tool_group: str tool_group: str
tool_host: ToolHost
description: str description: str
parameters: List[ToolParameter] parameters: List[ToolParameter]
provider_id: Optional[str] = None built_in_type: Optional[BuiltinTool] = None
metadata: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field( tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json default=ToolPromptFormat.json
@ -39,7 +48,8 @@ class Tool(Resource):
@json_schema_type @json_schema_type
class ToolDef(BaseModel): class CustomToolDef(BaseModel):
type: Literal["custom"] = "custom"
name: str name: str
description: str description: str
parameters: List[ToolParameter] parameters: List[ToolParameter]
@ -49,6 +59,19 @@ class ToolDef(BaseModel):
) )
@json_schema_type
class BuiltInToolDef(BaseModel):
type: Literal["built_in"] = "built_in"
built_in_type: BuiltinTool
metadata: Optional[Dict[str, Any]] = None
ToolDef = register_schema(
Annotated[Union[CustomToolDef, BuiltInToolDef], Field(discriminator="type")],
name="ToolDef",
)
@json_schema_type @json_schema_type
class MCPToolGroupDef(BaseModel): class MCPToolGroupDef(BaseModel):
""" """
@ -149,3 +172,14 @@ class ToolRuntime(Protocol):
) -> ToolInvocationResult: ) -> ToolInvocationResult:
"""Run a tool with the given arguments""" """Run a tool with the given arguments"""
... ...
# Three tool types:
# 1. Built-in tools
# 2. Client tools
# 3. Model-context-protocol tools
# Suport registration of agents with tool groups
# TBD: Have a client utility to hide the pre processing tools.
# Attachments are confusing right now since they are inserted into memory first and retireved through RAG, even before a question is asked.
#

View file

@ -516,6 +516,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
) -> None: ) -> None:
tools = [] tools = []
tool_defs = [] tool_defs = []
tool_host = ToolHost.distribution
if provider_id is None: if provider_id is None:
if len(self.impls_by_provider_id.keys()) > 1: if len(self.impls_by_provider_id.keys()) > 1:
raise ValueError( raise ValueError(
@ -529,25 +530,42 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools( tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
tool_group tool_group
) )
tool_host = ToolHost.model_context_protocol
elif isinstance(tool_group, UserDefinedToolGroupDef): elif isinstance(tool_group, UserDefinedToolGroupDef):
tool_defs = tool_group.tools tool_defs = tool_group.tools
else: else:
raise ValueError(f"Unknown tool group: {tool_group}") raise ValueError(f"Unknown tool group: {tool_group}")
for tool_def in tool_defs: for tool_def in tool_defs:
tools.append( if isinstance(tool_def, CustomToolDef):
Tool( tools.append(
identifier=tool_def.name, Tool(
tool_group=tool_group_id, identifier=tool_def.name,
description=tool_def.description, tool_group=tool_group_id,
parameters=tool_def.parameters, description=tool_def.description,
provider_id=provider_id, parameters=tool_def.parameters,
tool_prompt_format=tool_def.tool_prompt_format, provider_id=provider_id,
provider_resource_id=tool_def.name, tool_prompt_format=tool_def.tool_prompt_format,
metadata=tool_def.metadata, provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
tool_host=tool_host,
)
)
elif isinstance(tool_def, BuiltInToolDef):
tools.append(
Tool(
identifier=tool_def.built_in_type.value,
tool_group=tool_group_id,
built_in_type=tool_def.built_in_type,
description="",
parameters=[],
provider_id=provider_id,
tool_prompt_format=ToolPromptFormat.json,
provider_resource_id=tool_def.built_in_type.value,
metadata=tool_def.metadata,
tool_host=tool_host,
)
) )
)
for tool in tools: for tool in tools:
existing_tool = await self.get_tool(tool.identifier) existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists # Compare existing and new object if one exists

View file

@ -621,6 +621,9 @@ class ChatAgent(ShieldRunnerMixin):
ret = [] ret = []
for tool_name in self.agent_config.available_tools: for tool_name in self.agent_config.available_tools:
tool = await self.tool_groups_api.get_tool(tool_name) tool = await self.tool_groups_api.get_tool(tool_name)
if tool.built_in_type:
ret.append(ToolDefinition(tool_name=tool.built_in_type))
continue
params = {} params = {}
for param in tool.parameters: for param in tool.parameters:
params[param.name] = ToolParamDefinition( params[param.name] = ToolParamDefinition(

View file

@ -25,8 +25,7 @@ class BraveSearchToolRuntimeImpl(
pass pass
async def register_tool(self, tool: Tool): async def register_tool(self, tool: Tool):
if tool.identifier != "brave_search": pass
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
async def unregister_tool(self, tool_id: str) -> None: async def unregister_tool(self, tool_id: str) -> None:
return return

View file

@ -26,8 +26,7 @@ class TavilySearchToolRuntimeImpl(
pass pass
async def register_tool(self, tool: Tool): async def register_tool(self, tool: Tool):
if tool.identifier != "tavily_search": pass
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
async def unregister_tool(self, tool_id: str) -> None: async def unregister_tool(self, tool_id: str) -> None:
return return

View file

@ -8,6 +8,7 @@ from typing import Any, Dict, List
from urllib.parse import urlparse from urllib.parse import urlparse
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
CustomToolDef,
MCPToolGroupDef, MCPToolGroupDef,
ToolDef, ToolDef,
ToolGroupDef, ToolGroupDef,
@ -52,7 +53,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
) )
) )
tools.append( tools.append(
ToolDef( CustomToolDef(
name=tool.name, name=tool.name,
description=tool.description, description=tool.description,
parameters=parameters, parameters=parameters,

View file

@ -9,10 +9,12 @@ import tempfile
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.models import ModelInput, ModelType from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ToolDef, BuiltInToolDef,
CustomToolDef,
ToolGroupInput, ToolGroupInput,
ToolParameter, ToolParameter,
UserDefinedToolGroupDef, UserDefinedToolGroupDef,
@ -151,42 +153,12 @@ async def agents_stack(request, inference_model, safety_shield):
) )
) )
tool_groups = [ tool_groups = [
ToolGroupInput(
tool_group_id="brave_search_group",
tool_group=UserDefinedToolGroupDef(
tools=[
ToolDef(
name="brave_search",
description="brave_search",
parameters=[
ToolParameter(
name="query",
description="query",
parameter_type="string",
required=True,
),
],
metadata={},
),
],
),
provider_id="brave-search",
),
ToolGroupInput( ToolGroupInput(
tool_group_id="tavily_search_group", tool_group_id="tavily_search_group",
tool_group=UserDefinedToolGroupDef( tool_group=UserDefinedToolGroupDef(
tools=[ tools=[
ToolDef( BuiltInToolDef(
name="tavily_search", built_in_type=BuiltinTool.brave_search,
description="tavily_search",
parameters=[
ToolParameter(
name="query",
description="query",
parameter_type="string",
required=True,
),
],
metadata={}, metadata={},
), ),
], ],
@ -197,7 +169,7 @@ async def agents_stack(request, inference_model, safety_shield):
tool_group_id="memory_group", tool_group_id="memory_group",
tool_group=UserDefinedToolGroupDef( tool_group=UserDefinedToolGroupDef(
tools=[ tools=[
ToolDef( CustomToolDef(
name="memory", name="memory",
description="memory", description="memory",
parameters=[ parameters=[
@ -230,10 +202,8 @@ async def agents_stack(request, inference_model, safety_shield):
tool_group_id="code_interpreter_group", tool_group_id="code_interpreter_group",
tool_group=UserDefinedToolGroupDef( tool_group=UserDefinedToolGroupDef(
tools=[ tools=[
ToolDef( BuiltInToolDef(
name="code_interpreter", built_in_type=BuiltinTool.code_interpreter,
description="code_interpreter",
parameters=[],
metadata={}, metadata={},
) )
], ],

View file

@ -150,9 +150,7 @@ async def create_agent_turn_with_search_tool(
assert isinstance(tool_execution, ToolExecutionStep) assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0 assert len(tool_execution.tool_calls) > 0
actual_tool_name = tool_execution.tool_calls[0].tool_name actual_tool_name = tool_execution.tool_calls[0].tool_name
if isinstance(actual_tool_name, BuiltinTool): assert actual_tool_name.value == tool_name
actual_tool_name = actual_tool_name.value
assert actual_tool_name == tool_name
assert len(tool_execution.tool_responses) > 0 assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages) check_turn_complete_event(turn_response, session_id, search_query_messages)
@ -305,20 +303,6 @@ class TestAgents:
"brave_search", "brave_search",
) )
@pytest.mark.asyncio
async def test_create_agent_turn_with_tavily_search(
self, agents_stack, search_query_messages, common_params
):
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
await create_agent_turn_with_search_tool(
agents_stack,
search_query_messages,
common_params,
"tavily_search",
)
def check_event_types(turn_response): def check_event_types(turn_response):
event_types = [chunk.event.payload.event_type for chunk in turn_response] event_types = [chunk.event.payload.event_type for chunk in turn_response]