add brave tool provider (#653)

# What does this PR do?

Adds a new brave tool provider


## Test Plan

```
curl -X POST 'http://localhost:5000/alpha/toolgroups/register' \
-H 'Content-Type: application/json' \
-d '{
  "name": "search",
  "tool_group": {
    "type": "user_defined",
    "tools": [
      {
        "name": "brave_search",
        "description": "A web search tool",
        "parameters": [
          {
            "name": "query",
            "parameter_type": "string",
            "description": "The query to search"
          }
        ],
        "metadata": {},
        "tool_prompt_format": "json"
      }
    ]
  }
}'

curl -X POST 'http://localhost:5000/alpha/tool-runtime/invoke' \
-H 'Content-Type: application/json' \
-d '{
    "tool_id": "brave_search",
    "args": {
        "query": "who is meta ceo"
    }
}' | jq .content
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1973  100  1884  100    89  11288    533 --:--:-- --:--:-- --:--:-- 11885
"{'title': 'Mark Zuckerberg, Founder, Chairman and Chief Executive ...', 'url': 'https://about.meta.com/media-gallery/executives/mark-zuckerberg/', 'description': 'Not Logged In · Please log in to see this page', 'type': 'search_result'}\n{'title': 'Meta - Leadership & Governance', 'url': 'https://investor.fb.com/leadership-and-governance/', 'description': '<strong>Mark Zuckerberg</strong> is the founder, chairman and CEO of Meta, which he originally founded as Facebook in 2004. Mark is responsible for setting the overall direction and product strategy for the company. He leads the design of Meta&#x27;s services and development of its core technology and infrastructure.', 'type': 'search_result'}\n[{'type': 'video_result', 'url': '2372542949/', 'title': 'Mark Zuckerberg, the CEO of Meta, has officially joined the ...', 'description': \"Express Tribune, Karachi, Pakistan. 2,334,400 likes · 36,360 talking about this · 205 were here. The Express Tribune is Pakistan's #1 brand for breaking news in politics, sports, business, lifestyle\"}, {'type': 'video_result', 'url': 'https://www.youtube.com/watch?v=Y3oeQqtRvqk', 'title': \"Meta CEO: Mark Zuckerberg becomes World's Second Richest Person!\", 'description': 'Try VectorVest Risk-Free ➥➥➥ https://www.vectorvest.com/YTUse this link for a FREE Stock Analysis Report ➥➥➥ vectorvest.com/YTFSAVectorVest Merch Store ➥➥➥'}, {'type': 'video_result', 'url': '5348412224/', 'title': '#WATCH | Meta founder and CEO Mark Zuckerberg recently ...', 'description': 'See posts, photos and more on Facebook'}]"

curl -X POST 'http://localhost:5000/alpha/tool-runtime/invoke' \
-H 'Content-Type: application/json' -H 'X-LlamaStack-ProviderData: {"api_key": "<KEY>"}' \
-d '{
    "tool_id": "brave_search",
    "args": {
        "query": "who is meta ceo"
    }
}'
```
This commit is contained in:
Dinesh Yeduguru 2024-12-19 16:15:08 -08:00 committed by GitHub
parent ea0ca7454a
commit a297d27d48
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 316 additions and 58 deletions

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import ModelContextProtocolConfig
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
class ModelContextProtocolToolProviderDataValidator(BaseModel):
api_key: str
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
impl = ModelContextProtocolToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class ModelContextProtocolConfig(BaseModel):
pass

View file

@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from urllib.parse import urlparse
from llama_stack.apis.tools import (
MCPToolGroup,
Tool,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from mcp import ClientSession
from mcp.client.sse import sse_client
from .config import ModelContextProtocolConfig
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: ModelContextProtocolConfig):
self.config = config
async def initialize(self):
pass
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
if not isinstance(tool_group, MCPToolGroup):
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
tools = []
async with sse_client(tool_group.endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get(
"properties", {}
).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
Tool(
identifier=tool.name,
description=tool.description,
tool_group=tool_group.name,
parameters=parameters,
metadata={
"endpoint": tool_group.endpoint.uri,
},
provider_resource_id=tool.name,
)
)
return tools
async def invoke_tool(
self, tool_id: str, args: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_id)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
raise ValueError(f"Tool {tool_id} does not have metadata")
endpoint = tool.metadata.get("endpoint")
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
async with sse_client(endpoint) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
result = await session.call_tool(tool.identifier, args)
return ToolInvocationResult(
content="\n".join([result.model_dump_json() for result in result.content]),
error_code=1 if result.isError else 0,
)