From de065a60f2276e452c785129f09cb9ed88105d90 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 19 Dec 2024 20:59:47 -0800 Subject: [PATCH] final changes --- llama_stack/apis/resource.py | 1 + llama_stack/apis/tools/tools.py | 54 +++++---- llama_stack/distribution/datatypes.py | 16 +-- llama_stack/distribution/distribution.py | 2 +- llama_stack/distribution/resolver.py | 6 +- llama_stack/distribution/routers/__init__.py | 5 +- llama_stack/distribution/routers/routers.py | 22 ++-- .../distribution/routers/routing_tables.py | 113 ++++++++++-------- llama_stack/providers/datatypes.py | 2 +- .../tool_runtime/brave_search/brave_search.py | 6 +- .../model_context_protocol.py | 23 ++-- 11 files changed, 142 insertions(+), 108 deletions(-) diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index 67ea94c19..a85f5a31c 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -19,6 +19,7 @@ class ResourceType(Enum): scoring_function = "scoring_function" eval_task = "eval_task" tool = "tool" + tool_group = "tool_group" class Resource(BaseModel): diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index ce053fd66..23110543b 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -48,30 +48,34 @@ class ToolDef(BaseModel): @json_schema_type -class MCPToolGroup(BaseModel): +class MCPToolGroupDef(BaseModel): """ A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information. """ type: Literal["model_context_protocol"] = "model_context_protocol" - name: str endpoint: URL @json_schema_type -class UserDefinedToolGroup(BaseModel): +class UserDefinedToolGroupDef(BaseModel): type: Literal["user_defined"] = "user_defined" - name: str tools: List[ToolDef] -ToolGroup = register_schema( - Annotated[Union[MCPToolGroup, UserDefinedToolGroup], Field(discriminator="type")], +ToolGroupDef = register_schema( + Annotated[ + Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type") + ], name="ToolGroup", ) +class ToolGroup(Resource): + type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value + + @json_schema_type class ToolInvocationResult(BaseModel): content: InterleavedContent @@ -80,34 +84,44 @@ class ToolInvocationResult(BaseModel): class ToolStore(Protocol): - def get_tool(self, tool_id: str) -> Tool: ... + def get_tool(self, tool_name: str) -> Tool: ... @runtime_checkable @trace_protocol -class Tools(Protocol): +class ToolGroups(Protocol): @webmethod(route="/toolgroups/register", method="POST") async def register_tool_group( self, - tool_group: ToolGroup, + tool_group_id: str, + tool_group: ToolGroupDef, provider_id: Optional[str] = None, ) -> None: """Register a tool group""" ... - @webmethod(route="/tools/get", method="GET") - async def get_tool( + @webmethod(route="/toolgroups/get", method="GET") + async def get_tool_group( self, - tool_id: str, - ) -> Tool: ... + tool_group_id: str, + ) -> ToolGroup: ... + + @webmethod(route="/toolgroups/list", method="GET") + async def list_tool_groups(self) -> List[ToolGroup]: + """List tool groups with optional provider""" + ... @webmethod(route="/tools/list", method="GET") - async def list_tools(self) -> List[Tool]: - """List tools with optional provider""" + async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: + """List tools with optional tool group""" + ... - @webmethod(route="/tools/unregister", method="POST") - async def unregister_tool(self, tool_id: str) -> None: - """Unregister a tool""" + @webmethod(route="/tools/get", method="GET") + async def get_tool(self, tool_name: str) -> Tool: ... + + @webmethod(route="/toolgroups/unregister", method="POST") + async def unregister_tool_group(self, tool_group_id: str) -> None: + """Unregister a tool group""" ... @@ -117,11 +131,11 @@ class ToolRuntime(Protocol): tool_store: ToolStore @webmethod(route="/tool-runtime/discover", method="POST") - async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: ... + async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ... @webmethod(route="/tool-runtime/invoke", method="POST") async def invoke_tool( - self, tool_id: str, args: Dict[str, Any] + self, tool_name: str, args: Dict[str, Any] ) -> ToolInvocationResult: """Run a tool with the given arguments""" ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index f70616895..f2dea6012 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -8,20 +8,20 @@ from typing import Dict, List, Optional, Union from pydantic import BaseModel, Field -from llama_stack.providers.datatypes import * # noqa: F403 -from llama_stack.apis.models import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 -from llama_stack.apis.datasets import * # noqa: F403 -from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.eval import Eval from llama_stack.apis.eval_tasks import EvalTaskInput from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory +from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring -from llama_stack.apis.tools import Tool, ToolRuntime +from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.shields import * # noqa: F403 +from llama_stack.apis.tools import Tool, ToolGroup, ToolRuntime +from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.providers.utils.kvstore.config import KVStoreConfig LLAMA_STACK_BUILD_CONFIG_VERSION = "2" @@ -39,6 +39,7 @@ RoutableObject = Union[ ScoringFn, EvalTask, Tool, + ToolGroup, ] @@ -51,6 +52,7 @@ RoutableObjectWithProvider = Annotated[ ScoringFn, EvalTask, Tool, + ToolGroup, ], Field(discriminator="type"), ] diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 1478737da..4183d92cd 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -48,7 +48,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: router_api=Api.eval, ), AutoRoutedApiInfo( - routing_table_api=Api.tools, + routing_table_api=Api.tool_groups, router_api=Api.tool_runtime, ), ] diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 14e5e7a86..439971315 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -30,7 +30,7 @@ from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry -from llama_stack.apis.tools import ToolRuntime, Tools +from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.distribution.client import get_client_impl from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.store import DistributionRegistry @@ -61,7 +61,7 @@ def api_protocol_map() -> Dict[Api, Any]: Api.eval: Eval, Api.eval_tasks: EvalTasks, Api.post_training: PostTraining, - Api.tools: Tools, + Api.tool_groups: ToolGroups, Api.tool_runtime: ToolRuntime, } @@ -69,7 +69,7 @@ def api_protocol_map() -> Dict[Api, Any]: def additional_protocols_map() -> Dict[Api, Any]: return { Api.inference: (ModelsProtocolPrivate, Models, Api.models), - Api.tools: (ToolsProtocolPrivate, Tools, Api.tools), + Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups), Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 05e741598..693f1fbe2 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -7,7 +7,6 @@ from typing import Any from llama_stack.distribution.datatypes import * # noqa: F403 - from llama_stack.distribution.store import DistributionRegistry from .routing_tables import ( @@ -17,7 +16,7 @@ from .routing_tables import ( ModelsRoutingTable, ScoringFunctionsRoutingTable, ShieldsRoutingTable, - ToolsRoutingTable, + ToolGroupsRoutingTable, ) @@ -34,7 +33,7 @@ async def get_routing_table_impl( "datasets": DatasetsRoutingTable, "scoring_functions": ScoringFunctionsRoutingTable, "eval_tasks": EvalTasksRoutingTable, - "tools": ToolsRoutingTable, + "tool_groups": ToolGroupsRoutingTable, } if api.value not in api_to_tables: diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 9c9cfec6f..a25a848db 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,16 +6,16 @@ from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack.apis.datasetio.datasetio import DatasetIO -from llama_stack.apis.memory_banks.memory_banks import BankParams -from llama_stack.distribution.datatypes import RoutingTable -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.datasetio.datasetio import DatasetIO from llama_stack.apis.eval import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.memory_banks.memory_banks import BankParams +from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.apis.scoring import * # noqa: F403 from llama_stack.apis.tools import * # noqa: F403 +from llama_stack.distribution.datatypes import RoutingTable class MemoryRouter(Memory): @@ -388,13 +388,13 @@ class ToolRuntimeRouter(ToolRuntime): async def shutdown(self) -> None: pass - async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any: - return await self.routing_table.get_provider_impl(tool_id).invoke_tool( - tool_id=tool_id, + async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: + return await self.routing_table.get_provider_impl(tool_name).invoke_tool( + tool_name=tool_name, args=args, ) - async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: + async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: return await self.routing_table.get_provider_impl( tool_group.name ).discover_tools(tool_group) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 690a4e9b7..3fb086b72 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -6,21 +6,19 @@ from typing import Any, Dict, List, Optional +from llama_models.llama3.api.datatypes import * # noqa: F403 from pydantic import parse_obj_as -from llama_models.llama3.api.datatypes import * # noqa: F403 - -from llama_stack.apis.models import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.eval_tasks import * # noqa: F403 +from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.apis.models import * # noqa: F403 +from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.tools import * # noqa: F403 -from llama_stack.apis.common.content_types import URL - -from llama_stack.apis.common.type_system import ParamType -from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.distribution.store import DistributionRegistry def get_impl_api(p: Any) -> Api: @@ -131,7 +129,7 @@ class CommonRoutingTableImpl(RoutingTable): return ("Scoring", "scoring_function") elif isinstance(self, EvalTasksRoutingTable): return ("Eval", "eval_task") - elif isinstance(self, ToolsRoutingTable): + elif isinstance(self, ToolGroupsRoutingTable): return ("Tools", "tool") else: raise ValueError("Unknown routing table type") @@ -471,65 +469,86 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks): await self.register_object(eval_task) -class ToolsRoutingTable(CommonRoutingTableImpl, Tools): - async def list_tools(self) -> List[Tool]: - return await self.get_all_with_type("tool") +class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): + async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: + tools = await self.get_all_with_type("tool") + if tool_group_id: + tools = [tool for tool in tools if tool.tool_group == tool_group_id] + return tools - async def get_tool(self, tool_id: str) -> Tool: - return await self.get_object_by_identifier("tool", tool_id) + async def list_tool_groups(self) -> List[ToolGroup]: + return await self.get_all_with_type("tool_group") + + async def get_tool_group(self, tool_group_id: str) -> ToolGroup: + return await self.get_object_by_identifier("tool_group", tool_group_id) + + async def get_tool(self, tool_name: str) -> Tool: + return await self.get_object_by_identifier("tool", tool_name) async def register_tool_group( self, - tool_group: ToolGroup, + tool_group_id: str, + tool_group: ToolGroupDef, provider_id: Optional[str] = None, ) -> None: tools = [] - if isinstance(tool_group, MCPToolGroup): - # TODO: Actually find the right MCP provider - if provider_id is None: - raise ValueError("MCP provider_id not specified") - tools = await self.impls_by_provider_id[provider_id].discover_tools( + tool_defs = [] + if provider_id is None: + if len(self.impls_by_provider_id.keys()) > 1: + raise ValueError( + f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}" + ) + provider_id = list(self.impls_by_provider_id.keys())[0] + + if isinstance(tool_group, MCPToolGroupDef): + tool_defs = await self.impls_by_provider_id[provider_id].discover_tools( tool_group ) - for tool in tools: - tool.provider_id = provider_id - elif isinstance(tool_group, UserDefinedToolGroup): - for tool in tool_group.tools: - tools.append( - Tool( - identifier=tool.name, - tool_group=tool_group.name, - name=tool.name, - description=tool.description, - parameters=tool.parameters, - provider_id=provider_id, - tool_prompt_format=tool.tool_prompt_format, - provider_resource_id=tool.name, - metadata=tool.metadata, - ) - ) + elif isinstance(tool_group, UserDefinedToolGroupDef): + tool_defs = tool_group.tools else: raise ValueError(f"Unknown tool group: {tool_group}") + for tool_def in tool_defs: + tools.append( + Tool( + identifier=tool_def.name, + tool_group=tool_group_id, + description=tool_def.description, + parameters=tool_def.parameters, + provider_id=provider_id, + tool_prompt_format=tool_def.tool_prompt_format, + provider_resource_id=tool_def.name, + metadata=tool_def.metadata, + ) + ) for tool in tools: existing_tool = await self.get_tool(tool.identifier) # Compare existing and new object if one exists if existing_tool: - # Compare all fields except provider_id since that might be None in new obj - if tool.provider_id is None: - tool.provider_id = existing_tool.provider_id existing_dict = existing_tool.model_dump() new_dict = tool.model_dump() if existing_dict != new_dict: raise ValueError( - f"Object {tool.name} already exists in registry. Please use a different identifier." + f"Object {tool.identifier} already exists in registry. Please use a different identifier." ) await self.register_object(tool) - async def unregister_tool(self, tool_id: str) -> None: - tool = await self.get_tool(tool_id) - if tool is None: - raise ValueError(f"Tool {tool_id} not found") - await self.unregister_object(tool) + await self.dist_registry.register( + ToolGroup( + identifier=tool_group_id, + provider_id=provider_id, + provider_resource_id=tool_group_id, + ) + ) + + async def unregister_tool_group(self, tool_group_id: str) -> None: + tool_group = await self.get_tool_group(tool_group_id) + if tool_group is None: + raise ValueError(f"Tool group {tool_group_id} not found") + tools = await self.list_tools(tool_group_id) + for tool in tools: + await self.unregister_object(tool) + await self.unregister_object(tool_group) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 7a82e282e..ce0c9f52e 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -40,7 +40,7 @@ class Api(Enum): datasets = "datasets" scoring_functions = "scoring_functions" eval_tasks = "eval_tasks" - tools = "tools" + tool_groups = "tool_groups" # built-in API inspect = "inspect" diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py index 464963b40..ca0141552 100644 --- a/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List import requests -from llama_stack.apis.tools import Tool, ToolGroup, ToolInvocationResult, ToolRuntime +from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ToolsProtocolPrivate @@ -42,11 +42,11 @@ class BraveSearchToolRuntimeImpl( ) return provider_data.api_key - async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: + async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: raise NotImplementedError("Brave search tool group not supported") async def invoke_tool( - self, tool_id: str, args: Dict[str, Any] + self, tool_name: str, args: Dict[str, Any] ) -> ToolInvocationResult: api_key = self._get_api_key() url = "https://api.search.brave.com/res/v1/web/search" diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 0c6661731..b9bf3fe36 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -8,14 +8,15 @@ from typing import Any, Dict, List from urllib.parse import urlparse from llama_stack.apis.tools import ( - MCPToolGroup, - Tool, - ToolGroup, + MCPToolGroupDef, + ToolDef, + ToolGroupDef, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.providers.datatypes import ToolsProtocolPrivate + from mcp import ClientSession from mcp.client.sse import sse_client @@ -29,8 +30,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): async def initialize(self): pass - async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: - if not isinstance(tool_group, MCPToolGroup): + async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: + if not isinstance(tool_group, MCPToolGroupDef): raise ValueError(f"Unsupported tool group type: {type(tool_group)}") tools = [] @@ -51,25 +52,23 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): ) ) tools.append( - Tool( - identifier=tool.name, + ToolDef( + name=tool.name, description=tool.description, - tool_group=tool_group.name, parameters=parameters, metadata={ "endpoint": tool_group.endpoint.uri, }, - provider_resource_id=tool.name, ) ) return tools async def invoke_tool( - self, tool_id: str, args: Dict[str, Any] + self, tool_name: str, args: Dict[str, Any] ) -> ToolInvocationResult: - tool = await self.tool_store.get_tool(tool_id) + tool = await self.tool_store.get_tool(tool_name) if tool.metadata is None or tool.metadata.get("endpoint") is None: - raise ValueError(f"Tool {tool_id} does not have metadata") + raise ValueError(f"Tool {tool_name} does not have metadata") endpoint = tool.metadata.get("endpoint") if urlparse(endpoint).scheme not in ("http", "https"): raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")