From 1a66ddc1b55f0a0a3f143de38225fa4fc154fbe7 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 23 Dec 2024 16:50:03 -0800 Subject: [PATCH] add support for built in tool type --- llama_stack/apis/tools/tools.py | 40 ++++++++++++++-- .../distribution/routers/routing_tables.py | 42 ++++++++++++----- .../agents/meta_reference/agent_instance.py | 3 ++ .../tool_runtime/brave_search/brave_search.py | 3 +- .../tavily_search/tavily_search.py | 3 +- .../model_context_protocol.py | 3 +- .../providers/tests/agents/fixtures.py | 46 ++++--------------- .../providers/tests/agents/test_agents.py | 18 +------- 8 files changed, 83 insertions(+), 75 deletions(-) diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index e3c2ca52c..65d5b8444 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -4,9 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum from typing import Annotated, Any, Dict, List, Literal, Optional, Union -from llama_models.llama3.api.datatypes import ToolPromptFormat +from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Protocol, runtime_checkable @@ -25,13 +26,21 @@ class ToolParameter(BaseModel): default: Optional[Any] = None +@json_schema_type +class ToolHost(Enum): + distribution = "distribution" + client = "client" + model_context_protocol = "model_context_protocol" + + @json_schema_type class Tool(Resource): type: Literal[ResourceType.tool.value] = ResourceType.tool.value tool_group: str + tool_host: ToolHost description: str parameters: List[ToolParameter] - provider_id: Optional[str] = None + built_in_type: Optional[BuiltinTool] = None metadata: Optional[Dict[str, Any]] = None tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json @@ -39,7 +48,8 @@ class Tool(Resource): @json_schema_type -class ToolDef(BaseModel): +class CustomToolDef(BaseModel): + type: Literal["custom"] = "custom" name: str description: str parameters: List[ToolParameter] @@ -49,6 +59,19 @@ class ToolDef(BaseModel): ) +@json_schema_type +class BuiltInToolDef(BaseModel): + type: Literal["built_in"] = "built_in" + built_in_type: BuiltinTool + metadata: Optional[Dict[str, Any]] = None + + +ToolDef = register_schema( + Annotated[Union[CustomToolDef, BuiltInToolDef], Field(discriminator="type")], + name="ToolDef", +) + + @json_schema_type class MCPToolGroupDef(BaseModel): """ @@ -149,3 +172,14 @@ class ToolRuntime(Protocol): ) -> ToolInvocationResult: """Run a tool with the given arguments""" ... + + +# Three tool types: +# 1. Built-in tools +# 2. Client tools +# 3. Model-context-protocol tools + +# Suport registration of agents with tool groups +# TBD: Have a client utility to hide the pre processing tools. +# Attachments are confusing right now since they are inserted into memory first and retireved through RAG, even before a question is asked. +# diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 8d622a5c2..2aff0f3a2 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -516,6 +516,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): ) -> None: tools = [] tool_defs = [] + tool_host = ToolHost.distribution if provider_id is None: if len(self.impls_by_provider_id.keys()) > 1: raise ValueError( @@ -529,25 +530,42 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): tool_defs = await self.impls_by_provider_id[provider_id].discover_tools( tool_group ) - + tool_host = ToolHost.model_context_protocol elif isinstance(tool_group, UserDefinedToolGroupDef): tool_defs = tool_group.tools else: raise ValueError(f"Unknown tool group: {tool_group}") for tool_def in tool_defs: - tools.append( - Tool( - identifier=tool_def.name, - tool_group=tool_group_id, - description=tool_def.description, - parameters=tool_def.parameters, - provider_id=provider_id, - tool_prompt_format=tool_def.tool_prompt_format, - provider_resource_id=tool_def.name, - metadata=tool_def.metadata, + if isinstance(tool_def, CustomToolDef): + tools.append( + Tool( + identifier=tool_def.name, + tool_group=tool_group_id, + description=tool_def.description, + parameters=tool_def.parameters, + provider_id=provider_id, + tool_prompt_format=tool_def.tool_prompt_format, + provider_resource_id=tool_def.name, + metadata=tool_def.metadata, + tool_host=tool_host, + ) + ) + elif isinstance(tool_def, BuiltInToolDef): + tools.append( + Tool( + identifier=tool_def.built_in_type.value, + tool_group=tool_group_id, + built_in_type=tool_def.built_in_type, + description="", + parameters=[], + provider_id=provider_id, + tool_prompt_format=ToolPromptFormat.json, + provider_resource_id=tool_def.built_in_type.value, + metadata=tool_def.metadata, + tool_host=tool_host, + ) ) - ) for tool in tools: existing_tool = await self.get_tool(tool.identifier) # Compare existing and new object if one exists diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 8075ea2bd..cc4ef38a9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -621,6 +621,9 @@ class ChatAgent(ShieldRunnerMixin): ret = [] for tool_name in self.agent_config.available_tools: tool = await self.tool_groups_api.get_tool(tool_name) + if tool.built_in_type: + ret.append(ToolDefinition(tool_name=tool.built_in_type)) + continue params = {} for param in tool.parameters: params[param.name] = ToolParamDefinition( diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py index ca0141552..cd0468d93 100644 --- a/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py @@ -25,8 +25,7 @@ class BraveSearchToolRuntimeImpl( pass async def register_tool(self, tool: Tool): - if tool.identifier != "brave_search": - raise ValueError(f"Tool identifier {tool.identifier} is not supported") + pass async def unregister_tool(self, tool_id: str) -> None: return diff --git a/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py index 94a387f30..f4e980929 100644 --- a/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py @@ -26,8 +26,7 @@ class TavilySearchToolRuntimeImpl( pass async def register_tool(self, tool: Tool): - if tool.identifier != "tavily_search": - raise ValueError(f"Tool identifier {tool.identifier} is not supported") + pass async def unregister_tool(self, tool_id: str) -> None: return diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index b9bf3fe36..c77929f99 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List from urllib.parse import urlparse from llama_stack.apis.tools import ( + CustomToolDef, MCPToolGroupDef, ToolDef, ToolGroupDef, @@ -52,7 +53,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): ) ) tools.append( - ToolDef( + CustomToolDef( name=tool.name, description=tool.description, parameters=parameters, diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index ca44325d7..97d0d47e6 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -9,10 +9,12 @@ import tempfile import pytest import pytest_asyncio +from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.models import ModelInput, ModelType from llama_stack.apis.tools import ( - ToolDef, + BuiltInToolDef, + CustomToolDef, ToolGroupInput, ToolParameter, UserDefinedToolGroupDef, @@ -151,42 +153,12 @@ async def agents_stack(request, inference_model, safety_shield): ) ) tool_groups = [ - ToolGroupInput( - tool_group_id="brave_search_group", - tool_group=UserDefinedToolGroupDef( - tools=[ - ToolDef( - name="brave_search", - description="brave_search", - parameters=[ - ToolParameter( - name="query", - description="query", - parameter_type="string", - required=True, - ), - ], - metadata={}, - ), - ], - ), - provider_id="brave-search", - ), ToolGroupInput( tool_group_id="tavily_search_group", tool_group=UserDefinedToolGroupDef( tools=[ - ToolDef( - name="tavily_search", - description="tavily_search", - parameters=[ - ToolParameter( - name="query", - description="query", - parameter_type="string", - required=True, - ), - ], + BuiltInToolDef( + built_in_type=BuiltinTool.brave_search, metadata={}, ), ], @@ -197,7 +169,7 @@ async def agents_stack(request, inference_model, safety_shield): tool_group_id="memory_group", tool_group=UserDefinedToolGroupDef( tools=[ - ToolDef( + CustomToolDef( name="memory", description="memory", parameters=[ @@ -230,10 +202,8 @@ async def agents_stack(request, inference_model, safety_shield): tool_group_id="code_interpreter_group", tool_group=UserDefinedToolGroupDef( tools=[ - ToolDef( - name="code_interpreter", - description="code_interpreter", - parameters=[], + BuiltInToolDef( + built_in_type=BuiltinTool.code_interpreter, metadata={}, ) ], diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 147f04b02..a8c472da4 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -150,9 +150,7 @@ async def create_agent_turn_with_search_tool( assert isinstance(tool_execution, ToolExecutionStep) assert len(tool_execution.tool_calls) > 0 actual_tool_name = tool_execution.tool_calls[0].tool_name - if isinstance(actual_tool_name, BuiltinTool): - actual_tool_name = actual_tool_name.value - assert actual_tool_name == tool_name + assert actual_tool_name.value == tool_name assert len(tool_execution.tool_responses) > 0 check_turn_complete_event(turn_response, session_id, search_query_messages) @@ -305,20 +303,6 @@ class TestAgents: "brave_search", ) - @pytest.mark.asyncio - async def test_create_agent_turn_with_tavily_search( - self, agents_stack, search_query_messages, common_params - ): - if "TAVILY_SEARCH_API_KEY" not in os.environ: - pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") - - await create_agent_turn_with_search_tool( - agents_stack, - search_query_messages, - common_params, - "tavily_search", - ) - def check_event_types(turn_response): event_types = [chunk.event.payload.event_type for chunk in turn_response]