simplify search tool and enable configuration for search engine

This commit is contained in:
Hardik Shah 2024-09-09 18:41:11 -07:00
parent 640c5f8ab9
commit bdede6d14e
6 changed files with 56 additions and 48 deletions

View file

@ -14,8 +14,6 @@ from typing import List, Optional
import requests
from termcolor import cprint
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
from .ipython_tool.code_execution import (
CodeExecutionContext,
CodeExecutionRequest,
@ -85,21 +83,41 @@ class PhotogenTool(SingleMessageBuiltinTool):
raise NotImplementedError()
class RemoteSearchTool(SingleMessageBuiltinTool):
def __init__(self, config: RestAPIExecutionConfig) -> None:
self.config = config
class SearchTool(SingleMessageBuiltinTool):
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
self.api_key = api_key
if engine == SearchEngineType.bing:
self.engine = BingSearch(api_key, **kwargs)
elif engine == SearchEngineType.brave:
self.engine = BraveSearch(api_key, **kwargs)
else:
raise ValueError(f"Unknown search engine: {engine}")
def get_name(self) -> str:
return BuiltinTool.brave_search.value
async def run_impl(self, query: str) -> str:
params = self.config.params.copy()
params["q"] = query
response = requests.get(
url=self.config.url,
params=params,
headers=self.config.headers,
)
return await self.engine.search(query)
class BingSearch:
def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None:
self.api_key = api_key
self.top_k = top_k
async def search(self, query: str) -> str:
url = "https://api.bing.microsoft.com/v7.0/search"
headers = {
"Ocp-Apim-Subscription-Key": self.api_key,
}
params = {
"count": self.top_k,
"textDecorations": True,
"textFormat": "HTML",
"q": query,
}
response = requests.get(url=url, params=params, headers=headers)
response.raise_for_status()
clean = self._clean_response(response.json())
return json.dumps(clean)
@ -126,14 +144,11 @@ class RemoteSearchTool(SingleMessageBuiltinTool):
return {"query": query, "top_k": clean_response}
class BraveSearchTool(SingleMessageBuiltinTool):
class BraveSearch:
def __init__(self, api_key: str) -> None:
self.api_key = api_key
def get_name(self) -> str:
return BuiltinTool.brave_search.value
async def run_impl(self, query: str) -> str:
async def search(self, query: str) -> str:
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"X-Subscription-Token": self.api_key,