rebase and address feedback

This commit is contained in:
Dinesh Yeduguru 2024-12-19 07:38:08 -08:00
parent 1c770508df
commit 71434d67f3
3 changed files with 34 additions and 13 deletions

View file

@ -4,10 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel
from .brave_search import BraveSearchToolRuntimeImpl from .brave_search import BraveSearchToolRuntimeImpl
from .config import BraveSearchToolConfig from .config import BraveSearchToolConfig
class BraveSearchToolProviderDataValidator(BaseModel):
api_key: str
async def get_provider_impl(config: BraveSearchToolConfig, _deps): async def get_provider_impl(config: BraveSearchToolConfig, _deps):
impl = BraveSearchToolRuntimeImpl(config) impl = BraveSearchToolRuntimeImpl(config)
await impl.initialize() await impl.initialize()

View file

@ -4,17 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List from typing import Any, Dict
import requests import requests
from llama_stack.apis.tools import InvokeToolResult, Tool, ToolRuntime from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import BraveSearchToolConfig from .config import BraveSearchToolConfig
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): class BraveSearchToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: BraveSearchToolConfig): def __init__(self, config: BraveSearchToolConfig):
self.config = config self.config = config
@ -28,24 +31,35 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def unregister_tool(self, tool_id: str) -> None: async def unregister_tool(self, tool_id: str) -> None:
return return
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> InvokeToolResult: def _get_api_key(self) -> str:
results = await self.execute(args["query"]) if self.config.api_key:
content_items = "\n".join([str(result) for result in results]) return self.config.api_key
return InvokeToolResult(
content=content_items,
)
async def execute(self, query: str) -> List[dict]: provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key:
raise ValueError(
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
)
return provider_data.api_key
async def invoke_tool(
self, tool_id: str, args: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
url = "https://api.search.brave.com/res/v1/web/search" url = "https://api.search.brave.com/res/v1/web/search"
headers = { headers = {
"X-Subscription-Token": self.config.api_key, "X-Subscription-Token": api_key,
"Accept-Encoding": "gzip", "Accept-Encoding": "gzip",
"Accept": "application/json", "Accept": "application/json",
} }
payload = {"q": query} payload = {"q": args["query"]}
response = requests.get(url=url, params=payload, headers=headers) response = requests.get(url=url, params=payload, headers=headers)
response.raise_for_status() response.raise_for_status()
return self._clean_brave_response(response.json()) results = self._clean_brave_response(response.json())
content_items = "\n".join([str(result) for result in results])
return ToolInvocationResult(
content=content_items,
)
def _clean_brave_response(self, search_response): def _clean_brave_response(self, search_response):
clean_response = [] clean_response = []

View file

@ -17,5 +17,6 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=[], pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.brave_search", module="llama_stack.providers.inline.tool_runtime.brave_search",
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig", config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
), ),
] ]