add brave tool provider (#653)

# What does this PR do?

Adds a new brave tool provider


## Test Plan

```
curl -X POST 'http://localhost:5000/alpha/toolgroups/register' \
-H 'Content-Type: application/json' \
-d '{
  "name": "search",
  "tool_group": {
    "type": "user_defined",
    "tools": [
      {
        "name": "brave_search",
        "description": "A web search tool",
        "parameters": [
          {
            "name": "query",
            "parameter_type": "string",
            "description": "The query to search"
          }
        ],
        "metadata": {},
        "tool_prompt_format": "json"
      }
    ]
  }
}'

curl -X POST 'http://localhost:5000/alpha/tool-runtime/invoke' \
-H 'Content-Type: application/json' \
-d '{
    "tool_id": "brave_search",
    "args": {
        "query": "who is meta ceo"
    }
}' | jq .content
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1973  100  1884  100    89  11288    533 --:--:-- --:--:-- --:--:-- 11885
"{'title': 'Mark Zuckerberg, Founder, Chairman and Chief Executive ...', 'url': 'https://about.meta.com/media-gallery/executives/mark-zuckerberg/', 'description': 'Not Logged In · Please log in to see this page', 'type': 'search_result'}\n{'title': 'Meta - Leadership & Governance', 'url': 'https://investor.fb.com/leadership-and-governance/', 'description': '<strong>Mark Zuckerberg</strong> is the founder, chairman and CEO of Meta, which he originally founded as Facebook in 2004. Mark is responsible for setting the overall direction and product strategy for the company. He leads the design of Meta&#x27;s services and development of its core technology and infrastructure.', 'type': 'search_result'}\n[{'type': 'video_result', 'url': '2372542949/', 'title': 'Mark Zuckerberg, the CEO of Meta, has officially joined the ...', 'description': \"Express Tribune, Karachi, Pakistan. 2,334,400 likes · 36,360 talking about this · 205 were here. The Express Tribune is Pakistan's #1 brand for breaking news in politics, sports, business, lifestyle\"}, {'type': 'video_result', 'url': 'https://www.youtube.com/watch?v=Y3oeQqtRvqk', 'title': \"Meta CEO: Mark Zuckerberg becomes World's Second Richest Person!\", 'description': 'Try VectorVest Risk-Free ➥➥➥ https://www.vectorvest.com/YTUse this link for a FREE Stock Analysis Report ➥➥➥ vectorvest.com/YTFSAVectorVest Merch Store ➥➥➥'}, {'type': 'video_result', 'url': '5348412224/', 'title': '#WATCH | Meta founder and CEO Mark Zuckerberg recently ...', 'description': 'See posts, photos and more on Facebook'}]"

curl -X POST 'http://localhost:5000/alpha/tool-runtime/invoke' \
-H 'Content-Type: application/json' -H 'X-LlamaStack-ProviderData: {"api_key": "<KEY>"}' \
-d '{
    "tool_id": "brave_search",
    "args": {
        "query": "who is meta ceo"
    }
}'
```
This commit is contained in:
Dinesh Yeduguru 2024-12-19 16:15:08 -08:00 committed by GitHub
parent ea0ca7454a
commit a297d27d48
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 316 additions and 58 deletions

View file

@ -26,11 +26,11 @@ class ToolParameter(BaseModel):
@json_schema_type @json_schema_type
class Tool(Resource): class Tool(Resource):
type: Literal[ResourceType.tool.value] = ResourceType.tool.value type: Literal[ResourceType.tool.value] = ResourceType.tool.value
name: str
tool_group: str tool_group: str
description: str description: str
parameters: List[ToolParameter] parameters: List[ToolParameter]
provider_id: Optional[str] = None provider_id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field( tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json default=ToolPromptFormat.json
) )
@ -55,12 +55,14 @@ class MCPToolGroup(BaseModel):
""" """
type: Literal["model_context_protocol"] = "model_context_protocol" type: Literal["model_context_protocol"] = "model_context_protocol"
name: str
endpoint: URL endpoint: URL
@json_schema_type @json_schema_type
class UserDefinedToolGroup(BaseModel): class UserDefinedToolGroup(BaseModel):
type: Literal["user_defined"] = "user_defined" type: Literal["user_defined"] = "user_defined"
name: str
tools: List[ToolDef] tools: List[ToolDef]
@ -87,7 +89,6 @@ class Tools(Protocol):
@webmethod(route="/toolgroups/register", method="POST") @webmethod(route="/toolgroups/register", method="POST")
async def register_tool_group( async def register_tool_group(
self, self,
name: str,
tool_group: ToolGroup, tool_group: ToolGroup,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
) -> None: ) -> None:
@ -115,6 +116,9 @@ class Tools(Protocol):
class ToolRuntime(Protocol): class ToolRuntime(Protocol):
tool_store: ToolStore tool_store: ToolStore
@webmethod(route="/tool-runtime/discover", method="POST")
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: ...
@webmethod(route="/tool-runtime/invoke", method="POST") @webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool( async def invoke_tool(
self, tool_id: str, args: Dict[str, Any] self, tool_id: str, args: Dict[str, Any]

View file

@ -393,3 +393,8 @@ class ToolRuntimeRouter(ToolRuntime):
tool_id=tool_id, tool_id=tool_id,
args=args, args=args,
) )
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
return await self.routing_table.get_provider_impl(
tool_group.name
).discover_tools(tool_group)

View file

@ -108,6 +108,8 @@ class CommonRoutingTableImpl(RoutingTable):
await add_objects(scoring_functions, pid, ScoringFn) await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval: elif api == Api.eval:
p.eval_task_store = self p.eval_task_store = self
elif api == Api.tool_runtime:
p.tool_store = self
async def shutdown(self) -> None: async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():
@ -129,6 +131,8 @@ class CommonRoutingTableImpl(RoutingTable):
return ("Scoring", "scoring_function") return ("Scoring", "scoring_function")
elif isinstance(self, EvalTasksRoutingTable): elif isinstance(self, EvalTasksRoutingTable):
return ("Eval", "eval_task") return ("Eval", "eval_task")
elif isinstance(self, ToolsRoutingTable):
return ("Tools", "tool")
else: else:
raise ValueError("Unknown routing table type") raise ValueError("Unknown routing table type")
@ -476,34 +480,40 @@ class ToolsRoutingTable(CommonRoutingTableImpl, Tools):
async def register_tool_group( async def register_tool_group(
self, self,
name: str,
tool_group: ToolGroup, tool_group: ToolGroup,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
) -> None: ) -> None:
tools = [] tools = []
if isinstance(tool_group, MCPToolGroup): if isinstance(tool_group, MCPToolGroup):
# TODO: first needs to be resolved to corresponding tools available in the MCP server # TODO: Actually find the right MCP provider
raise NotImplementedError("MCP tool provider not implemented yet") if provider_id is None:
raise ValueError("MCP provider_id not specified")
tools = await self.impls_by_provider_id[provider_id].discover_tools(
tool_group
)
for tool in tools:
tool.provider_id = provider_id
elif isinstance(tool_group, UserDefinedToolGroup): elif isinstance(tool_group, UserDefinedToolGroup):
for tool in tool_group.tools: for tool in tool_group.tools:
tools.append( tools.append(
Tool( Tool(
identifier=tool.name, identifier=tool.name,
tool_group=name, tool_group=tool_group.name,
name=tool.name, name=tool.name,
description=tool.description, description=tool.description,
parameters=tool.parameters, parameters=tool.parameters,
provider_id=provider_id, provider_id=provider_id,
tool_prompt_format=tool.tool_prompt_format, tool_prompt_format=tool.tool_prompt_format,
provider_resource_id=tool.name, provider_resource_id=tool.name,
metadata=tool.metadata,
) )
) )
else: else:
raise ValueError(f"Unknown tool group: {tool_group}") raise ValueError(f"Unknown tool group: {tool_group}")
for tool in tools: for tool in tools:
existing_tool = await self.get_tool(tool.name) existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists # Compare existing and new object if one exists
if existing_tool: if existing_tool:
# Compare all fields except provider_id since that might be None in new obj # Compare all fields except provider_id since that might be None in new obj

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .brave_search import BraveSearchToolRuntimeImpl
from .config import BraveSearchToolConfig
class BraveSearchToolProviderDataValidator(BaseModel):
api_key: str
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
impl = BraveSearchToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,123 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
import requests
from llama_stack.apis.tools import Tool, ToolGroup, ToolInvocationResult, ToolRuntime
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import BraveSearchToolConfig
class BraveSearchToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: BraveSearchToolConfig):
self.config = config
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
if tool.identifier != "brave_search":
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
async def unregister_tool(self, tool_id: str) -> None:
return
def _get_api_key(self) -> str:
if self.config.api_key:
return self.config.api_key
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key:
raise ValueError(
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
)
return provider_data.api_key
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
raise NotImplementedError("Brave search tool group not supported")
async def invoke_tool(
self, tool_id: str, args: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"X-Subscription-Token": api_key,
"Accept-Encoding": "gzip",
"Accept": "application/json",
}
payload = {"q": args["query"]}
response = requests.get(url=url, params=payload, headers=headers)
response.raise_for_status()
results = self._clean_brave_response(response.json())
content_items = "\n".join([str(result) for result in results])
return ToolInvocationResult(
content=content_items,
)
def _clean_brave_response(self, search_response):
clean_response = []
if "mixed" in search_response:
mixed_results = search_response["mixed"]
for m in mixed_results["main"][: self.config.max_results]:
r_type = m["type"]
results = search_response[r_type]["results"]
cleaned = self._clean_result_by_type(r_type, results, m.get("index"))
clean_response.append(cleaned)
return clean_response
def _clean_result_by_type(self, r_type, results, idx=None):
type_cleaners = {
"web": (
["type", "title", "url", "description", "date", "extra_snippets"],
lambda x: x[idx],
),
"faq": (["type", "question", "answer", "title", "url"], lambda x: x),
"infobox": (
["type", "title", "url", "description", "long_desc"],
lambda x: x[idx],
),
"videos": (["type", "url", "title", "description", "date"], lambda x: x),
"locations": (
[
"type",
"title",
"url",
"description",
"coordinates",
"postal_address",
"contact",
"rating",
"distance",
"zoom_level",
],
lambda x: x,
),
"news": (["type", "title", "url", "description"], lambda x: x),
}
if r_type not in type_cleaners:
return ""
selected_keys, result_selector = type_cleaners[r_type]
results = result_selector(results)
if isinstance(results, list):
cleaned = [
{k: v for k, v in item.items() if k in selected_keys}
for item in results
]
else:
cleaned = {k: v for k, v in results.items() if k in selected_keys}
return str(cleaned)

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel, Field
class BraveSearchToolConfig(BaseModel):
api_key: Optional[str] = Field(
default=None,
description="The Brave Search API Key",
)
max_results: int = Field(
default=3,
description="The maximum number of results to return",
)

View file

@ -1,14 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import MetaReferenceToolRuntimeConfig
from .meta_reference import MetaReferenceToolRuntimeImpl
async def get_provider_impl(config: MetaReferenceToolRuntimeConfig, _deps):
impl = MetaReferenceToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -1,32 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import MetaReferenceToolRuntimeConfig
class MetaReferenceToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: MetaReferenceToolRuntimeConfig):
self.config = config
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
print(f"registering tool {tool.identifier}")
pass
async def unregister_tool(self, tool_id: str) -> None:
pass
async def invoke_tool(
self, tool_id: str, args: Dict[str, Any]
) -> ToolInvocationResult:
pass

View file

@ -6,16 +6,32 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec from llama_stack.distribution.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
def available_providers() -> List[ProviderSpec]: def available_providers() -> List[ProviderSpec]:
return [ return [
InlineProviderSpec( InlineProviderSpec(
api=Api.tool_runtime, api=Api.tool_runtime,
provider_type="inline::meta-reference", provider_type="inline::brave-search",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.meta_reference", module="llama_stack.providers.inline.tool_runtime.brave_search",
config_class="llama_stack.providers.inline.tool_runtime.meta_reference.MetaReferenceToolRuntimeConfig", config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="model-context-protocol",
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
pip_packages=["mcp"],
),
), ),
] ]

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import ModelContextProtocolConfig
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
class ModelContextProtocolToolProviderDataValidator(BaseModel):
api_key: str
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
impl = ModelContextProtocolToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -7,5 +7,5 @@
from pydantic import BaseModel from pydantic import BaseModel
class MetaReferenceToolRuntimeConfig(BaseModel): class ModelContextProtocolConfig(BaseModel):
pass pass

View file

@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from urllib.parse import urlparse
from llama_stack.apis.tools import (
MCPToolGroup,
Tool,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from mcp import ClientSession
from mcp.client.sse import sse_client
from .config import ModelContextProtocolConfig
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: ModelContextProtocolConfig):
self.config = config
async def initialize(self):
pass
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
if not isinstance(tool_group, MCPToolGroup):
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
tools = []
async with sse_client(tool_group.endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get(
"properties", {}
).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
Tool(
identifier=tool.name,
description=tool.description,
tool_group=tool_group.name,
parameters=parameters,
metadata={
"endpoint": tool_group.endpoint.uri,
},
provider_resource_id=tool.name,
)
)
return tools
async def invoke_tool(
self, tool_id: str, args: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_id)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
raise ValueError(f"Tool {tool_id} does not have metadata")
endpoint = tool.metadata.get("endpoint")
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
async with sse_client(endpoint) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
result = await session.call_tool(tool.identifier, args)
return ToolInvocationResult(
content="\n".join([result.model_dump_json() for result in result.content]),
error_code=1 if result.isError else 0,
)