diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index cc458c32a..556edc434 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -108,6 +108,8 @@ class CommonRoutingTableImpl(RoutingTable): await add_objects(scoring_functions, pid, ScoringFn) elif api == Api.eval: p.eval_task_store = self + elif api == Api.tool_runtime: + p.tool_store = self async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): @@ -129,6 +131,8 @@ class CommonRoutingTableImpl(RoutingTable): return ("Scoring", "scoring_function") elif isinstance(self, EvalTasksRoutingTable): return ("Eval", "eval_task") + elif isinstance(self, ToolsRoutingTable): + return ("Tools", "tool") else: raise ValueError("Unknown routing table type") diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py b/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py similarity index 51% rename from llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py rename to llama_stack/providers/inline/tool_runtime/brave_search/__init__.py index f7d52c1f0..418f0fd5a 100644 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py @@ -4,11 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .config import MetaReferenceToolRuntimeConfig -from .meta_reference import MetaReferenceToolRuntimeImpl +from .brave_search import BraveSearchToolRuntimeImpl +from .config import BraveSearchToolConfig -async def get_provider_impl(config: MetaReferenceToolRuntimeConfig, _deps): - impl = MetaReferenceToolRuntimeImpl(config) +async def get_provider_impl(config: BraveSearchToolConfig, _deps): + impl = BraveSearchToolRuntimeImpl(config) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py new file mode 100644 index 000000000..0ecf9f9c3 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List + +import requests + +from llama_stack.apis.tools import InvokeToolResult, Tool, ToolRuntime +from llama_stack.providers.datatypes import ToolsProtocolPrivate + +from .config import BraveSearchToolConfig + + +class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): + def __init__(self, config: BraveSearchToolConfig): + self.config = config + + async def initialize(self): + pass + + async def register_tool(self, tool: Tool): + if tool.identifier != "brave_search": + raise ValueError(f"Tool identifier {tool.identifier} is not supported") + + async def unregister_tool(self, tool_id: str) -> None: + return + + async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> InvokeToolResult: + results = await self.execute(args["query"]) + content_items = "\n".join([str(result) for result in results]) + return InvokeToolResult( + content=content_items, + ) + + async def execute(self, query: str) -> List[dict]: + url = "https://api.search.brave.com/res/v1/web/search" + headers = { + "X-Subscription-Token": self.config.api_key, + "Accept-Encoding": "gzip", + "Accept": "application/json", + } + payload = {"q": query} + response = requests.get(url=url, params=payload, headers=headers) + response.raise_for_status() + return self._clean_brave_response(response.json()) + + def _clean_brave_response(self, search_response): + clean_response = [] + if "mixed" in search_response: + mixed_results = search_response["mixed"] + for m in mixed_results["main"][: self.config.max_results]: + r_type = m["type"] + results = search_response[r_type]["results"] + cleaned = self._clean_result_by_type(r_type, results, m.get("index")) + clean_response.append(cleaned) + + return clean_response + + def _clean_result_by_type(self, r_type, results, idx=None): + type_cleaners = { + "web": ( + ["type", "title", "url", "description", "date", "extra_snippets"], + lambda x: x[idx], + ), + "faq": (["type", "question", "answer", "title", "url"], lambda x: x), + "infobox": ( + ["type", "title", "url", "description", "long_desc"], + lambda x: x[idx], + ), + "videos": (["type", "url", "title", "description", "date"], lambda x: x), + "locations": ( + [ + "type", + "title", + "url", + "description", + "coordinates", + "postal_address", + "contact", + "rating", + "distance", + "zoom_level", + ], + lambda x: x, + ), + "news": (["type", "title", "url", "description"], lambda x: x), + } + + if r_type not in type_cleaners: + return "" + + selected_keys, result_selector = type_cleaners[r_type] + results = result_selector(results) + + if isinstance(results, list): + cleaned = [ + {k: v for k, v in item.items() if k in selected_keys} + for item in results + ] + else: + cleaned = {k: v for k, v in results.items() if k in selected_keys} + + return str(cleaned) diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/config.py b/llama_stack/providers/inline/tool_runtime/brave_search/config.py similarity index 74% rename from llama_stack/providers/inline/tool_runtime/meta_reference/config.py rename to llama_stack/providers/inline/tool_runtime/brave_search/config.py index 3f6146c51..e8fbaaec9 100644 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/config.py +++ b/llama_stack/providers/inline/tool_runtime/brave_search/config.py @@ -7,5 +7,6 @@ from pydantic import BaseModel -class MetaReferenceToolRuntimeConfig(BaseModel): - pass +class BraveSearchToolConfig(BaseModel): + api_key: str + max_results: int = 3 diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py b/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py deleted file mode 100644 index 087fd918d..000000000 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict - -from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime -from llama_stack.providers.datatypes import ToolsProtocolPrivate - -from .config import MetaReferenceToolRuntimeConfig - - -class MetaReferenceToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): - def __init__(self, config: MetaReferenceToolRuntimeConfig): - self.config = config - - async def initialize(self): - pass - - async def register_tool(self, tool: Tool): - print(f"registering tool {tool.identifier}") - pass - - async def unregister_tool(self, tool_id: str) -> None: - pass - - async def invoke_tool( - self, tool_id: str, args: Dict[str, Any] - ) -> ToolInvocationResult: - pass diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index c0e7a3d1b..64dc520fd 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -13,9 +13,9 @@ def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.tool_runtime, - provider_type="inline::meta-reference", + provider_type="inline::brave-search", pip_packages=[], - module="llama_stack.providers.inline.tool_runtime.meta_reference", - config_class="llama_stack.providers.inline.tool_runtime.meta_reference.MetaReferenceToolRuntimeConfig", + module="llama_stack.providers.inline.tool_runtime.brave_search", + config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig", ), ]