From e65a6fac9d6c8dca8bbd071cc7a9af8382e49fb9 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 13 Dec 2024 12:09:12 -0800 Subject: [PATCH] init --- llama_stack/apis/tools/__init__.py | 7 ++ llama_stack/apis/tools/tools.py | 86 +++++++++++++++++++ llama_stack/distribution/resolver.py | 2 + llama_stack/providers/datatypes.py | 5 ++ .../inline/tools/meta_reference/__init__.py | 7 ++ .../inline/tools/meta_reference/config.py | 11 +++ .../tools/meta_reference/meta_reference.py | 17 ++++ 7 files changed, 135 insertions(+) create mode 100644 llama_stack/apis/tools/__init__.py create mode 100644 llama_stack/apis/tools/tools.py create mode 100644 llama_stack/providers/inline/tools/meta_reference/__init__.py create mode 100644 llama_stack/providers/inline/tools/meta_reference/config.py create mode 100644 llama_stack/providers/inline/tools/meta_reference/meta_reference.py diff --git a/llama_stack/apis/tools/__init__.py b/llama_stack/apis/tools/__init__.py new file mode 100644 index 000000000..f747fcdc2 --- /dev/null +++ b/llama_stack/apis/tools/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .tools import * # noqa: F401 F403 diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py new file mode 100644 index 000000000..572a74998 --- /dev/null +++ b/llama_stack/apis/tools/tools.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Literal, Optional + +from llama_models.llama3.api.datatypes import ToolPromptFormat + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field +from typing_extensions import Protocol, runtime_checkable + +from llama_stack.apis.resource import Resource +from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol + + +@json_schema_type +class ToolParameter(BaseModel): + """Represents a parameter in a tool's function signature""" + + name: str + type_hint: str + description: str + required: bool = True + default: Optional[Any] = None + + +@json_schema_type +class ToolReturn(BaseModel): + """Represents the return type and description of a tool""" + + type_hint: str + description: str + + +@json_schema_type +class Tool(Resource): + """Represents a tool that can be provided by different providers""" + + resource_type: Literal["tool"] = "tool" + name: str + description: str + parameters: List[ToolParameter] + returns: ToolReturn + provider_metadata: Optional[Dict[str, Any]] = None + tool_prompt_format: Optional[ToolPromptFormat] = Field( + default=ToolPromptFormat.json + ) + + +@runtime_checkable +@trace_protocol +class Tools(Protocol): + async def register_tool( + self, + tool_id: str, + name: str, + description: str, + parameters: List[ToolParameter], + returns: ToolReturn, + provider_metadata: Optional[Dict[str, Any]] = None, + tool_prompt_format: Optional[ToolPromptFormat] = None, + ) -> Tool: + """Register a tool with provider-specific metadata""" + ... + + async def get_tool( + self, + identifier: str, + ) -> Tool: ... + + async def list_tools( + self, + provider_id: Optional[str] = None, + ) -> List[Tool]: + """List tools with optional provider""" + + +@runtime_checkable +@trace_protocol +class ToolRuntime(Protocol): + def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any: + """Run a tool with the given arguments""" + ... diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 4541b01eb..885e9bbc0 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -30,6 +30,7 @@ from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry +from llama_stack.apis.tools import Tools from llama_stack.distribution.client import get_client_impl from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.store import DistributionRegistry @@ -66,6 +67,7 @@ def api_protocol_map() -> Dict[Api, Any]: def additional_protocols_map() -> Dict[Api, Any]: return { Api.inference: (ModelsProtocolPrivate, Models, Api.models), + Api.tools: (ToolsProtocolPrivate, Tools, Api.tools), Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index c506a754c..f49222bca 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -17,6 +17,7 @@ from llama_stack.apis.memory_banks.memory_banks import MemoryBank from llama_stack.apis.models import Model from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.shields import Shield +from llama_stack.apis.tools import Tool @json_schema_type @@ -75,6 +76,10 @@ class EvalTasksProtocolPrivate(Protocol): async def register_eval_task(self, eval_task: EvalTask) -> None: ... +class ToolsProtocolPrivate(Protocol): + async def register_tool(self, tool: Tool) -> None: ... + + @json_schema_type class ProviderSpec(BaseModel): api: Api diff --git a/llama_stack/providers/inline/tools/meta_reference/__init__.py b/llama_stack/providers/inline/tools/meta_reference/__init__.py new file mode 100644 index 000000000..da392fdb3 --- /dev/null +++ b/llama_stack/providers/inline/tools/meta_reference/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .meta_reference import * # noqa: F401 F403 diff --git a/llama_stack/providers/inline/tools/meta_reference/config.py b/llama_stack/providers/inline/tools/meta_reference/config.py new file mode 100644 index 000000000..61dfcf52e --- /dev/null +++ b/llama_stack/providers/inline/tools/meta_reference/config.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class MetaReferenceToolConfig(BaseModel): + pass diff --git a/llama_stack/providers/inline/tools/meta_reference/meta_reference.py b/llama_stack/providers/inline/tools/meta_reference/meta_reference.py new file mode 100644 index 000000000..c69e83203 --- /dev/null +++ b/llama_stack/providers/inline/tools/meta_reference/meta_reference.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.tools import Tool, Tools + +from .config import MetaReferenceToolConfig + + +class MetaReferenceTool(Tools): + def __init__(self, config: MetaReferenceToolConfig): + self.config = config + + async def register_tool(self, tool: Tool): + pass