diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py b/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py index 418f0fd5a..e9f0eeae8 100644 --- a/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py @@ -4,10 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from pydantic import BaseModel + from .brave_search import BraveSearchToolRuntimeImpl from .config import BraveSearchToolConfig +class BraveSearchToolProviderDataValidator(BaseModel): + api_key: str + + async def get_provider_impl(config: BraveSearchToolConfig, _deps): impl = BraveSearchToolRuntimeImpl(config) await impl.initialize() diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py index 0ecf9f9c3..cb673d88f 100644 --- a/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py @@ -4,17 +4,20 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List +from typing import Any, Dict import requests -from llama_stack.apis.tools import InvokeToolResult, Tool, ToolRuntime +from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime +from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ToolsProtocolPrivate from .config import BraveSearchToolConfig -class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): +class BraveSearchToolRuntimeImpl( + ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData +): def __init__(self, config: BraveSearchToolConfig): self.config = config @@ -28,24 +31,35 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): async def unregister_tool(self, tool_id: str) -> None: return - async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> InvokeToolResult: - results = await self.execute(args["query"]) - content_items = "\n".join([str(result) for result in results]) - return InvokeToolResult( - content=content_items, - ) + def _get_api_key(self) -> str: + if self.config.api_key: + return self.config.api_key - async def execute(self, query: str) -> List[dict]: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.api_key: + raise ValueError( + 'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": }' + ) + return provider_data.api_key + + async def invoke_tool( + self, tool_id: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + api_key = self._get_api_key() url = "https://api.search.brave.com/res/v1/web/search" headers = { - "X-Subscription-Token": self.config.api_key, + "X-Subscription-Token": api_key, "Accept-Encoding": "gzip", "Accept": "application/json", } - payload = {"q": query} + payload = {"q": args["query"]} response = requests.get(url=url, params=payload, headers=headers) response.raise_for_status() - return self._clean_brave_response(response.json()) + results = self._clean_brave_response(response.json()) + content_items = "\n".join([str(result) for result in results]) + return ToolInvocationResult( + content=content_items, + ) def _clean_brave_response(self, search_response): clean_response = [] diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 64dc520fd..a732845be 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -17,5 +17,6 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[], module="llama_stack.providers.inline.tool_runtime.brave_search", config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig", + provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", ), ]