Enable Bing search (#59)

* add tool for bing search

* simplify search tool and enable configuration for search engine

* dropped commented code

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
Hardik Shah 2024-09-10 12:34:29 -07:00 committed by GitHub
parent 2b63074676
commit a11d92601b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 87 additions and 18 deletions

View file

@ -41,11 +41,19 @@ class ToolDefinitionCommon(BaseModel):
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
class SearchEngineType(Enum):
bing = "bing"
brave = "brave"
@json_schema_type
class BraveSearchToolDefinition(ToolDefinitionCommon):
class SearchToolDefinition(ToolDefinitionCommon):
# NOTE: brave_search is just a placeholder since model always uses
# brave_search as tool call name
type: Literal[AgenticSystemTool.brave_search.value] = (
AgenticSystemTool.brave_search.value
)
engine: SearchEngineType = SearchEngineType.brave
remote_execution: Optional[RestAPIExecutionConfig] = None
@ -163,7 +171,7 @@ class MemoryToolDefinition(ToolDefinitionCommon):
AgenticSystemToolDefinition = Annotated[
Union[
BraveSearchToolDefinition,
SearchToolDefinition,
WolframAlphaToolDefinition,
PhotogenToolDefinition,
CodeInterpreterToolDefinition,

View file

@ -134,7 +134,7 @@ async def run_main(host: str, port: int):
api = AgenticSystemClient(f"http://{host}:{port}")
tool_definitions = [
BraveSearchToolDefinition(),
SearchToolDefinition(engine=SearchEngineType.bing),
WolframAlphaToolDefinition(),
CodeInterpreterToolDefinition(),
]

View file

@ -710,7 +710,7 @@ class ChatAgent(ShieldRunnerMixin):
def _get_tools(self) -> List[ToolDefinition]:
ret = []
for t in self.agent_config.tools:
if isinstance(t, BraveSearchToolDefinition):
if isinstance(t, SearchToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
elif isinstance(t, WolframAlphaToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))

View file

@ -15,9 +15,9 @@ from llama_toolchain.memory.api import Memory
from llama_toolchain.safety.api import Safety
from llama_toolchain.agentic_system.api import * # noqa: F403
from llama_toolchain.tools.builtin import (
BraveSearchTool,
CodeInterpreterTool,
PhotogenTool,
SearchTool,
WolframAlphaTool,
)
from llama_toolchain.tools.safety import with_safety
@ -62,17 +62,19 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
if not key:
raise ValueError("Wolfram API key not defined in config")
tool = WolframAlphaTool(key)
elif isinstance(tool_defn, BraveSearchToolDefinition):
key = self.config.brave_search_api_key
elif isinstance(tool_defn, SearchToolDefinition):
key = None
if tool_defn.engine == SearchEngineType.brave:
key = self.config.brave_search_api_key
elif tool_defn.engine == SearchEngineType.bing:
key = self.config.bing_search_api_key
if not key:
raise ValueError("Brave API key not defined in config")
tool = BraveSearchTool(key)
raise ValueError("API key not defined in config")
tool = SearchTool(tool_defn.engine, key)
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition):
tool = PhotogenTool(
dump_dir=tempfile.mkdtemp(),
)
tool = PhotogenTool(dump_dir=tempfile.mkdtemp())
else:
continue

View file

@ -11,4 +11,5 @@ from pydantic import BaseModel
class MetaReferenceImplConfig(BaseModel):
brave_search_api_key: Optional[str] = None
bing_search_api_key: Optional[str] = None
wolfram_api_key: Optional[str] = None