mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
add brave tool provider (#653)
# What does this PR do? Adds a new brave tool provider ## Test Plan ``` curl -X POST 'http://localhost:5000/alpha/toolgroups/register' \ -H 'Content-Type: application/json' \ -d '{ "name": "search", "tool_group": { "type": "user_defined", "tools": [ { "name": "brave_search", "description": "A web search tool", "parameters": [ { "name": "query", "parameter_type": "string", "description": "The query to search" } ], "metadata": {}, "tool_prompt_format": "json" } ] } }' curl -X POST 'http://localhost:5000/alpha/tool-runtime/invoke' \ -H 'Content-Type: application/json' \ -d '{ "tool_id": "brave_search", "args": { "query": "who is meta ceo" } }' | jq .content % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 1973 100 1884 100 89 11288 533 --:--:-- --:--:-- --:--:-- 11885 "{'title': 'Mark Zuckerberg, Founder, Chairman and Chief Executive ...', 'url': 'https://about.meta.com/media-gallery/executives/mark-zuckerberg/', 'description': 'Not Logged In · Please log in to see this page', 'type': 'search_result'}\n{'title': 'Meta - Leadership & Governance', 'url': 'https://investor.fb.com/leadership-and-governance/', 'description': '<strong>Mark Zuckerberg</strong> is the founder, chairman and CEO of Meta, which he originally founded as Facebook in 2004. Mark is responsible for setting the overall direction and product strategy for the company. He leads the design of Meta's services and development of its core technology and infrastructure.', 'type': 'search_result'}\n[{'type': 'video_result', 'url': '2372542949
/', 'title': 'Mark Zuckerberg, the CEO of Meta, has officially joined the ...', 'description': \"Express Tribune, Karachi, Pakistan. 2,334,400 likes · 36,360 talking about this · 205 were here. The Express Tribune is Pakistan's #1 brand for breaking news in politics, sports, business, lifestyle\"}, {'type': 'video_result', 'url': 'https://www.youtube.com/watch?v=Y3oeQqtRvqk', 'title': \"Meta CEO: Mark Zuckerberg becomes World's Second Richest Person!\", 'description': 'Try VectorVest Risk-Free ➥➥➥ https://www.vectorvest.com/YTUse this link for a FREE Stock Analysis Report ➥➥➥ vectorvest.com/YTFSAVectorVest Merch Store ➥➥➥'}, {'type': 'video_result', 'url': '5348412224
/', 'title': '#WATCH | Meta founder and CEO Mark Zuckerberg recently ...', 'description': 'See posts, photos and more on Facebook'}]" curl -X POST 'http://localhost:5000/alpha/tool-runtime/invoke' \ -H 'Content-Type: application/json' -H 'X-LlamaStack-ProviderData: {"api_key": "<KEY>"}' \ -d '{ "tool_id": "brave_search", "args": { "query": "who is meta ceo" } }' ```
This commit is contained in:
parent
ea0ca7454a
commit
a297d27d48
12 changed files with 316 additions and 58 deletions
|
@ -26,11 +26,11 @@ class ToolParameter(BaseModel):
|
|||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
||||
name: str
|
||||
tool_group: str
|
||||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
provider_id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
|
@ -55,12 +55,14 @@ class MCPToolGroup(BaseModel):
|
|||
"""
|
||||
|
||||
type: Literal["model_context_protocol"] = "model_context_protocol"
|
||||
name: str
|
||||
endpoint: URL
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UserDefinedToolGroup(BaseModel):
|
||||
type: Literal["user_defined"] = "user_defined"
|
||||
name: str
|
||||
tools: List[ToolDef]
|
||||
|
||||
|
||||
|
@ -87,7 +89,6 @@ class Tools(Protocol):
|
|||
@webmethod(route="/toolgroups/register", method="POST")
|
||||
async def register_tool_group(
|
||||
self,
|
||||
name: str,
|
||||
tool_group: ToolGroup,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> None:
|
||||
|
@ -115,6 +116,9 @@ class Tools(Protocol):
|
|||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore
|
||||
|
||||
@webmethod(route="/tool-runtime/discover", method="POST")
|
||||
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: ...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(
|
||||
self, tool_id: str, args: Dict[str, Any]
|
||||
|
|
|
@ -393,3 +393,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
tool_id=tool_id,
|
||||
args=args,
|
||||
)
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
|
||||
return await self.routing_table.get_provider_impl(
|
||||
tool_group.name
|
||||
).discover_tools(tool_group)
|
||||
|
|
|
@ -108,6 +108,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
await add_objects(scoring_functions, pid, ScoringFn)
|
||||
elif api == Api.eval:
|
||||
p.eval_task_store = self
|
||||
elif api == Api.tool_runtime:
|
||||
p.tool_store = self
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.impls_by_provider_id.values():
|
||||
|
@ -129,6 +131,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
return ("Scoring", "scoring_function")
|
||||
elif isinstance(self, EvalTasksRoutingTable):
|
||||
return ("Eval", "eval_task")
|
||||
elif isinstance(self, ToolsRoutingTable):
|
||||
return ("Tools", "tool")
|
||||
else:
|
||||
raise ValueError("Unknown routing table type")
|
||||
|
||||
|
@ -476,34 +480,40 @@ class ToolsRoutingTable(CommonRoutingTableImpl, Tools):
|
|||
|
||||
async def register_tool_group(
|
||||
self,
|
||||
name: str,
|
||||
tool_group: ToolGroup,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> None:
|
||||
tools = []
|
||||
if isinstance(tool_group, MCPToolGroup):
|
||||
# TODO: first needs to be resolved to corresponding tools available in the MCP server
|
||||
raise NotImplementedError("MCP tool provider not implemented yet")
|
||||
# TODO: Actually find the right MCP provider
|
||||
if provider_id is None:
|
||||
raise ValueError("MCP provider_id not specified")
|
||||
tools = await self.impls_by_provider_id[provider_id].discover_tools(
|
||||
tool_group
|
||||
)
|
||||
for tool in tools:
|
||||
tool.provider_id = provider_id
|
||||
elif isinstance(tool_group, UserDefinedToolGroup):
|
||||
for tool in tool_group.tools:
|
||||
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool.name,
|
||||
tool_group=name,
|
||||
tool_group=tool_group.name,
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.parameters,
|
||||
provider_id=provider_id,
|
||||
tool_prompt_format=tool.tool_prompt_format,
|
||||
provider_resource_id=tool.name,
|
||||
metadata=tool.metadata,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown tool group: {tool_group}")
|
||||
|
||||
for tool in tools:
|
||||
existing_tool = await self.get_tool(tool.name)
|
||||
existing_tool = await self.get_tool(tool.identifier)
|
||||
# Compare existing and new object if one exists
|
||||
if existing_tool:
|
||||
# Compare all fields except provider_id since that might be None in new obj
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .brave_search import BraveSearchToolRuntimeImpl
|
||||
from .config import BraveSearchToolConfig
|
||||
|
||||
|
||||
class BraveSearchToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
|
||||
impl = BraveSearchToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .config import BraveSearchToolConfig
|
||||
|
||||
|
||||
class BraveSearchToolRuntimeImpl(
|
||||
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: BraveSearchToolConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
if tool.identifier != "brave_search":
|
||||
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key:
|
||||
return self.config.api_key
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.api_key:
|
||||
raise ValueError(
|
||||
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
|
||||
)
|
||||
return provider_data.api_key
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
|
||||
raise NotImplementedError("Brave search tool group not supported")
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_id: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
headers = {
|
||||
"X-Subscription-Token": api_key,
|
||||
"Accept-Encoding": "gzip",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"q": args["query"]}
|
||||
response = requests.get(url=url, params=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
results = self._clean_brave_response(response.json())
|
||||
content_items = "\n".join([str(result) for result in results])
|
||||
return ToolInvocationResult(
|
||||
content=content_items,
|
||||
)
|
||||
|
||||
def _clean_brave_response(self, search_response):
|
||||
clean_response = []
|
||||
if "mixed" in search_response:
|
||||
mixed_results = search_response["mixed"]
|
||||
for m in mixed_results["main"][: self.config.max_results]:
|
||||
r_type = m["type"]
|
||||
results = search_response[r_type]["results"]
|
||||
cleaned = self._clean_result_by_type(r_type, results, m.get("index"))
|
||||
clean_response.append(cleaned)
|
||||
|
||||
return clean_response
|
||||
|
||||
def _clean_result_by_type(self, r_type, results, idx=None):
|
||||
type_cleaners = {
|
||||
"web": (
|
||||
["type", "title", "url", "description", "date", "extra_snippets"],
|
||||
lambda x: x[idx],
|
||||
),
|
||||
"faq": (["type", "question", "answer", "title", "url"], lambda x: x),
|
||||
"infobox": (
|
||||
["type", "title", "url", "description", "long_desc"],
|
||||
lambda x: x[idx],
|
||||
),
|
||||
"videos": (["type", "url", "title", "description", "date"], lambda x: x),
|
||||
"locations": (
|
||||
[
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
"coordinates",
|
||||
"postal_address",
|
||||
"contact",
|
||||
"rating",
|
||||
"distance",
|
||||
"zoom_level",
|
||||
],
|
||||
lambda x: x,
|
||||
),
|
||||
"news": (["type", "title", "url", "description"], lambda x: x),
|
||||
}
|
||||
|
||||
if r_type not in type_cleaners:
|
||||
return ""
|
||||
|
||||
selected_keys, result_selector = type_cleaners[r_type]
|
||||
results = result_selector(results)
|
||||
|
||||
if isinstance(results, list):
|
||||
cleaned = [
|
||||
{k: v for k, v in item.items() if k in selected_keys}
|
||||
for item in results
|
||||
]
|
||||
else:
|
||||
cleaned = {k: v for k, v in results.items() if k in selected_keys}
|
||||
|
||||
return str(cleaned)
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BraveSearchToolConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Brave Search API Key",
|
||||
)
|
||||
max_results: int = Field(
|
||||
default=3,
|
||||
description="The maximum number of results to return",
|
||||
)
|
|
@ -1,14 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import MetaReferenceToolRuntimeConfig
|
||||
from .meta_reference import MetaReferenceToolRuntimeImpl
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceToolRuntimeConfig, _deps):
|
||||
impl = MetaReferenceToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,32 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .config import MetaReferenceToolRuntimeConfig
|
||||
|
||||
|
||||
class MetaReferenceToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||
def __init__(self, config: MetaReferenceToolRuntimeConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
print(f"registering tool {tool.identifier}")
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_id: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
pass
|
|
@ -6,16 +6,32 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
provider_type="inline::meta-reference",
|
||||
provider_type="inline::brave-search",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.tool_runtime.meta_reference",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.meta_reference.MetaReferenceToolRuntimeConfig",
|
||||
module="llama_stack.providers.inline.tool_runtime.brave_search",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="model-context-protocol",
|
||||
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
|
||||
pip_packages=["mcp"],
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import ModelContextProtocolConfig
|
||||
|
||||
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
|
||||
|
||||
|
||||
class ModelContextProtocolToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
|
||||
impl = ModelContextProtocolToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -7,5 +7,5 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MetaReferenceToolRuntimeConfig(BaseModel):
|
||||
class ModelContextProtocolConfig(BaseModel):
|
||||
pass
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from llama_stack.apis.tools import (
|
||||
MCPToolGroup,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from .config import ModelContextProtocolConfig
|
||||
|
||||
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||
def __init__(self, config: ModelContextProtocolConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
|
||||
if not isinstance(tool_group, MCPToolGroup):
|
||||
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
|
||||
|
||||
tools = []
|
||||
async with sse_client(tool_group.endpoint.uri) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
for param_name, param_schema in tool.inputSchema.get(
|
||||
"properties", {}
|
||||
).items():
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=param_name,
|
||||
parameter_type=param_schema.get("type", "string"),
|
||||
description=param_schema.get("description", ""),
|
||||
)
|
||||
)
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool.name,
|
||||
description=tool.description,
|
||||
tool_group=tool_group.name,
|
||||
parameters=parameters,
|
||||
metadata={
|
||||
"endpoint": tool_group.endpoint.uri,
|
||||
},
|
||||
provider_resource_id=tool.name,
|
||||
)
|
||||
)
|
||||
return tools
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_id: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_id)
|
||||
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
||||
raise ValueError(f"Tool {tool_id} does not have metadata")
|
||||
endpoint = tool.metadata.get("endpoint")
|
||||
if urlparse(endpoint).scheme not in ("http", "https"):
|
||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||
|
||||
async with sse_client(endpoint) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
result = await session.call_tool(tool.identifier, args)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content="\n".join([result.model_dump_json() for result in result.content]),
|
||||
error_code=1 if result.isError else 0,
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue