add brave tool provider

This commit is contained in:
Dinesh Yeduguru 2024-12-18 16:37:02 -08:00
parent ea0ca7454a
commit 1c770508df
6 changed files with 120 additions and 41 deletions

View file

@ -108,6 +108,8 @@ class CommonRoutingTableImpl(RoutingTable):
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.eval_task_store = self
elif api == Api.tool_runtime:
p.tool_store = self
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
@ -129,6 +131,8 @@ class CommonRoutingTableImpl(RoutingTable):
return ("Scoring", "scoring_function")
elif isinstance(self, EvalTasksRoutingTable):
return ("Eval", "eval_task")
elif isinstance(self, ToolsRoutingTable):
return ("Tools", "tool")
else:
raise ValueError("Unknown routing table type")

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import MetaReferenceToolRuntimeConfig
from .meta_reference import MetaReferenceToolRuntimeImpl
from .brave_search import BraveSearchToolRuntimeImpl
from .config import BraveSearchToolConfig
async def get_provider_impl(config: MetaReferenceToolRuntimeConfig, _deps):
impl = MetaReferenceToolRuntimeImpl(config)
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
impl = BraveSearchToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,106 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
import requests
from llama_stack.apis.tools import InvokeToolResult, Tool, ToolRuntime
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import BraveSearchToolConfig
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: BraveSearchToolConfig):
self.config = config
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
if tool.identifier != "brave_search":
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
async def unregister_tool(self, tool_id: str) -> None:
return
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> InvokeToolResult:
results = await self.execute(args["query"])
content_items = "\n".join([str(result) for result in results])
return InvokeToolResult(
content=content_items,
)
async def execute(self, query: str) -> List[dict]:
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"X-Subscription-Token": self.config.api_key,
"Accept-Encoding": "gzip",
"Accept": "application/json",
}
payload = {"q": query}
response = requests.get(url=url, params=payload, headers=headers)
response.raise_for_status()
return self._clean_brave_response(response.json())
def _clean_brave_response(self, search_response):
clean_response = []
if "mixed" in search_response:
mixed_results = search_response["mixed"]
for m in mixed_results["main"][: self.config.max_results]:
r_type = m["type"]
results = search_response[r_type]["results"]
cleaned = self._clean_result_by_type(r_type, results, m.get("index"))
clean_response.append(cleaned)
return clean_response
def _clean_result_by_type(self, r_type, results, idx=None):
type_cleaners = {
"web": (
["type", "title", "url", "description", "date", "extra_snippets"],
lambda x: x[idx],
),
"faq": (["type", "question", "answer", "title", "url"], lambda x: x),
"infobox": (
["type", "title", "url", "description", "long_desc"],
lambda x: x[idx],
),
"videos": (["type", "url", "title", "description", "date"], lambda x: x),
"locations": (
[
"type",
"title",
"url",
"description",
"coordinates",
"postal_address",
"contact",
"rating",
"distance",
"zoom_level",
],
lambda x: x,
),
"news": (["type", "title", "url", "description"], lambda x: x),
}
if r_type not in type_cleaners:
return ""
selected_keys, result_selector = type_cleaners[r_type]
results = result_selector(results)
if isinstance(results, list):
cleaned = [
{k: v for k, v in item.items() if k in selected_keys}
for item in results
]
else:
cleaned = {k: v for k, v in results.items() if k in selected_keys}
return str(cleaned)

View file

@ -7,5 +7,6 @@
from pydantic import BaseModel
class MetaReferenceToolRuntimeConfig(BaseModel):
pass
class BraveSearchToolConfig(BaseModel):
api_key: str
max_results: int = 3

View file

@ -1,32 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import MetaReferenceToolRuntimeConfig
class MetaReferenceToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: MetaReferenceToolRuntimeConfig):
self.config = config
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
print(f"registering tool {tool.identifier}")
pass
async def unregister_tool(self, tool_id: str) -> None:
pass
async def invoke_tool(
self, tool_id: str, args: Dict[str, Any]
) -> ToolInvocationResult:
pass

View file

@ -13,9 +13,9 @@ def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::meta-reference",
provider_type="inline::brave-search",
pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.meta_reference",
config_class="llama_stack.providers.inline.tool_runtime.meta_reference.MetaReferenceToolRuntimeConfig",
module="llama_stack.providers.inline.tool_runtime.brave_search",
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
),
]