diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index 93a3718a0..a85f5a31c 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -18,6 +18,8 @@ class ResourceType(Enum): dataset = "dataset" scoring_function = "scoring_function" eval_task = "eval_task" + tool = "tool" + tool_group = "tool_group" class Resource(BaseModel): diff --git a/llama_stack/apis/tools/__init__.py b/llama_stack/apis/tools/__init__.py new file mode 100644 index 000000000..f747fcdc2 --- /dev/null +++ b/llama_stack/apis/tools/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .tools import * # noqa: F401 F403 diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py new file mode 100644 index 000000000..23110543b --- /dev/null +++ b/llama_stack/apis/tools/tools.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Annotated, Any, Dict, List, Literal, Optional, Union + +from llama_models.llama3.api.datatypes import ToolPromptFormat +from llama_models.schema_utils import json_schema_type, register_schema, webmethod +from pydantic import BaseModel, Field +from typing_extensions import Protocol, runtime_checkable + +from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol + + +@json_schema_type +class ToolParameter(BaseModel): + name: str + parameter_type: str + description: str + + +@json_schema_type +class Tool(Resource): + type: Literal[ResourceType.tool.value] = ResourceType.tool.value + tool_group: str + description: str + parameters: List[ToolParameter] + provider_id: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + tool_prompt_format: Optional[ToolPromptFormat] = Field( + default=ToolPromptFormat.json + ) + + +@json_schema_type +class ToolDef(BaseModel): + name: str + description: str + parameters: List[ToolParameter] + metadata: Dict[str, Any] + tool_prompt_format: Optional[ToolPromptFormat] = Field( + default=ToolPromptFormat.json + ) + + +@json_schema_type +class MCPToolGroupDef(BaseModel): + """ + A tool group that is defined by in a model context protocol server. + Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information. + """ + + type: Literal["model_context_protocol"] = "model_context_protocol" + endpoint: URL + + +@json_schema_type +class UserDefinedToolGroupDef(BaseModel): + type: Literal["user_defined"] = "user_defined" + tools: List[ToolDef] + + +ToolGroupDef = register_schema( + Annotated[ + Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type") + ], + name="ToolGroup", +) + + +class ToolGroup(Resource): + type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value + + +@json_schema_type +class ToolInvocationResult(BaseModel): + content: InterleavedContent + error_message: Optional[str] = None + error_code: Optional[int] = None + + +class ToolStore(Protocol): + def get_tool(self, tool_name: str) -> Tool: ... + + +@runtime_checkable +@trace_protocol +class ToolGroups(Protocol): + @webmethod(route="/toolgroups/register", method="POST") + async def register_tool_group( + self, + tool_group_id: str, + tool_group: ToolGroupDef, + provider_id: Optional[str] = None, + ) -> None: + """Register a tool group""" + ... + + @webmethod(route="/toolgroups/get", method="GET") + async def get_tool_group( + self, + tool_group_id: str, + ) -> ToolGroup: ... + + @webmethod(route="/toolgroups/list", method="GET") + async def list_tool_groups(self) -> List[ToolGroup]: + """List tool groups with optional provider""" + ... + + @webmethod(route="/tools/list", method="GET") + async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: + """List tools with optional tool group""" + ... + + @webmethod(route="/tools/get", method="GET") + async def get_tool(self, tool_name: str) -> Tool: ... + + @webmethod(route="/toolgroups/unregister", method="POST") + async def unregister_tool_group(self, tool_group_id: str) -> None: + """Unregister a tool group""" + ... + + +@runtime_checkable +@trace_protocol +class ToolRuntime(Protocol): + tool_store: ToolStore + + @webmethod(route="/tool-runtime/discover", method="POST") + async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ... + + @webmethod(route="/tool-runtime/invoke", method="POST") + async def invoke_tool( + self, tool_name: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + """Run a tool with the given arguments""" + ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 1159372d4..f2dea6012 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -8,19 +8,20 @@ from typing import Dict, List, Optional, Union from pydantic import BaseModel, Field -from llama_stack.providers.datatypes import * # noqa: F403 -from llama_stack.apis.models import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 -from llama_stack.apis.datasets import * # noqa: F403 -from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.eval import Eval from llama_stack.apis.eval_tasks import EvalTaskInput from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory +from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring +from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.shields import * # noqa: F403 +from llama_stack.apis.tools import Tool, ToolGroup, ToolRuntime +from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.providers.utils.kvstore.config import KVStoreConfig LLAMA_STACK_BUILD_CONFIG_VERSION = "2" @@ -37,6 +38,8 @@ RoutableObject = Union[ Dataset, ScoringFn, EvalTask, + Tool, + ToolGroup, ] @@ -48,6 +51,8 @@ RoutableObjectWithProvider = Annotated[ Dataset, ScoringFn, EvalTask, + Tool, + ToolGroup, ], Field(discriminator="type"), ] @@ -59,6 +64,7 @@ RoutedProtocol = Union[ DatasetIO, Scoring, Eval, + ToolRuntime, ] diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 6fc4545c7..4183d92cd 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -47,6 +47,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: routing_table_api=Api.eval_tasks, router_api=Api.eval, ), + AutoRoutedApiInfo( + routing_table_api=Api.tool_groups, + router_api=Api.tool_runtime, + ), ] diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 4541b01eb..439971315 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -30,6 +30,7 @@ from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry +from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.distribution.client import get_client_impl from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.store import DistributionRegistry @@ -60,12 +61,15 @@ def api_protocol_map() -> Dict[Api, Any]: Api.eval: Eval, Api.eval_tasks: EvalTasks, Api.post_training: PostTraining, + Api.tool_groups: ToolGroups, + Api.tool_runtime: ToolRuntime, } def additional_protocols_map() -> Dict[Api, Any]: return { Api.inference: (ModelsProtocolPrivate, Models, Api.models), + Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups), Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 57e81ac30..693f1fbe2 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -7,7 +7,6 @@ from typing import Any from llama_stack.distribution.datatypes import * # noqa: F403 - from llama_stack.distribution.store import DistributionRegistry from .routing_tables import ( @@ -17,6 +16,7 @@ from .routing_tables import ( ModelsRoutingTable, ScoringFunctionsRoutingTable, ShieldsRoutingTable, + ToolGroupsRoutingTable, ) @@ -33,6 +33,7 @@ async def get_routing_table_impl( "datasets": DatasetsRoutingTable, "scoring_functions": ScoringFunctionsRoutingTable, "eval_tasks": EvalTasksRoutingTable, + "tool_groups": ToolGroupsRoutingTable, } if api.value not in api_to_tables: @@ -51,6 +52,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> MemoryRouter, SafetyRouter, ScoringRouter, + ToolRuntimeRouter, ) api_to_routers = { @@ -60,6 +62,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> "datasetio": DatasetIORouter, "scoring": ScoringRouter, "eval": EvalRouter, + "tool_runtime": ToolRuntimeRouter, } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 586ebfae4..a25a848db 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,15 +6,16 @@ from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack.apis.datasetio.datasetio import DatasetIO -from llama_stack.apis.memory_banks.memory_banks import BankParams -from llama_stack.distribution.datatypes import RoutingTable -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.datasetio.datasetio import DatasetIO from llama_stack.apis.eval import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.memory_banks.memory_banks import BankParams +from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.tools import * # noqa: F403 +from llama_stack.distribution.datatypes import RoutingTable class MemoryRouter(Memory): @@ -372,3 +373,28 @@ class EvalRouter(Eval): task_id, job_id, ) + + +class ToolRuntimeRouter(ToolRuntime): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + self.routing_table = routing_table + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: + return await self.routing_table.get_provider_impl(tool_name).invoke_tool( + tool_name=tool_name, + args=args, + ) + + async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: + return await self.routing_table.get_provider_impl( + tool_group.name + ).discover_tools(tool_group) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index ecf47a054..3fb086b72 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -6,21 +6,19 @@ from typing import Any, Dict, List, Optional +from llama_models.llama3.api.datatypes import * # noqa: F403 from pydantic import parse_obj_as -from llama_models.llama3.api.datatypes import * # noqa: F403 - -from llama_stack.apis.models import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.eval_tasks import * # noqa: F403 - -from llama_stack.apis.common.content_types import URL - -from llama_stack.apis.common.type_system import ParamType -from llama_stack.distribution.store import DistributionRegistry +from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.apis.models import * # noqa: F403 +from llama_stack.apis.shields import * # noqa: F403 +from llama_stack.apis.tools import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.distribution.store import DistributionRegistry def get_impl_api(p: Any) -> Api: @@ -45,6 +43,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable return await p.register_scoring_function(obj) elif api == Api.eval: return await p.register_eval_task(obj) + elif api == Api.tool_runtime: + return await p.register_tool(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") @@ -57,6 +57,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: return await p.unregister_model(obj.identifier) elif api == Api.datasetio: return await p.unregister_dataset(obj.identifier) + elif api == Api.tool_runtime: + return await p.unregister_tool(obj.identifier) else: raise ValueError(f"Unregister not supported for {api}") @@ -104,6 +106,8 @@ class CommonRoutingTableImpl(RoutingTable): await add_objects(scoring_functions, pid, ScoringFn) elif api == Api.eval: p.eval_task_store = self + elif api == Api.tool_runtime: + p.tool_store = self async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): @@ -125,6 +129,8 @@ class CommonRoutingTableImpl(RoutingTable): return ("Scoring", "scoring_function") elif isinstance(self, EvalTasksRoutingTable): return ("Eval", "eval_task") + elif isinstance(self, ToolGroupsRoutingTable): + return ("Tools", "tool") else: raise ValueError("Unknown routing table type") @@ -461,3 +467,88 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks): provider_resource_id=provider_eval_task_id, ) await self.register_object(eval_task) + + +class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): + async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: + tools = await self.get_all_with_type("tool") + if tool_group_id: + tools = [tool for tool in tools if tool.tool_group == tool_group_id] + return tools + + async def list_tool_groups(self) -> List[ToolGroup]: + return await self.get_all_with_type("tool_group") + + async def get_tool_group(self, tool_group_id: str) -> ToolGroup: + return await self.get_object_by_identifier("tool_group", tool_group_id) + + async def get_tool(self, tool_name: str) -> Tool: + return await self.get_object_by_identifier("tool", tool_name) + + async def register_tool_group( + self, + tool_group_id: str, + tool_group: ToolGroupDef, + provider_id: Optional[str] = None, + ) -> None: + tools = [] + tool_defs = [] + if provider_id is None: + if len(self.impls_by_provider_id.keys()) > 1: + raise ValueError( + f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}" + ) + provider_id = list(self.impls_by_provider_id.keys())[0] + + if isinstance(tool_group, MCPToolGroupDef): + tool_defs = await self.impls_by_provider_id[provider_id].discover_tools( + tool_group + ) + + elif isinstance(tool_group, UserDefinedToolGroupDef): + tool_defs = tool_group.tools + else: + raise ValueError(f"Unknown tool group: {tool_group}") + + for tool_def in tool_defs: + tools.append( + Tool( + identifier=tool_def.name, + tool_group=tool_group_id, + description=tool_def.description, + parameters=tool_def.parameters, + provider_id=provider_id, + tool_prompt_format=tool_def.tool_prompt_format, + provider_resource_id=tool_def.name, + metadata=tool_def.metadata, + ) + ) + for tool in tools: + existing_tool = await self.get_tool(tool.identifier) + # Compare existing and new object if one exists + if existing_tool: + existing_dict = existing_tool.model_dump() + new_dict = tool.model_dump() + + if existing_dict != new_dict: + raise ValueError( + f"Object {tool.identifier} already exists in registry. Please use a different identifier." + ) + await self.register_object(tool) + + await self.dist_registry.register( + ToolGroup( + identifier=tool_group_id, + provider_id=provider_id, + provider_resource_id=tool_group_id, + ) + ) + + async def unregister_tool_group(self, tool_group_id: str) -> None: + tool_group = await self.get_tool_group(tool_group_id) + if tool_group is None: + raise ValueError(f"Tool group {tool_group_id} not found") + tools = await self.list_tools(tool_group_id) + for tool in tools: + await self.unregister_object(tool) + await self.unregister_object(tool_group) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index c506a754c..ce0c9f52e 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -17,6 +17,7 @@ from llama_stack.apis.memory_banks.memory_banks import MemoryBank from llama_stack.apis.models import Model from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.shields import Shield +from llama_stack.apis.tools import Tool @json_schema_type @@ -29,6 +30,7 @@ class Api(Enum): scoring = "scoring" eval = "eval" post_training = "post_training" + tool_runtime = "tool_runtime" telemetry = "telemetry" @@ -38,6 +40,7 @@ class Api(Enum): datasets = "datasets" scoring_functions = "scoring_functions" eval_tasks = "eval_tasks" + tool_groups = "tool_groups" # built-in API inspect = "inspect" @@ -75,6 +78,12 @@ class EvalTasksProtocolPrivate(Protocol): async def register_eval_task(self, eval_task: EvalTask) -> None: ... +class ToolsProtocolPrivate(Protocol): + async def register_tool(self, tool: Tool) -> None: ... + + async def unregister_tool(self, tool_id: str) -> None: ... + + @json_schema_type class ProviderSpec(BaseModel): api: Api diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py b/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py new file mode 100644 index 000000000..e9f0eeae8 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .brave_search import BraveSearchToolRuntimeImpl +from .config import BraveSearchToolConfig + + +class BraveSearchToolProviderDataValidator(BaseModel): + api_key: str + + +async def get_provider_impl(config: BraveSearchToolConfig, _deps): + impl = BraveSearchToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py new file mode 100644 index 000000000..ca0141552 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List + +import requests + +from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.datatypes import ToolsProtocolPrivate + +from .config import BraveSearchToolConfig + + +class BraveSearchToolRuntimeImpl( + ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData +): + def __init__(self, config: BraveSearchToolConfig): + self.config = config + + async def initialize(self): + pass + + async def register_tool(self, tool: Tool): + if tool.identifier != "brave_search": + raise ValueError(f"Tool identifier {tool.identifier} is not supported") + + async def unregister_tool(self, tool_id: str) -> None: + return + + def _get_api_key(self) -> str: + if self.config.api_key: + return self.config.api_key + + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.api_key: + raise ValueError( + 'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": }' + ) + return provider_data.api_key + + async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: + raise NotImplementedError("Brave search tool group not supported") + + async def invoke_tool( + self, tool_name: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + api_key = self._get_api_key() + url = "https://api.search.brave.com/res/v1/web/search" + headers = { + "X-Subscription-Token": api_key, + "Accept-Encoding": "gzip", + "Accept": "application/json", + } + payload = {"q": args["query"]} + response = requests.get(url=url, params=payload, headers=headers) + response.raise_for_status() + results = self._clean_brave_response(response.json()) + content_items = "\n".join([str(result) for result in results]) + return ToolInvocationResult( + content=content_items, + ) + + def _clean_brave_response(self, search_response): + clean_response = [] + if "mixed" in search_response: + mixed_results = search_response["mixed"] + for m in mixed_results["main"][: self.config.max_results]: + r_type = m["type"] + results = search_response[r_type]["results"] + cleaned = self._clean_result_by_type(r_type, results, m.get("index")) + clean_response.append(cleaned) + + return clean_response + + def _clean_result_by_type(self, r_type, results, idx=None): + type_cleaners = { + "web": ( + ["type", "title", "url", "description", "date", "extra_snippets"], + lambda x: x[idx], + ), + "faq": (["type", "question", "answer", "title", "url"], lambda x: x), + "infobox": ( + ["type", "title", "url", "description", "long_desc"], + lambda x: x[idx], + ), + "videos": (["type", "url", "title", "description", "date"], lambda x: x), + "locations": ( + [ + "type", + "title", + "url", + "description", + "coordinates", + "postal_address", + "contact", + "rating", + "distance", + "zoom_level", + ], + lambda x: x, + ), + "news": (["type", "title", "url", "description"], lambda x: x), + } + + if r_type not in type_cleaners: + return "" + + selected_keys, result_selector = type_cleaners[r_type] + results = result_selector(results) + + if isinstance(results, list): + cleaned = [ + {k: v for k, v in item.items() if k in selected_keys} + for item in results + ] + else: + cleaned = {k: v for k, v in results.items() if k in selected_keys} + + return str(cleaned) diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/config.py b/llama_stack/providers/inline/tool_runtime/brave_search/config.py new file mode 100644 index 000000000..565d428f7 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/brave_search/config.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from pydantic import BaseModel, Field + + +class BraveSearchToolConfig(BaseModel): + api_key: Optional[str] = Field( + default=None, + description="The Brave Search API Key", + ) + max_results: int = Field( + default=3, + description="The maximum number of results to return", + ) diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py new file mode 100644 index 000000000..f3e6aead8 --- /dev/null +++ b/llama_stack/providers/registry/tool_runtime.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.distribution.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.tool_runtime, + provider_type="inline::brave-search", + pip_packages=[], + module="llama_stack.providers.inline.tool_runtime.brave_search", + config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig", + provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", + ), + remote_provider_spec( + api=Api.tool_runtime, + adapter=AdapterSpec( + adapter_type="model-context-protocol", + module="llama_stack.providers.remote.tool_runtime.model_context_protocol", + config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig", + pip_packages=["mcp"], + ), + ), + ] diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py new file mode 100644 index 000000000..3b05f5632 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .config import ModelContextProtocolConfig + +from .model_context_protocol import ModelContextProtocolToolRuntimeImpl + + +class ModelContextProtocolToolProviderDataValidator(BaseModel): + api_key: str + + +async def get_adapter_impl(config: ModelContextProtocolConfig, _deps): + impl = ModelContextProtocolToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py new file mode 100644 index 000000000..ffe4c9887 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class ModelContextProtocolConfig(BaseModel): + pass diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py new file mode 100644 index 000000000..b9bf3fe36 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List +from urllib.parse import urlparse + +from llama_stack.apis.tools import ( + MCPToolGroupDef, + ToolDef, + ToolGroupDef, + ToolInvocationResult, + ToolParameter, + ToolRuntime, +) +from llama_stack.providers.datatypes import ToolsProtocolPrivate + +from mcp import ClientSession +from mcp.client.sse import sse_client + +from .config import ModelContextProtocolConfig + + +class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): + def __init__(self, config: ModelContextProtocolConfig): + self.config = config + + async def initialize(self): + pass + + async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: + if not isinstance(tool_group, MCPToolGroupDef): + raise ValueError(f"Unsupported tool group type: {type(tool_group)}") + + tools = [] + async with sse_client(tool_group.endpoint.uri) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + tools_result = await session.list_tools() + for tool in tools_result.tools: + parameters = [] + for param_name, param_schema in tool.inputSchema.get( + "properties", {} + ).items(): + parameters.append( + ToolParameter( + name=param_name, + parameter_type=param_schema.get("type", "string"), + description=param_schema.get("description", ""), + ) + ) + tools.append( + ToolDef( + name=tool.name, + description=tool.description, + parameters=parameters, + metadata={ + "endpoint": tool_group.endpoint.uri, + }, + ) + ) + return tools + + async def invoke_tool( + self, tool_name: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + tool = await self.tool_store.get_tool(tool_name) + if tool.metadata is None or tool.metadata.get("endpoint") is None: + raise ValueError(f"Tool {tool_name} does not have metadata") + endpoint = tool.metadata.get("endpoint") + if urlparse(endpoint).scheme not in ("http", "https"): + raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") + + async with sse_client(endpoint) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + result = await session.call_tool(tool.identifier, args) + + return ToolInvocationResult( + content="\n".join([result.model_dump_json() for result in result.content]), + error_code=1 if result.isError else 0, + )