From a7a55748cacbda8f0e30ee26bf5d71cf2da66920 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 8 Jan 2025 18:24:35 -0800 Subject: [PATCH] address feedback --- llama_stack/apis/tools/tools.py | 5 ++-- .../distribution/routers/routing_tables.py | 1 - .../agents/meta_reference/agent_instance.py | 25 +++++++++++++------ .../meta_reference/tests/test_chat_agent.py | 5 ++-- .../code_interpreter/code_interpreter.py | 3 --- .../tool_runtime/bing_search/bing_search.py | 2 -- .../tavily_search/tavily_search.py | 2 -- .../wolfram_alpha/wolfram_alpha.py | 2 -- 8 files changed, 21 insertions(+), 24 deletions(-) diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index dbfd85220..e430ec46d 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional -from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat +from llama_models.llama3.api.datatypes import ToolPromptFormat from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Protocol, runtime_checkable @@ -40,7 +40,6 @@ class Tool(Resource): tool_host: ToolHost description: str parameters: List[ToolParameter] - built_in_type: Optional[BuiltinTool] = None metadata: Optional[Dict[str, Any]] = None tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json @@ -53,7 +52,6 @@ class ToolDef(BaseModel): description: Optional[str] = None parameters: Optional[List[ToolParameter]] = None metadata: Optional[Dict[str, Any]] = None - built_in_type: Optional[BuiltinTool] = None tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) @@ -130,6 +128,7 @@ class ToolGroups(Protocol): class ToolRuntime(Protocol): tool_store: ToolStore + # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. @webmethod(route="/tool-runtime/list-tools", method="GET") async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 36ddda7a6..d4cb708a2 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -527,7 +527,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): provider_resource_id=tool_def.name, metadata=tool_def.metadata, tool_host=tool_host, - built_in_type=tool_def.built_in_type, ) ) for tool in tools: diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index b72856674..2cd86bcaa 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -78,6 +78,7 @@ def make_random_string(length: int = 8): TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") MEMORY_QUERY_TOOL = "query_memory" WEB_SEARCH_TOOL = "web_search" +MEMORY_GROUP = "builtin::memory" class ChatAgent(ShieldRunnerMixin): @@ -741,16 +742,24 @@ class ChatAgent(ShieldRunnerMixin): continue tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name) for tool_def in tools: - if tool_def.built_in_type: - if tool_def_map.get(tool_def.built_in_type, None): - raise ValueError( - f"Tool {tool_def.built_in_type} already exists" - ) + if ( + toolgroup_name.startswith("builtin") + and toolgroup_name != MEMORY_GROUP + ): + tool_name = tool_def.identifier + built_in_type = BuiltinTool.brave_search + if tool_name == "web_search": + built_in_type = BuiltinTool.brave_search + else: + built_in_type = BuiltinTool(tool_name) - tool_def_map[tool_def.built_in_type] = ToolDefinition( - tool_name=tool_def.built_in_type + if tool_def_map.get(built_in_type, None): + raise ValueError(f"Tool {built_in_type} already exists") + + tool_def_map[built_in_type] = ToolDefinition( + tool_name=built_in_type ) - tool_to_group[tool_def.built_in_type] = tool_def.toolgroup_id + tool_to_group[built_in_type] = tool_def.toolgroup_id continue if tool_def_map.get(tool_def.identifier, None): diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py index 6b8a846ee..a7e6efc8c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py +++ b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py @@ -198,7 +198,7 @@ class MockToolGroupsAPI: toolgroup_id=MEMORY_TOOLGROUP, tool_host=ToolHost.client, description="Mock tool", - provider_id="mock_provider", + provider_id="builtin::memory", parameters=[], ) ] @@ -208,10 +208,9 @@ class MockToolGroupsAPI: identifier="code_interpreter", provider_resource_id="code_interpreter", toolgroup_id=CODE_INTERPRETER_TOOLGROUP, - built_in_type=BuiltinTool.code_interpreter, tool_host=ToolHost.client, description="Mock tool", - provider_id="mock_provider", + provider_id="builtin::code_interpreter", parameters=[], ) ] diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py index 98026fa3d..361c91a92 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -9,8 +9,6 @@ import logging import tempfile from typing import Any, Dict, List, Optional -from llama_models.llama3.api.datatypes import BuiltinTool - from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( Tool, @@ -58,7 +56,6 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): parameter_type="string", ), ], - built_in_type=BuiltinTool.code_interpreter, ) ] diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py index a69f08ce8..5cf36acbc 100644 --- a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -8,7 +8,6 @@ import json from typing import Any, Dict, List, Optional import requests -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( @@ -65,7 +64,6 @@ class BingSearchToolRuntimeImpl( parameter_type="string", ) ], - built_in_type=BuiltinTool.brave_search, ) ] diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index 8f666a6fb..8f86edfb1 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -8,7 +8,6 @@ import json from typing import Any, Dict, List, Optional import requests -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( @@ -64,7 +63,6 @@ class TavilySearchToolRuntimeImpl( parameter_type="string", ) ], - built_in_type=BuiltinTool.brave_search, ) ] diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py index 13c298eb2..af99d7b2a 100644 --- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -8,7 +8,6 @@ import json from typing import Any, Dict, List, Optional import requests -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( @@ -65,7 +64,6 @@ class WolframAlphaToolRuntimeImpl( parameter_type="string", ) ], - built_in_type=BuiltinTool.wolfram_alpha, ) ]