mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
add brave tool provider
This commit is contained in:
parent
ea0ca7454a
commit
1c770508df
6 changed files with 120 additions and 41 deletions
|
@ -108,6 +108,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
await add_objects(scoring_functions, pid, ScoringFn)
|
await add_objects(scoring_functions, pid, ScoringFn)
|
||||||
elif api == Api.eval:
|
elif api == Api.eval:
|
||||||
p.eval_task_store = self
|
p.eval_task_store = self
|
||||||
|
elif api == Api.tool_runtime:
|
||||||
|
p.tool_store = self
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for p in self.impls_by_provider_id.values():
|
for p in self.impls_by_provider_id.values():
|
||||||
|
@ -129,6 +131,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
return ("Scoring", "scoring_function")
|
return ("Scoring", "scoring_function")
|
||||||
elif isinstance(self, EvalTasksRoutingTable):
|
elif isinstance(self, EvalTasksRoutingTable):
|
||||||
return ("Eval", "eval_task")
|
return ("Eval", "eval_task")
|
||||||
|
elif isinstance(self, ToolsRoutingTable):
|
||||||
|
return ("Tools", "tool")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown routing table type")
|
raise ValueError("Unknown routing table type")
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .config import MetaReferenceToolRuntimeConfig
|
from .brave_search import BraveSearchToolRuntimeImpl
|
||||||
from .meta_reference import MetaReferenceToolRuntimeImpl
|
from .config import BraveSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: MetaReferenceToolRuntimeConfig, _deps):
|
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
|
||||||
impl = MetaReferenceToolRuntimeImpl(config)
|
impl = BraveSearchToolRuntimeImpl(config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
|
@ -0,0 +1,106 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llama_stack.apis.tools import InvokeToolResult, Tool, ToolRuntime
|
||||||
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
|
from .config import BraveSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
|
def __init__(self, config: BraveSearchToolConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_tool(self, tool: Tool):
|
||||||
|
if tool.identifier != "brave_search":
|
||||||
|
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
|
||||||
|
|
||||||
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> InvokeToolResult:
|
||||||
|
results = await self.execute(args["query"])
|
||||||
|
content_items = "\n".join([str(result) for result in results])
|
||||||
|
return InvokeToolResult(
|
||||||
|
content=content_items,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(self, query: str) -> List[dict]:
|
||||||
|
url = "https://api.search.brave.com/res/v1/web/search"
|
||||||
|
headers = {
|
||||||
|
"X-Subscription-Token": self.config.api_key,
|
||||||
|
"Accept-Encoding": "gzip",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
payload = {"q": query}
|
||||||
|
response = requests.get(url=url, params=payload, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
return self._clean_brave_response(response.json())
|
||||||
|
|
||||||
|
def _clean_brave_response(self, search_response):
|
||||||
|
clean_response = []
|
||||||
|
if "mixed" in search_response:
|
||||||
|
mixed_results = search_response["mixed"]
|
||||||
|
for m in mixed_results["main"][: self.config.max_results]:
|
||||||
|
r_type = m["type"]
|
||||||
|
results = search_response[r_type]["results"]
|
||||||
|
cleaned = self._clean_result_by_type(r_type, results, m.get("index"))
|
||||||
|
clean_response.append(cleaned)
|
||||||
|
|
||||||
|
return clean_response
|
||||||
|
|
||||||
|
def _clean_result_by_type(self, r_type, results, idx=None):
|
||||||
|
type_cleaners = {
|
||||||
|
"web": (
|
||||||
|
["type", "title", "url", "description", "date", "extra_snippets"],
|
||||||
|
lambda x: x[idx],
|
||||||
|
),
|
||||||
|
"faq": (["type", "question", "answer", "title", "url"], lambda x: x),
|
||||||
|
"infobox": (
|
||||||
|
["type", "title", "url", "description", "long_desc"],
|
||||||
|
lambda x: x[idx],
|
||||||
|
),
|
||||||
|
"videos": (["type", "url", "title", "description", "date"], lambda x: x),
|
||||||
|
"locations": (
|
||||||
|
[
|
||||||
|
"type",
|
||||||
|
"title",
|
||||||
|
"url",
|
||||||
|
"description",
|
||||||
|
"coordinates",
|
||||||
|
"postal_address",
|
||||||
|
"contact",
|
||||||
|
"rating",
|
||||||
|
"distance",
|
||||||
|
"zoom_level",
|
||||||
|
],
|
||||||
|
lambda x: x,
|
||||||
|
),
|
||||||
|
"news": (["type", "title", "url", "description"], lambda x: x),
|
||||||
|
}
|
||||||
|
|
||||||
|
if r_type not in type_cleaners:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
selected_keys, result_selector = type_cleaners[r_type]
|
||||||
|
results = result_selector(results)
|
||||||
|
|
||||||
|
if isinstance(results, list):
|
||||||
|
cleaned = [
|
||||||
|
{k: v for k, v in item.items() if k in selected_keys}
|
||||||
|
for item in results
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
cleaned = {k: v for k, v in results.items() if k in selected_keys}
|
||||||
|
|
||||||
|
return str(cleaned)
|
|
@ -7,5 +7,6 @@
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceToolRuntimeConfig(BaseModel):
|
class BraveSearchToolConfig(BaseModel):
|
||||||
pass
|
api_key: str
|
||||||
|
max_results: int = 3
|
|
@ -1,32 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime
|
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
|
||||||
|
|
||||||
from .config import MetaReferenceToolRuntimeConfig
|
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|
||||||
def __init__(self, config: MetaReferenceToolRuntimeConfig):
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool):
|
|
||||||
print(f"registering tool {tool.identifier}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def invoke_tool(
|
|
||||||
self, tool_id: str, args: Dict[str, Any]
|
|
||||||
) -> ToolInvocationResult:
|
|
||||||
pass
|
|
|
@ -13,9 +13,9 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
provider_type="inline::meta-reference",
|
provider_type="inline::brave-search",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.inline.tool_runtime.meta_reference",
|
module="llama_stack.providers.inline.tool_runtime.brave_search",
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.meta_reference.MetaReferenceToolRuntimeConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue