mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
tool def
This commit is contained in:
parent
482a0e4839
commit
a7dd22988b
2 changed files with 77 additions and 37 deletions
|
@ -25,7 +25,6 @@ class MetaReferenceToolRuntimeImpl(
|
|||
def __init__(self, config: MetaReferenceToolRuntimeConfig):
|
||||
self.config = config
|
||||
self.tools: Dict[str, Type[BaseTool]] = {}
|
||||
self.tool_instances: Dict[str, BaseTool] = {}
|
||||
self._discover_tools()
|
||||
|
||||
def _discover_tools(self):
|
||||
|
@ -44,6 +43,33 @@ class MetaReferenceToolRuntimeImpl(
|
|||
):
|
||||
self.tools[attr.tool_id()] = attr
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
if tool.identifier not in self.tools:
|
||||
raise ValueError(f"Tool {tool.identifier} not found in available tools")
|
||||
|
||||
# Validate provider_metadata against tool's config type if specified
|
||||
tool_class = self.tools[tool.identifier]
|
||||
config_type = tool_class.get_provider_config_type()
|
||||
if (
|
||||
config_type
|
||||
and tool.provider_metadata
|
||||
and tool.provider_metadata.get("config")
|
||||
):
|
||||
config_type(**tool.provider_metadata.get("config"))
|
||||
|
||||
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
|
||||
if tool_id not in self.tools:
|
||||
raise ValueError(f"Tool {tool_id} not found")
|
||||
|
||||
tool_instance = await self._create_tool_instance(tool_id)
|
||||
return await tool_instance.execute(**args)
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
raise NotImplementedError("Meta Reference does not support unregistering tools")
|
||||
|
||||
async def _create_tool_instance(
|
||||
self, tool_id: str, tool_def: Optional[Tool] = None
|
||||
) -> BaseTool:
|
||||
|
@ -64,41 +90,6 @@ class MetaReferenceToolRuntimeImpl(
|
|||
|
||||
return tool_class(config=config)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
if tool.identifier not in self.tools:
|
||||
raise ValueError(f"Tool {tool.identifier} not found in available tools")
|
||||
|
||||
# Validate provider_metadata against tool's config type if specified
|
||||
tool_class = self.tools[tool.identifier]
|
||||
config_type = tool_class.get_provider_config_type()
|
||||
if (
|
||||
config_type
|
||||
and tool.provider_metadata
|
||||
and tool.provider_metadata.get("config")
|
||||
):
|
||||
config_type(**tool.provider_metadata.get("config"))
|
||||
|
||||
self.tool_instances[tool.identifier] = await self._create_tool_instance(
|
||||
tool.identifier, tool
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
|
||||
if tool_id not in self.tools:
|
||||
raise ValueError(f"Tool {tool_id} not found")
|
||||
|
||||
if tool_id not in self.tool_instances:
|
||||
self.tool_instances[tool_id] = await self._create_tool_instance(tool_id)
|
||||
|
||||
return await self.tool_instances[tool_id].execute(**args)
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
if tool_id in self.tool_instances:
|
||||
del self.tool_instances[tool_id]
|
||||
raise NotImplementedError("Meta Reference does not support unregistering tools")
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.api_key:
|
||||
|
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, Type, TypeVar
|
||||
from typing import Any, Dict, get_type_hints, List, Optional, Type, TypeVar
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||
from llama_stack.apis.tools.tools import Tool, ToolParameter, ToolReturn
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
@ -33,3 +37,48 @@ class BaseTool(ABC):
|
|||
def get_provider_config_type(cls) -> Optional[Type[T]]:
|
||||
"""Override to specify a Pydantic model for tool configuration"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_tool_definition(cls) -> Tool:
|
||||
"""Generate a Tool definition from the class implementation"""
|
||||
# Get execute method
|
||||
execute_method = cls.execute
|
||||
signature = inspect.signature(execute_method)
|
||||
docstring = execute_method.__doc__ or "No description available"
|
||||
|
||||
# Extract parameters
|
||||
parameters: List[ToolParameter] = []
|
||||
type_hints = get_type_hints(execute_method)
|
||||
|
||||
for name, param in signature.parameters.items():
|
||||
if name == "self":
|
||||
continue
|
||||
|
||||
param_type = type_hints.get(name, Any).__name__
|
||||
required = param.default == param.empty
|
||||
default = None if param.default == param.empty else param.default
|
||||
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=name,
|
||||
type_hint=param_type,
|
||||
description=f"Parameter: {name}", # Could be enhanced with docstring parsing
|
||||
required=required,
|
||||
default=default,
|
||||
)
|
||||
)
|
||||
|
||||
# Extract return info
|
||||
return_type = type_hints.get("return", Any).__name__
|
||||
|
||||
return Tool(
|
||||
identifier=cls.tool_id(),
|
||||
provider_resource_id=cls.tool_id(),
|
||||
name=cls.__name__,
|
||||
description=docstring,
|
||||
parameters=parameters,
|
||||
returns=ToolReturn(
|
||||
type_hint=return_type, description="Tool execution result"
|
||||
),
|
||||
tool_prompt_format=ToolPromptFormat.json,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue