simplify search tool and enable configuration for search engine

This commit is contained in:
Hardik Shah 2024-09-09 18:41:11 -07:00
parent 640c5f8ab9
commit bdede6d14e
6 changed files with 56 additions and 48 deletions

View file

@ -41,11 +41,19 @@ class ToolDefinitionCommon(BaseModel):
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
class SearchEngineType(Enum):
bing = "bing"
brave = "brave"
@json_schema_type @json_schema_type
class BraveSearchToolDefinition(ToolDefinitionCommon): class SearchToolDefinition(ToolDefinitionCommon):
# NOTE: brave_search is just a placeholder since model always uses
# brave_search as tool call name
type: Literal[AgenticSystemTool.brave_search.value] = ( type: Literal[AgenticSystemTool.brave_search.value] = (
AgenticSystemTool.brave_search.value AgenticSystemTool.brave_search.value
) )
engine: SearchEngineType = SearchEngineType.brave
remote_execution: Optional[RestAPIExecutionConfig] = None remote_execution: Optional[RestAPIExecutionConfig] = None
@ -163,7 +171,7 @@ class MemoryToolDefinition(ToolDefinitionCommon):
AgenticSystemToolDefinition = Annotated[ AgenticSystemToolDefinition = Annotated[
Union[ Union[
BraveSearchToolDefinition, SearchToolDefinition,
WolframAlphaToolDefinition, WolframAlphaToolDefinition,
PhotogenToolDefinition, PhotogenToolDefinition,
CodeInterpreterToolDefinition, CodeInterpreterToolDefinition,

View file

@ -16,7 +16,6 @@ from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
from llama_toolchain.core.datatypes import RemoteProviderConfig from llama_toolchain.core.datatypes import RemoteProviderConfig
from .api import * # noqa: F403 from .api import * # noqa: F403
@ -135,21 +134,8 @@ async def run_main(host: str, port: int):
api = AgenticSystemClient(f"http://{host}:{port}") api = AgenticSystemClient(f"http://{host}:{port}")
tool_definitions = [ tool_definitions = [
BraveSearchToolDefinition( SearchToolDefinition(engine=SearchEngineType.bing),
remote_execution=RestAPIExecutionConfig( WolframAlphaToolDefinition(),
url=URL(uri="https://api.bing.microsoft.com/v7.0/search"),
method=RestAPIMethod.GET,
headers={
"Ocp-Apim-Subscription-Key": "2259b3f9e0cb4fc9b968bb3b02ab13e7"
},
params={
"count": 3,
"textDecorations": True,
"textFormat": "HTML",
},
)
),
# WolframAlphaToolDefinition(),
CodeInterpreterToolDefinition(), CodeInterpreterToolDefinition(),
] ]
tool_definitions += [ tool_definitions += [

View file

@ -710,7 +710,7 @@ class ChatAgent(ShieldRunnerMixin):
def _get_tools(self) -> List[ToolDefinition]: def _get_tools(self) -> List[ToolDefinition]:
ret = [] ret = []
for t in self.agent_config.tools: for t in self.agent_config.tools:
if isinstance(t, BraveSearchToolDefinition): if isinstance(t, SearchToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search)) ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
elif isinstance(t, WolframAlphaToolDefinition): elif isinstance(t, WolframAlphaToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha)) ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))

View file

@ -15,10 +15,9 @@ from llama_toolchain.memory.api import Memory
from llama_toolchain.safety.api import Safety from llama_toolchain.safety.api import Safety
from llama_toolchain.agentic_system.api import * # noqa: F403 from llama_toolchain.agentic_system.api import * # noqa: F403
from llama_toolchain.tools.builtin import ( from llama_toolchain.tools.builtin import (
BraveSearchTool,
CodeInterpreterTool, CodeInterpreterTool,
PhotogenTool, PhotogenTool,
RemoteSearchTool, SearchTool,
WolframAlphaTool, WolframAlphaTool,
) )
from llama_toolchain.tools.safety import with_safety from llama_toolchain.tools.safety import with_safety
@ -63,20 +62,19 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
if not key: if not key:
raise ValueError("Wolfram API key not defined in config") raise ValueError("Wolfram API key not defined in config")
tool = WolframAlphaTool(key) tool = WolframAlphaTool(key)
elif isinstance(tool_defn, BraveSearchToolDefinition): elif isinstance(tool_defn, SearchToolDefinition):
if tool_defn.remote_execution is not None: key = None
tool = RemoteSearchTool(tool_defn.remote_execution) if tool_defn.engine == SearchEngineType.brave:
else:
key = self.config.brave_search_api_key key = self.config.brave_search_api_key
if not key: elif tool_defn.engine == SearchEngineType.bing:
raise ValueError("Brave API key not defined in config") key = self.config.bing_search_api_key
tool = BraveSearchTool(key) if not key:
raise ValueError("API key not defined in config")
tool = SearchTool(tool_defn.engine, key)
elif isinstance(tool_defn, CodeInterpreterToolDefinition): elif isinstance(tool_defn, CodeInterpreterToolDefinition):
tool = CodeInterpreterTool() tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition): elif isinstance(tool_defn, PhotogenToolDefinition):
tool = PhotogenTool( tool = PhotogenTool(dump_dir=tempfile.mkdtemp())
dump_dir=tempfile.mkdtemp(),
)
else: else:
continue continue

View file

@ -11,4 +11,5 @@ from pydantic import BaseModel
class MetaReferenceImplConfig(BaseModel): class MetaReferenceImplConfig(BaseModel):
brave_search_api_key: Optional[str] = None brave_search_api_key: Optional[str] = None
bing_search_api_key: Optional[str] = None
wolfram_api_key: Optional[str] = None wolfram_api_key: Optional[str] = None

View file

@ -14,8 +14,6 @@ from typing import List, Optional
import requests import requests
from termcolor import cprint from termcolor import cprint
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
from .ipython_tool.code_execution import ( from .ipython_tool.code_execution import (
CodeExecutionContext, CodeExecutionContext,
CodeExecutionRequest, CodeExecutionRequest,
@ -85,21 +83,41 @@ class PhotogenTool(SingleMessageBuiltinTool):
raise NotImplementedError() raise NotImplementedError()
class RemoteSearchTool(SingleMessageBuiltinTool): class SearchTool(SingleMessageBuiltinTool):
def __init__(self, config: RestAPIExecutionConfig) -> None: def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
self.config = config self.api_key = api_key
if engine == SearchEngineType.bing:
self.engine = BingSearch(api_key, **kwargs)
elif engine == SearchEngineType.brave:
self.engine = BraveSearch(api_key, **kwargs)
else:
raise ValueError(f"Unknown search engine: {engine}")
def get_name(self) -> str: def get_name(self) -> str:
return BuiltinTool.brave_search.value return BuiltinTool.brave_search.value
async def run_impl(self, query: str) -> str: async def run_impl(self, query: str) -> str:
params = self.config.params.copy() return await self.engine.search(query)
params["q"] = query
response = requests.get(
url=self.config.url, class BingSearch:
params=params, def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None:
headers=self.config.headers, self.api_key = api_key
) self.top_k = top_k
async def search(self, query: str) -> str:
url = "https://api.bing.microsoft.com/v7.0/search"
headers = {
"Ocp-Apim-Subscription-Key": self.api_key,
}
params = {
"count": self.top_k,
"textDecorations": True,
"textFormat": "HTML",
"q": query,
}
response = requests.get(url=url, params=params, headers=headers)
response.raise_for_status() response.raise_for_status()
clean = self._clean_response(response.json()) clean = self._clean_response(response.json())
return json.dumps(clean) return json.dumps(clean)
@ -126,14 +144,11 @@ class RemoteSearchTool(SingleMessageBuiltinTool):
return {"query": query, "top_k": clean_response} return {"query": query, "top_k": clean_response}
class BraveSearchTool(SingleMessageBuiltinTool): class BraveSearch:
def __init__(self, api_key: str) -> None: def __init__(self, api_key: str) -> None:
self.api_key = api_key self.api_key = api_key
def get_name(self) -> str: async def search(self, query: str) -> str:
return BuiltinTool.brave_search.value
async def run_impl(self, query: str) -> str:
url = "https://api.search.brave.com/res/v1/web/search" url = "https://api.search.brave.com/res/v1/web/search"
headers = { headers = {
"X-Subscription-Token": self.api_key, "X-Subscription-Token": self.api_key,