This commit is contained in:
Dinesh Yeduguru 2024-12-17 16:06:57 -08:00
parent 482a0e4839
commit a7dd22988b
2 changed files with 77 additions and 37 deletions

View file

@ -25,7 +25,6 @@ class MetaReferenceToolRuntimeImpl(
def __init__(self, config: MetaReferenceToolRuntimeConfig):
self.config = config
self.tools: Dict[str, Type[BaseTool]] = {}
self.tool_instances: Dict[str, BaseTool] = {}
self._discover_tools()
def _discover_tools(self):
@ -44,6 +43,33 @@ class MetaReferenceToolRuntimeImpl(
):
self.tools[attr.tool_id()] = attr
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
if tool.identifier not in self.tools:
raise ValueError(f"Tool {tool.identifier} not found in available tools")
# Validate provider_metadata against tool's config type if specified
tool_class = self.tools[tool.identifier]
config_type = tool_class.get_provider_config_type()
if (
config_type
and tool.provider_metadata
and tool.provider_metadata.get("config")
):
config_type(**tool.provider_metadata.get("config"))
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
if tool_id not in self.tools:
raise ValueError(f"Tool {tool_id} not found")
tool_instance = await self._create_tool_instance(tool_id)
return await tool_instance.execute(**args)
async def unregister_tool(self, tool_id: str) -> None:
raise NotImplementedError("Meta Reference does not support unregistering tools")
async def _create_tool_instance(
self, tool_id: str, tool_def: Optional[Tool] = None
) -> BaseTool:
@ -64,41 +90,6 @@ class MetaReferenceToolRuntimeImpl(
return tool_class(config=config)
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
if tool.identifier not in self.tools:
raise ValueError(f"Tool {tool.identifier} not found in available tools")
# Validate provider_metadata against tool's config type if specified
tool_class = self.tools[tool.identifier]
config_type = tool_class.get_provider_config_type()
if (
config_type
and tool.provider_metadata
and tool.provider_metadata.get("config")
):
config_type(**tool.provider_metadata.get("config"))
self.tool_instances[tool.identifier] = await self._create_tool_instance(
tool.identifier, tool
)
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
if tool_id not in self.tools:
raise ValueError(f"Tool {tool_id} not found")
if tool_id not in self.tool_instances:
self.tool_instances[tool_id] = await self._create_tool_instance(tool_id)
return await self.tool_instances[tool_id].execute(**args)
async def unregister_tool(self, tool_id: str) -> None:
if tool_id in self.tool_instances:
del self.tool_instances[tool_id]
raise NotImplementedError("Meta Reference does not support unregistering tools")
def _get_api_key(self) -> str:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key:

View file

@ -4,8 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type, TypeVar
from typing import Any, Dict, get_type_hints, List, Optional, Type, TypeVar
from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_stack.apis.tools.tools import Tool, ToolParameter, ToolReturn
T = TypeVar("T")
@ -33,3 +37,48 @@ class BaseTool(ABC):
def get_provider_config_type(cls) -> Optional[Type[T]]:
"""Override to specify a Pydantic model for tool configuration"""
return None
@classmethod
def get_tool_definition(cls) -> Tool:
"""Generate a Tool definition from the class implementation"""
# Get execute method
execute_method = cls.execute
signature = inspect.signature(execute_method)
docstring = execute_method.__doc__ or "No description available"
# Extract parameters
parameters: List[ToolParameter] = []
type_hints = get_type_hints(execute_method)
for name, param in signature.parameters.items():
if name == "self":
continue
param_type = type_hints.get(name, Any).__name__
required = param.default == param.empty
default = None if param.default == param.empty else param.default
parameters.append(
ToolParameter(
name=name,
type_hint=param_type,
description=f"Parameter: {name}", # Could be enhanced with docstring parsing
required=required,
default=default,
)
)
# Extract return info
return_type = type_hints.get("return", Any).__name__
return Tool(
identifier=cls.tool_id(),
provider_resource_id=cls.tool_id(),
name=cls.__name__,
description=docstring,
parameters=parameters,
returns=ToolReturn(
type_hint=return_type, description="Tool execution result"
),
tool_prompt_format=ToolPromptFormat.json,
)