mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
simplify search tool and enable configuration for search engine
This commit is contained in:
parent
640c5f8ab9
commit
bdede6d14e
6 changed files with 56 additions and 48 deletions
|
@ -41,11 +41,19 @@ class ToolDefinitionCommon(BaseModel):
|
||||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchEngineType(Enum):
|
||||||
|
bing = "bing"
|
||||||
|
brave = "brave"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BraveSearchToolDefinition(ToolDefinitionCommon):
|
class SearchToolDefinition(ToolDefinitionCommon):
|
||||||
|
# NOTE: brave_search is just a placeholder since model always uses
|
||||||
|
# brave_search as tool call name
|
||||||
type: Literal[AgenticSystemTool.brave_search.value] = (
|
type: Literal[AgenticSystemTool.brave_search.value] = (
|
||||||
AgenticSystemTool.brave_search.value
|
AgenticSystemTool.brave_search.value
|
||||||
)
|
)
|
||||||
|
engine: SearchEngineType = SearchEngineType.brave
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -163,7 +171,7 @@ class MemoryToolDefinition(ToolDefinitionCommon):
|
||||||
|
|
||||||
AgenticSystemToolDefinition = Annotated[
|
AgenticSystemToolDefinition = Annotated[
|
||||||
Union[
|
Union[
|
||||||
BraveSearchToolDefinition,
|
SearchToolDefinition,
|
||||||
WolframAlphaToolDefinition,
|
WolframAlphaToolDefinition,
|
||||||
PhotogenToolDefinition,
|
PhotogenToolDefinition,
|
||||||
CodeInterpreterToolDefinition,
|
CodeInterpreterToolDefinition,
|
||||||
|
|
|
@ -16,7 +16,6 @@ from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
|
|
||||||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from .api import * # noqa: F403
|
from .api import * # noqa: F403
|
||||||
|
@ -135,21 +134,8 @@ async def run_main(host: str, port: int):
|
||||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
tool_definitions = [
|
tool_definitions = [
|
||||||
BraveSearchToolDefinition(
|
SearchToolDefinition(engine=SearchEngineType.bing),
|
||||||
remote_execution=RestAPIExecutionConfig(
|
WolframAlphaToolDefinition(),
|
||||||
url=URL(uri="https://api.bing.microsoft.com/v7.0/search"),
|
|
||||||
method=RestAPIMethod.GET,
|
|
||||||
headers={
|
|
||||||
"Ocp-Apim-Subscription-Key": "2259b3f9e0cb4fc9b968bb3b02ab13e7"
|
|
||||||
},
|
|
||||||
params={
|
|
||||||
"count": 3,
|
|
||||||
"textDecorations": True,
|
|
||||||
"textFormat": "HTML",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
),
|
|
||||||
# WolframAlphaToolDefinition(),
|
|
||||||
CodeInterpreterToolDefinition(),
|
CodeInterpreterToolDefinition(),
|
||||||
]
|
]
|
||||||
tool_definitions += [
|
tool_definitions += [
|
||||||
|
|
|
@ -710,7 +710,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
def _get_tools(self) -> List[ToolDefinition]:
|
def _get_tools(self) -> List[ToolDefinition]:
|
||||||
ret = []
|
ret = []
|
||||||
for t in self.agent_config.tools:
|
for t in self.agent_config.tools:
|
||||||
if isinstance(t, BraveSearchToolDefinition):
|
if isinstance(t, SearchToolDefinition):
|
||||||
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
|
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
|
||||||
elif isinstance(t, WolframAlphaToolDefinition):
|
elif isinstance(t, WolframAlphaToolDefinition):
|
||||||
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
|
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
|
||||||
|
|
|
@ -15,10 +15,9 @@ from llama_toolchain.memory.api import Memory
|
||||||
from llama_toolchain.safety.api import Safety
|
from llama_toolchain.safety.api import Safety
|
||||||
from llama_toolchain.agentic_system.api import * # noqa: F403
|
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||||
from llama_toolchain.tools.builtin import (
|
from llama_toolchain.tools.builtin import (
|
||||||
BraveSearchTool,
|
|
||||||
CodeInterpreterTool,
|
CodeInterpreterTool,
|
||||||
PhotogenTool,
|
PhotogenTool,
|
||||||
RemoteSearchTool,
|
SearchTool,
|
||||||
WolframAlphaTool,
|
WolframAlphaTool,
|
||||||
)
|
)
|
||||||
from llama_toolchain.tools.safety import with_safety
|
from llama_toolchain.tools.safety import with_safety
|
||||||
|
@ -63,20 +62,19 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
if not key:
|
if not key:
|
||||||
raise ValueError("Wolfram API key not defined in config")
|
raise ValueError("Wolfram API key not defined in config")
|
||||||
tool = WolframAlphaTool(key)
|
tool = WolframAlphaTool(key)
|
||||||
elif isinstance(tool_defn, BraveSearchToolDefinition):
|
elif isinstance(tool_defn, SearchToolDefinition):
|
||||||
if tool_defn.remote_execution is not None:
|
key = None
|
||||||
tool = RemoteSearchTool(tool_defn.remote_execution)
|
if tool_defn.engine == SearchEngineType.brave:
|
||||||
else:
|
|
||||||
key = self.config.brave_search_api_key
|
key = self.config.brave_search_api_key
|
||||||
if not key:
|
elif tool_defn.engine == SearchEngineType.bing:
|
||||||
raise ValueError("Brave API key not defined in config")
|
key = self.config.bing_search_api_key
|
||||||
tool = BraveSearchTool(key)
|
if not key:
|
||||||
|
raise ValueError("API key not defined in config")
|
||||||
|
tool = SearchTool(tool_defn.engine, key)
|
||||||
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
|
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
|
||||||
tool = CodeInterpreterTool()
|
tool = CodeInterpreterTool()
|
||||||
elif isinstance(tool_defn, PhotogenToolDefinition):
|
elif isinstance(tool_defn, PhotogenToolDefinition):
|
||||||
tool = PhotogenTool(
|
tool = PhotogenTool(dump_dir=tempfile.mkdtemp())
|
||||||
dump_dir=tempfile.mkdtemp(),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -11,4 +11,5 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
class MetaReferenceImplConfig(BaseModel):
|
class MetaReferenceImplConfig(BaseModel):
|
||||||
brave_search_api_key: Optional[str] = None
|
brave_search_api_key: Optional[str] = None
|
||||||
|
bing_search_api_key: Optional[str] = None
|
||||||
wolfram_api_key: Optional[str] = None
|
wolfram_api_key: Optional[str] = None
|
||||||
|
|
|
@ -14,8 +14,6 @@ from typing import List, Optional
|
||||||
import requests
|
import requests
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
|
|
||||||
|
|
||||||
from .ipython_tool.code_execution import (
|
from .ipython_tool.code_execution import (
|
||||||
CodeExecutionContext,
|
CodeExecutionContext,
|
||||||
CodeExecutionRequest,
|
CodeExecutionRequest,
|
||||||
|
@ -85,21 +83,41 @@ class PhotogenTool(SingleMessageBuiltinTool):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class RemoteSearchTool(SingleMessageBuiltinTool):
|
class SearchTool(SingleMessageBuiltinTool):
|
||||||
def __init__(self, config: RestAPIExecutionConfig) -> None:
|
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
|
||||||
self.config = config
|
self.api_key = api_key
|
||||||
|
if engine == SearchEngineType.bing:
|
||||||
|
self.engine = BingSearch(api_key, **kwargs)
|
||||||
|
elif engine == SearchEngineType.brave:
|
||||||
|
self.engine = BraveSearch(api_key, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown search engine: {engine}")
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> str:
|
||||||
return BuiltinTool.brave_search.value
|
return BuiltinTool.brave_search.value
|
||||||
|
|
||||||
async def run_impl(self, query: str) -> str:
|
async def run_impl(self, query: str) -> str:
|
||||||
params = self.config.params.copy()
|
return await self.engine.search(query)
|
||||||
params["q"] = query
|
|
||||||
response = requests.get(
|
|
||||||
url=self.config.url,
|
class BingSearch:
|
||||||
params=params,
|
def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None:
|
||||||
headers=self.config.headers,
|
self.api_key = api_key
|
||||||
)
|
self.top_k = top_k
|
||||||
|
|
||||||
|
async def search(self, query: str) -> str:
|
||||||
|
url = "https://api.bing.microsoft.com/v7.0/search"
|
||||||
|
headers = {
|
||||||
|
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||||
|
}
|
||||||
|
params = {
|
||||||
|
"count": self.top_k,
|
||||||
|
"textDecorations": True,
|
||||||
|
"textFormat": "HTML",
|
||||||
|
"q": query,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(url=url, params=params, headers=headers)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
clean = self._clean_response(response.json())
|
clean = self._clean_response(response.json())
|
||||||
return json.dumps(clean)
|
return json.dumps(clean)
|
||||||
|
@ -126,14 +144,11 @@ class RemoteSearchTool(SingleMessageBuiltinTool):
|
||||||
return {"query": query, "top_k": clean_response}
|
return {"query": query, "top_k": clean_response}
|
||||||
|
|
||||||
|
|
||||||
class BraveSearchTool(SingleMessageBuiltinTool):
|
class BraveSearch:
|
||||||
def __init__(self, api_key: str) -> None:
|
def __init__(self, api_key: str) -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|
||||||
def get_name(self) -> str:
|
async def search(self, query: str) -> str:
|
||||||
return BuiltinTool.brave_search.value
|
|
||||||
|
|
||||||
async def run_impl(self, query: str) -> str:
|
|
||||||
url = "https://api.search.brave.com/res/v1/web/search"
|
url = "https://api.search.brave.com/res/v1/web/search"
|
||||||
headers = {
|
headers = {
|
||||||
"X-Subscription-Token": self.api_key,
|
"X-Subscription-Token": self.api_key,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue