From a11d92601bcc7bfe3d22432a80ab94ecea3d2730 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Tue, 10 Sep 2024 12:34:29 -0700 Subject: [PATCH] Enable Bing search (#59) * add tool for bing search * simplify search tool and enable configuration for search engine * dropped commented code --------- Co-authored-by: Hardik Shah --- llama_toolchain/agentic_system/api/api.py | 12 +++- llama_toolchain/agentic_system/client.py | 2 +- .../meta_reference/agent_instance.py | 2 +- .../meta_reference/agentic_system.py | 18 +++--- .../agentic_system/meta_reference/config.py | 1 + llama_toolchain/common/deployment_types.py | 8 +-- llama_toolchain/tools/builtin.py | 62 ++++++++++++++++++- 7 files changed, 87 insertions(+), 18 deletions(-) diff --git a/llama_toolchain/agentic_system/api/api.py b/llama_toolchain/agentic_system/api/api.py index 68ec980e6..b8be54861 100644 --- a/llama_toolchain/agentic_system/api/api.py +++ b/llama_toolchain/agentic_system/api/api.py @@ -41,11 +41,19 @@ class ToolDefinitionCommon(BaseModel): output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) +class SearchEngineType(Enum): + bing = "bing" + brave = "brave" + + @json_schema_type -class BraveSearchToolDefinition(ToolDefinitionCommon): +class SearchToolDefinition(ToolDefinitionCommon): + # NOTE: brave_search is just a placeholder since model always uses + # brave_search as tool call name type: Literal[AgenticSystemTool.brave_search.value] = ( AgenticSystemTool.brave_search.value ) + engine: SearchEngineType = SearchEngineType.brave remote_execution: Optional[RestAPIExecutionConfig] = None @@ -163,7 +171,7 @@ class MemoryToolDefinition(ToolDefinitionCommon): AgenticSystemToolDefinition = Annotated[ Union[ - BraveSearchToolDefinition, + SearchToolDefinition, WolframAlphaToolDefinition, PhotogenToolDefinition, CodeInterpreterToolDefinition, diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index fadb78182..b47e402f0 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -134,7 +134,7 @@ async def run_main(host: str, port: int): api = AgenticSystemClient(f"http://{host}:{port}") tool_definitions = [ - BraveSearchToolDefinition(), + SearchToolDefinition(engine=SearchEngineType.bing), WolframAlphaToolDefinition(), CodeInterpreterToolDefinition(), ] diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 4d38e0032..36c3d19e8 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -710,7 +710,7 @@ class ChatAgent(ShieldRunnerMixin): def _get_tools(self) -> List[ToolDefinition]: ret = [] for t in self.agent_config.tools: - if isinstance(t, BraveSearchToolDefinition): + if isinstance(t, SearchToolDefinition): ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search)) elif isinstance(t, WolframAlphaToolDefinition): ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha)) diff --git a/llama_toolchain/agentic_system/meta_reference/agentic_system.py b/llama_toolchain/agentic_system/meta_reference/agentic_system.py index 09fbfdde5..9caa3a75b 100644 --- a/llama_toolchain/agentic_system/meta_reference/agentic_system.py +++ b/llama_toolchain/agentic_system/meta_reference/agentic_system.py @@ -15,9 +15,9 @@ from llama_toolchain.memory.api import Memory from llama_toolchain.safety.api import Safety from llama_toolchain.agentic_system.api import * # noqa: F403 from llama_toolchain.tools.builtin import ( - BraveSearchTool, CodeInterpreterTool, PhotogenTool, + SearchTool, WolframAlphaTool, ) from llama_toolchain.tools.safety import with_safety @@ -62,17 +62,19 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): if not key: raise ValueError("Wolfram API key not defined in config") tool = WolframAlphaTool(key) - elif isinstance(tool_defn, BraveSearchToolDefinition): - key = self.config.brave_search_api_key + elif isinstance(tool_defn, SearchToolDefinition): + key = None + if tool_defn.engine == SearchEngineType.brave: + key = self.config.brave_search_api_key + elif tool_defn.engine == SearchEngineType.bing: + key = self.config.bing_search_api_key if not key: - raise ValueError("Brave API key not defined in config") - tool = BraveSearchTool(key) + raise ValueError("API key not defined in config") + tool = SearchTool(tool_defn.engine, key) elif isinstance(tool_defn, CodeInterpreterToolDefinition): tool = CodeInterpreterTool() elif isinstance(tool_defn, PhotogenToolDefinition): - tool = PhotogenTool( - dump_dir=tempfile.mkdtemp(), - ) + tool = PhotogenTool(dump_dir=tempfile.mkdtemp()) else: continue diff --git a/llama_toolchain/agentic_system/meta_reference/config.py b/llama_toolchain/agentic_system/meta_reference/config.py index 367ab17a5..f1a92f2e7 100644 --- a/llama_toolchain/agentic_system/meta_reference/config.py +++ b/llama_toolchain/agentic_system/meta_reference/config.py @@ -11,4 +11,5 @@ from pydantic import BaseModel class MetaReferenceImplConfig(BaseModel): brave_search_api_key: Optional[str] = None + bing_search_api_key: Optional[str] = None wolfram_api_key: Optional[str] = None diff --git a/llama_toolchain/common/deployment_types.py b/llama_toolchain/common/deployment_types.py index 8b67eff0d..af05aaae4 100644 --- a/llama_toolchain/common/deployment_types.py +++ b/llama_toolchain/common/deployment_types.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Dict, Optional +from typing import Any, Dict, Optional from llama_models.llama3.api.datatypes import URL @@ -26,6 +26,6 @@ class RestAPIMethod(Enum): class RestAPIExecutionConfig(BaseModel): url: URL method: RestAPIMethod - params: Optional[Dict[str, str]] = None - headers: Optional[Dict[str, str]] = None - body: Optional[Dict[str, str]] = None + params: Optional[Dict[str, Any]] = None + headers: Optional[Dict[str, Any]] = None + body: Optional[Dict[str, Any]] = None diff --git a/llama_toolchain/tools/builtin.py b/llama_toolchain/tools/builtin.py index 3a53e2e26..56fda3723 100644 --- a/llama_toolchain/tools/builtin.py +++ b/llama_toolchain/tools/builtin.py @@ -83,14 +83,72 @@ class PhotogenTool(SingleMessageBuiltinTool): raise NotImplementedError() -class BraveSearchTool(SingleMessageBuiltinTool): - def __init__(self, api_key: str) -> None: +class SearchTool(SingleMessageBuiltinTool): + def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None: self.api_key = api_key + if engine == SearchEngineType.bing: + self.engine = BingSearch(api_key, **kwargs) + elif engine == SearchEngineType.brave: + self.engine = BraveSearch(api_key, **kwargs) + else: + raise ValueError(f"Unknown search engine: {engine}") def get_name(self) -> str: return BuiltinTool.brave_search.value async def run_impl(self, query: str) -> str: + return await self.engine.search(query) + + +class BingSearch: + def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None: + self.api_key = api_key + self.top_k = top_k + + async def search(self, query: str) -> str: + url = "https://api.bing.microsoft.com/v7.0/search" + headers = { + "Ocp-Apim-Subscription-Key": self.api_key, + } + params = { + "count": self.top_k, + "textDecorations": True, + "textFormat": "HTML", + "q": query, + } + + response = requests.get(url=url, params=params, headers=headers) + response.raise_for_status() + clean = self._clean_response(response.json()) + return json.dumps(clean) + + def _clean_response(self, search_response): + clean_response = [] + query = search_response["queryContext"]["originalQuery"] + if "webPages" in search_response: + pages = search_response["webPages"]["value"] + for p in pages: + selected_keys = {"name", "url", "snippet"} + clean_response.append( + {k: v for k, v in p.items() if k in selected_keys} + ) + if "news" in search_response: + clean_news = [] + news = search_response["news"]["value"] + for n in news: + selected_keys = {"name", "url", "description"} + clean_news.append({k: v for k, v in n.items() if k in selected_keys}) + + clean_response.append(clean_news) + + return {"query": query, "top_k": clean_response} + + +class BraveSearch: + def __init__(self, api_key: str) -> None: + self.api_key = api_key + + async def search(self, query: str) -> str: url = "https://api.search.brave.com/res/v1/web/search" headers = { "X-Subscription-Token": self.api_key,