From a7dd22988b1251c00a588c7184b8b538796ac769 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 17 Dec 2024 16:06:57 -0800 Subject: [PATCH] tool def --- .../meta_reference/meta_reference.py | 63 ++++++++----------- .../tool_runtime/meta_reference/tools/base.py | 51 ++++++++++++++- 2 files changed, 77 insertions(+), 37 deletions(-) diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py b/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py index 89efafecd..47e6c2257 100644 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py @@ -25,7 +25,6 @@ class MetaReferenceToolRuntimeImpl( def __init__(self, config: MetaReferenceToolRuntimeConfig): self.config = config self.tools: Dict[str, Type[BaseTool]] = {} - self.tool_instances: Dict[str, BaseTool] = {} self._discover_tools() def _discover_tools(self): @@ -44,6 +43,33 @@ class MetaReferenceToolRuntimeImpl( ): self.tools[attr.tool_id()] = attr + async def initialize(self): + pass + + async def register_tool(self, tool: Tool): + if tool.identifier not in self.tools: + raise ValueError(f"Tool {tool.identifier} not found in available tools") + + # Validate provider_metadata against tool's config type if specified + tool_class = self.tools[tool.identifier] + config_type = tool_class.get_provider_config_type() + if ( + config_type + and tool.provider_metadata + and tool.provider_metadata.get("config") + ): + config_type(**tool.provider_metadata.get("config")) + + async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any: + if tool_id not in self.tools: + raise ValueError(f"Tool {tool_id} not found") + + tool_instance = await self._create_tool_instance(tool_id) + return await tool_instance.execute(**args) + + async def unregister_tool(self, tool_id: str) -> None: + raise NotImplementedError("Meta Reference does not support unregistering tools") + async def _create_tool_instance( self, tool_id: str, tool_def: Optional[Tool] = None ) -> BaseTool: @@ -64,41 +90,6 @@ class MetaReferenceToolRuntimeImpl( return tool_class(config=config) - async def initialize(self): - pass - - async def register_tool(self, tool: Tool): - if tool.identifier not in self.tools: - raise ValueError(f"Tool {tool.identifier} not found in available tools") - - # Validate provider_metadata against tool's config type if specified - tool_class = self.tools[tool.identifier] - config_type = tool_class.get_provider_config_type() - if ( - config_type - and tool.provider_metadata - and tool.provider_metadata.get("config") - ): - config_type(**tool.provider_metadata.get("config")) - - self.tool_instances[tool.identifier] = await self._create_tool_instance( - tool.identifier, tool - ) - - async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any: - if tool_id not in self.tools: - raise ValueError(f"Tool {tool_id} not found") - - if tool_id not in self.tool_instances: - self.tool_instances[tool_id] = await self._create_tool_instance(tool_id) - - return await self.tool_instances[tool_id].execute(**args) - - async def unregister_tool(self, tool_id: str) -> None: - if tool_id in self.tool_instances: - del self.tool_instances[tool_id] - raise NotImplementedError("Meta Reference does not support unregistering tools") - def _get_api_key(self) -> str: provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.api_key: diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/base.py b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/base.py index 79e20f85e..d9acfa9be 100644 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/base.py +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/base.py @@ -4,8 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Any, Dict, get_type_hints, List, Optional, Type, TypeVar + +from llama_models.llama3.api.datatypes import ToolPromptFormat +from llama_stack.apis.tools.tools import Tool, ToolParameter, ToolReturn T = TypeVar("T") @@ -33,3 +37,48 @@ class BaseTool(ABC): def get_provider_config_type(cls) -> Optional[Type[T]]: """Override to specify a Pydantic model for tool configuration""" return None + + @classmethod + def get_tool_definition(cls) -> Tool: + """Generate a Tool definition from the class implementation""" + # Get execute method + execute_method = cls.execute + signature = inspect.signature(execute_method) + docstring = execute_method.__doc__ or "No description available" + + # Extract parameters + parameters: List[ToolParameter] = [] + type_hints = get_type_hints(execute_method) + + for name, param in signature.parameters.items(): + if name == "self": + continue + + param_type = type_hints.get(name, Any).__name__ + required = param.default == param.empty + default = None if param.default == param.empty else param.default + + parameters.append( + ToolParameter( + name=name, + type_hint=param_type, + description=f"Parameter: {name}", # Could be enhanced with docstring parsing + required=required, + default=default, + ) + ) + + # Extract return info + return_type = type_hints.get("return", Any).__name__ + + return Tool( + identifier=cls.tool_id(), + provider_resource_id=cls.tool_id(), + name=cls.__name__, + description=docstring, + parameters=parameters, + returns=ToolReturn( + type_hint=return_type, description="Tool execution result" + ), + tool_prompt_format=ToolPromptFormat.json, + )