mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 16:54:42 +00:00
rebase and address feedback
This commit is contained in:
parent
1c770508df
commit
71434d67f3
3 changed files with 34 additions and 13 deletions
|
@ -4,10 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from .brave_search import BraveSearchToolRuntimeImpl
|
from .brave_search import BraveSearchToolRuntimeImpl
|
||||||
from .config import BraveSearchToolConfig
|
from .config import BraveSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BraveSearchToolProviderDataValidator(BaseModel):
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
|
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
|
||||||
impl = BraveSearchToolRuntimeImpl(config)
|
impl = BraveSearchToolRuntimeImpl(config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
|
|
|
@ -4,17 +4,20 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from llama_stack.apis.tools import InvokeToolResult, Tool, ToolRuntime
|
from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
from .config import BraveSearchToolConfig
|
from .config import BraveSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
class BraveSearchToolRuntimeImpl(
|
||||||
|
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||||
|
):
|
||||||
def __init__(self, config: BraveSearchToolConfig):
|
def __init__(self, config: BraveSearchToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
@ -28,24 +31,35 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> InvokeToolResult:
|
def _get_api_key(self) -> str:
|
||||||
results = await self.execute(args["query"])
|
if self.config.api_key:
|
||||||
content_items = "\n".join([str(result) for result in results])
|
return self.config.api_key
|
||||||
return InvokeToolResult(
|
|
||||||
content=content_items,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def execute(self, query: str) -> List[dict]:
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
return provider_data.api_key
|
||||||
|
|
||||||
|
async def invoke_tool(
|
||||||
|
self, tool_id: str, args: Dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
api_key = self._get_api_key()
|
||||||
url = "https://api.search.brave.com/res/v1/web/search"
|
url = "https://api.search.brave.com/res/v1/web/search"
|
||||||
headers = {
|
headers = {
|
||||||
"X-Subscription-Token": self.config.api_key,
|
"X-Subscription-Token": api_key,
|
||||||
"Accept-Encoding": "gzip",
|
"Accept-Encoding": "gzip",
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
}
|
}
|
||||||
payload = {"q": query}
|
payload = {"q": args["query"]}
|
||||||
response = requests.get(url=url, params=payload, headers=headers)
|
response = requests.get(url=url, params=payload, headers=headers)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return self._clean_brave_response(response.json())
|
results = self._clean_brave_response(response.json())
|
||||||
|
content_items = "\n".join([str(result) for result in results])
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=content_items,
|
||||||
|
)
|
||||||
|
|
||||||
def _clean_brave_response(self, search_response):
|
def _clean_brave_response(self, search_response):
|
||||||
clean_response = []
|
clean_response = []
|
||||||
|
|
|
@ -17,5 +17,6 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.inline.tool_runtime.brave_search",
|
module="llama_stack.providers.inline.tool_runtime.brave_search",
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue