This commit is contained in:
Dinesh Yeduguru 2024-12-17 16:06:57 -08:00
parent 482a0e4839
commit a7dd22988b
2 changed files with 77 additions and 37 deletions

View file

@ -25,7 +25,6 @@ class MetaReferenceToolRuntimeImpl(
def __init__(self, config: MetaReferenceToolRuntimeConfig): def __init__(self, config: MetaReferenceToolRuntimeConfig):
self.config = config self.config = config
self.tools: Dict[str, Type[BaseTool]] = {} self.tools: Dict[str, Type[BaseTool]] = {}
self.tool_instances: Dict[str, BaseTool] = {}
self._discover_tools() self._discover_tools()
def _discover_tools(self): def _discover_tools(self):
@ -44,6 +43,33 @@ class MetaReferenceToolRuntimeImpl(
): ):
self.tools[attr.tool_id()] = attr self.tools[attr.tool_id()] = attr
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
if tool.identifier not in self.tools:
raise ValueError(f"Tool {tool.identifier} not found in available tools")
# Validate provider_metadata against tool's config type if specified
tool_class = self.tools[tool.identifier]
config_type = tool_class.get_provider_config_type()
if (
config_type
and tool.provider_metadata
and tool.provider_metadata.get("config")
):
config_type(**tool.provider_metadata.get("config"))
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
if tool_id not in self.tools:
raise ValueError(f"Tool {tool_id} not found")
tool_instance = await self._create_tool_instance(tool_id)
return await tool_instance.execute(**args)
async def unregister_tool(self, tool_id: str) -> None:
raise NotImplementedError("Meta Reference does not support unregistering tools")
async def _create_tool_instance( async def _create_tool_instance(
self, tool_id: str, tool_def: Optional[Tool] = None self, tool_id: str, tool_def: Optional[Tool] = None
) -> BaseTool: ) -> BaseTool:
@ -64,41 +90,6 @@ class MetaReferenceToolRuntimeImpl(
return tool_class(config=config) return tool_class(config=config)
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
if tool.identifier not in self.tools:
raise ValueError(f"Tool {tool.identifier} not found in available tools")
# Validate provider_metadata against tool's config type if specified
tool_class = self.tools[tool.identifier]
config_type = tool_class.get_provider_config_type()
if (
config_type
and tool.provider_metadata
and tool.provider_metadata.get("config")
):
config_type(**tool.provider_metadata.get("config"))
self.tool_instances[tool.identifier] = await self._create_tool_instance(
tool.identifier, tool
)
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
if tool_id not in self.tools:
raise ValueError(f"Tool {tool_id} not found")
if tool_id not in self.tool_instances:
self.tool_instances[tool_id] = await self._create_tool_instance(tool_id)
return await self.tool_instances[tool_id].execute(**args)
async def unregister_tool(self, tool_id: str) -> None:
if tool_id in self.tool_instances:
del self.tool_instances[tool_id]
raise NotImplementedError("Meta Reference does not support unregistering tools")
def _get_api_key(self) -> str: def _get_api_key(self) -> str:
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key: if provider_data is None or not provider_data.api_key:

View file

@ -4,8 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type, TypeVar from typing import Any, Dict, get_type_hints, List, Optional, Type, TypeVar
from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_stack.apis.tools.tools import Tool, ToolParameter, ToolReturn
T = TypeVar("T") T = TypeVar("T")
@ -33,3 +37,48 @@ class BaseTool(ABC):
def get_provider_config_type(cls) -> Optional[Type[T]]: def get_provider_config_type(cls) -> Optional[Type[T]]:
"""Override to specify a Pydantic model for tool configuration""" """Override to specify a Pydantic model for tool configuration"""
return None return None
@classmethod
def get_tool_definition(cls) -> Tool:
"""Generate a Tool definition from the class implementation"""
# Get execute method
execute_method = cls.execute
signature = inspect.signature(execute_method)
docstring = execute_method.__doc__ or "No description available"
# Extract parameters
parameters: List[ToolParameter] = []
type_hints = get_type_hints(execute_method)
for name, param in signature.parameters.items():
if name == "self":
continue
param_type = type_hints.get(name, Any).__name__
required = param.default == param.empty
default = None if param.default == param.empty else param.default
parameters.append(
ToolParameter(
name=name,
type_hint=param_type,
description=f"Parameter: {name}", # Could be enhanced with docstring parsing
required=required,
default=default,
)
)
# Extract return info
return_type = type_hints.get("return", Any).__name__
return Tool(
identifier=cls.tool_id(),
provider_resource_id=cls.tool_id(),
name=cls.__name__,
description=docstring,
parameters=parameters,
returns=ToolReturn(
type_hint=return_type, description="Tool execution result"
),
tool_prompt_format=ToolPromptFormat.json,
)