From b7ad53ca9373c1ff3c352d6f2ae8ffb4b329e48b Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 16 Dec 2024 13:01:52 -0800 Subject: [PATCH] minor fixes --- llama_stack/apis/tools/tools.py | 1 + llama_stack/distribution/datatypes.py | 4 ++++ llama_stack/distribution/routers/routing_tables.py | 5 +++++ 3 files changed, 10 insertions(+) diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index d9baa33de..239375b11 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -61,6 +61,7 @@ class Tools(Protocol): parameters: List[ToolParameter], returns: ToolReturn, provider_id: Optional[str] = None, + provider_resource_id: Optional[str] = None, provider_metadata: Optional[Dict[str, Any]] = None, tool_prompt_format: Optional[ToolPromptFormat] = None, ) -> Tool: diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 1159372d4..f70616895 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -21,6 +21,7 @@ from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring +from llama_stack.apis.tools import Tool, ToolRuntime from llama_stack.providers.utils.kvstore.config import KVStoreConfig LLAMA_STACK_BUILD_CONFIG_VERSION = "2" @@ -37,6 +38,7 @@ RoutableObject = Union[ Dataset, ScoringFn, EvalTask, + Tool, ] @@ -48,6 +50,7 @@ RoutableObjectWithProvider = Annotated[ Dataset, ScoringFn, EvalTask, + Tool, ], Field(discriminator="type"), ] @@ -59,6 +62,7 @@ RoutedProtocol = Union[ DatasetIO, Scoring, Eval, + ToolRuntime, ] diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 5adbc5cb5..f076475a2 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -482,6 +482,7 @@ class ToolsRoutingTable(CommonRoutingTableImpl, Tools): parameters: List[ToolParameter], returns: ToolReturn, provider_id: Optional[str] = None, + provider_resource_id: Optional[str] = None, provider_metadata: Optional[Dict[str, Any]] = None, tool_prompt_format: Optional[ToolPromptFormat] = None, ) -> None: @@ -496,6 +497,9 @@ class ToolsRoutingTable(CommonRoutingTableImpl, Tools): raise ValueError( "No provider specified and multiple providers available. Please specify a provider_id." ) + if provider_resource_id is None: + provider_resource_id = tool_id + tool = Tool( identifier=tool_id, name=name, @@ -503,6 +507,7 @@ class ToolsRoutingTable(CommonRoutingTableImpl, Tools): parameters=parameters, returns=returns, provider_id=provider_id, + provider_resource_id=provider_resource_id, provider_metadata=provider_metadata, tool_prompt_format=tool_prompt_format, )