forked from phoenix-oss/llama-stack-mirror
Enable Bing search (#59)
* add tool for bing search * simplify search tool and enable configuration for search engine * dropped commented code --------- Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
parent
2b63074676
commit
a11d92601b
7 changed files with 87 additions and 18 deletions
|
@ -41,11 +41,19 @@ class ToolDefinitionCommon(BaseModel):
|
|||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SearchEngineType(Enum):
|
||||
bing = "bing"
|
||||
brave = "brave"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BraveSearchToolDefinition(ToolDefinitionCommon):
|
||||
class SearchToolDefinition(ToolDefinitionCommon):
|
||||
# NOTE: brave_search is just a placeholder since model always uses
|
||||
# brave_search as tool call name
|
||||
type: Literal[AgenticSystemTool.brave_search.value] = (
|
||||
AgenticSystemTool.brave_search.value
|
||||
)
|
||||
engine: SearchEngineType = SearchEngineType.brave
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
|
@ -163,7 +171,7 @@ class MemoryToolDefinition(ToolDefinitionCommon):
|
|||
|
||||
AgenticSystemToolDefinition = Annotated[
|
||||
Union[
|
||||
BraveSearchToolDefinition,
|
||||
SearchToolDefinition,
|
||||
WolframAlphaToolDefinition,
|
||||
PhotogenToolDefinition,
|
||||
CodeInterpreterToolDefinition,
|
||||
|
|
|
@ -134,7 +134,7 @@ async def run_main(host: str, port: int):
|
|||
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||
|
||||
tool_definitions = [
|
||||
BraveSearchToolDefinition(),
|
||||
SearchToolDefinition(engine=SearchEngineType.bing),
|
||||
WolframAlphaToolDefinition(),
|
||||
CodeInterpreterToolDefinition(),
|
||||
]
|
||||
|
|
|
@ -710,7 +710,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
def _get_tools(self) -> List[ToolDefinition]:
|
||||
ret = []
|
||||
for t in self.agent_config.tools:
|
||||
if isinstance(t, BraveSearchToolDefinition):
|
||||
if isinstance(t, SearchToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
|
||||
elif isinstance(t, WolframAlphaToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
|
||||
|
|
|
@ -15,9 +15,9 @@ from llama_toolchain.memory.api import Memory
|
|||
from llama_toolchain.safety.api import Safety
|
||||
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||
from llama_toolchain.tools.builtin import (
|
||||
BraveSearchTool,
|
||||
CodeInterpreterTool,
|
||||
PhotogenTool,
|
||||
SearchTool,
|
||||
WolframAlphaTool,
|
||||
)
|
||||
from llama_toolchain.tools.safety import with_safety
|
||||
|
@ -62,17 +62,19 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
|||
if not key:
|
||||
raise ValueError("Wolfram API key not defined in config")
|
||||
tool = WolframAlphaTool(key)
|
||||
elif isinstance(tool_defn, BraveSearchToolDefinition):
|
||||
key = self.config.brave_search_api_key
|
||||
elif isinstance(tool_defn, SearchToolDefinition):
|
||||
key = None
|
||||
if tool_defn.engine == SearchEngineType.brave:
|
||||
key = self.config.brave_search_api_key
|
||||
elif tool_defn.engine == SearchEngineType.bing:
|
||||
key = self.config.bing_search_api_key
|
||||
if not key:
|
||||
raise ValueError("Brave API key not defined in config")
|
||||
tool = BraveSearchTool(key)
|
||||
raise ValueError("API key not defined in config")
|
||||
tool = SearchTool(tool_defn.engine, key)
|
||||
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
|
||||
tool = CodeInterpreterTool()
|
||||
elif isinstance(tool_defn, PhotogenToolDefinition):
|
||||
tool = PhotogenTool(
|
||||
dump_dir=tempfile.mkdtemp(),
|
||||
)
|
||||
tool = PhotogenTool(dump_dir=tempfile.mkdtemp())
|
||||
else:
|
||||
continue
|
||||
|
||||
|
|
|
@ -11,4 +11,5 @@ from pydantic import BaseModel
|
|||
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
brave_search_api_key: Optional[str] = None
|
||||
bing_search_api_key: Optional[str] = None
|
||||
wolfram_api_key: Optional[str] = None
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import URL
|
||||
|
||||
|
@ -26,6 +26,6 @@ class RestAPIMethod(Enum):
|
|||
class RestAPIExecutionConfig(BaseModel):
|
||||
url: URL
|
||||
method: RestAPIMethod
|
||||
params: Optional[Dict[str, str]] = None
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
body: Optional[Dict[str, str]] = None
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
headers: Optional[Dict[str, Any]] = None
|
||||
body: Optional[Dict[str, Any]] = None
|
||||
|
|
|
@ -83,14 +83,72 @@ class PhotogenTool(SingleMessageBuiltinTool):
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BraveSearchTool(SingleMessageBuiltinTool):
|
||||
def __init__(self, api_key: str) -> None:
|
||||
class SearchTool(SingleMessageBuiltinTool):
|
||||
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
|
||||
self.api_key = api_key
|
||||
if engine == SearchEngineType.bing:
|
||||
self.engine = BingSearch(api_key, **kwargs)
|
||||
elif engine == SearchEngineType.brave:
|
||||
self.engine = BraveSearch(api_key, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown search engine: {engine}")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return BuiltinTool.brave_search.value
|
||||
|
||||
async def run_impl(self, query: str) -> str:
|
||||
return await self.engine.search(query)
|
||||
|
||||
|
||||
class BingSearch:
|
||||
def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None:
|
||||
self.api_key = api_key
|
||||
self.top_k = top_k
|
||||
|
||||
async def search(self, query: str) -> str:
|
||||
url = "https://api.bing.microsoft.com/v7.0/search"
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
}
|
||||
params = {
|
||||
"count": self.top_k,
|
||||
"textDecorations": True,
|
||||
"textFormat": "HTML",
|
||||
"q": query,
|
||||
}
|
||||
|
||||
response = requests.get(url=url, params=params, headers=headers)
|
||||
response.raise_for_status()
|
||||
clean = self._clean_response(response.json())
|
||||
return json.dumps(clean)
|
||||
|
||||
def _clean_response(self, search_response):
|
||||
clean_response = []
|
||||
query = search_response["queryContext"]["originalQuery"]
|
||||
if "webPages" in search_response:
|
||||
pages = search_response["webPages"]["value"]
|
||||
for p in pages:
|
||||
selected_keys = {"name", "url", "snippet"}
|
||||
clean_response.append(
|
||||
{k: v for k, v in p.items() if k in selected_keys}
|
||||
)
|
||||
if "news" in search_response:
|
||||
clean_news = []
|
||||
news = search_response["news"]["value"]
|
||||
for n in news:
|
||||
selected_keys = {"name", "url", "description"}
|
||||
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
|
||||
|
||||
clean_response.append(clean_news)
|
||||
|
||||
return {"query": query, "top_k": clean_response}
|
||||
|
||||
|
||||
class BraveSearch:
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
|
||||
async def search(self, query: str) -> str:
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
headers = {
|
||||
"X-Subscription-Token": self.api_key,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue