mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
simplify search tool and enable configuration for search engine
This commit is contained in:
parent
640c5f8ab9
commit
bdede6d14e
6 changed files with 56 additions and 48 deletions
|
@ -14,8 +14,6 @@ from typing import List, Optional
|
|||
import requests
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
|
||||
|
||||
from .ipython_tool.code_execution import (
|
||||
CodeExecutionContext,
|
||||
CodeExecutionRequest,
|
||||
|
@ -85,21 +83,41 @@ class PhotogenTool(SingleMessageBuiltinTool):
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RemoteSearchTool(SingleMessageBuiltinTool):
|
||||
def __init__(self, config: RestAPIExecutionConfig) -> None:
|
||||
self.config = config
|
||||
class SearchTool(SingleMessageBuiltinTool):
|
||||
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
|
||||
self.api_key = api_key
|
||||
if engine == SearchEngineType.bing:
|
||||
self.engine = BingSearch(api_key, **kwargs)
|
||||
elif engine == SearchEngineType.brave:
|
||||
self.engine = BraveSearch(api_key, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown search engine: {engine}")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return BuiltinTool.brave_search.value
|
||||
|
||||
async def run_impl(self, query: str) -> str:
|
||||
params = self.config.params.copy()
|
||||
params["q"] = query
|
||||
response = requests.get(
|
||||
url=self.config.url,
|
||||
params=params,
|
||||
headers=self.config.headers,
|
||||
)
|
||||
return await self.engine.search(query)
|
||||
|
||||
|
||||
class BingSearch:
|
||||
def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None:
|
||||
self.api_key = api_key
|
||||
self.top_k = top_k
|
||||
|
||||
async def search(self, query: str) -> str:
|
||||
url = "https://api.bing.microsoft.com/v7.0/search"
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
}
|
||||
params = {
|
||||
"count": self.top_k,
|
||||
"textDecorations": True,
|
||||
"textFormat": "HTML",
|
||||
"q": query,
|
||||
}
|
||||
|
||||
response = requests.get(url=url, params=params, headers=headers)
|
||||
response.raise_for_status()
|
||||
clean = self._clean_response(response.json())
|
||||
return json.dumps(clean)
|
||||
|
@ -126,14 +144,11 @@ class RemoteSearchTool(SingleMessageBuiltinTool):
|
|||
return {"query": query, "top_k": clean_response}
|
||||
|
||||
|
||||
class BraveSearchTool(SingleMessageBuiltinTool):
|
||||
class BraveSearch:
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
|
||||
def get_name(self) -> str:
|
||||
return BuiltinTool.brave_search.value
|
||||
|
||||
async def run_impl(self, query: str) -> str:
|
||||
async def search(self, query: str) -> str:
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
headers = {
|
||||
"X-Subscription-Token": self.api_key,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue