diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index fadb78182..0120de44e 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -16,6 +16,7 @@ from pydantic import BaseModel from termcolor import cprint from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_toolchain.common.deployment_types import RestAPIExecutionConfig from llama_toolchain.core.datatypes import RemoteProviderConfig from .api import * # noqa: F403 @@ -134,8 +135,21 @@ async def run_main(host: str, port: int): api = AgenticSystemClient(f"http://{host}:{port}") tool_definitions = [ - BraveSearchToolDefinition(), - WolframAlphaToolDefinition(), + BraveSearchToolDefinition( + remote_execution=RestAPIExecutionConfig( + url=URL(uri="https://api.bing.microsoft.com/v7.0/search"), + method=RestAPIMethod.GET, + headers={ + "Ocp-Apim-Subscription-Key": "2259b3f9e0cb4fc9b968bb3b02ab13e7" + }, + params={ + "count": 3, + "textDecorations": True, + "textFormat": "HTML", + }, + ) + ), + # WolframAlphaToolDefinition(), CodeInterpreterToolDefinition(), ] tool_definitions += [ diff --git a/llama_toolchain/agentic_system/meta_reference/agentic_system.py b/llama_toolchain/agentic_system/meta_reference/agentic_system.py index 09fbfdde5..79d20fd26 100644 --- a/llama_toolchain/agentic_system/meta_reference/agentic_system.py +++ b/llama_toolchain/agentic_system/meta_reference/agentic_system.py @@ -18,6 +18,7 @@ from llama_toolchain.tools.builtin import ( BraveSearchTool, CodeInterpreterTool, PhotogenTool, + RemoteSearchTool, WolframAlphaTool, ) from llama_toolchain.tools.safety import with_safety @@ -63,10 +64,13 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): raise ValueError("Wolfram API key not defined in config") tool = WolframAlphaTool(key) elif isinstance(tool_defn, BraveSearchToolDefinition): - key = self.config.brave_search_api_key - if not key: - raise ValueError("Brave API key not defined in config") - tool = BraveSearchTool(key) + if tool_defn.remote_execution is not None: + tool = RemoteSearchTool(tool_defn.remote_execution) + else: + key = self.config.brave_search_api_key + if not key: + raise ValueError("Brave API key not defined in config") + tool = BraveSearchTool(key) elif isinstance(tool_defn, CodeInterpreterToolDefinition): tool = CodeInterpreterTool() elif isinstance(tool_defn, PhotogenToolDefinition): diff --git a/llama_toolchain/common/deployment_types.py b/llama_toolchain/common/deployment_types.py index 8b67eff0d..af05aaae4 100644 --- a/llama_toolchain/common/deployment_types.py +++ b/llama_toolchain/common/deployment_types.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Dict, Optional +from typing import Any, Dict, Optional from llama_models.llama3.api.datatypes import URL @@ -26,6 +26,6 @@ class RestAPIMethod(Enum): class RestAPIExecutionConfig(BaseModel): url: URL method: RestAPIMethod - params: Optional[Dict[str, str]] = None - headers: Optional[Dict[str, str]] = None - body: Optional[Dict[str, str]] = None + params: Optional[Dict[str, Any]] = None + headers: Optional[Dict[str, Any]] = None + body: Optional[Dict[str, Any]] = None diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index 5ba9314bc..1e1dcb8c0 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -86,6 +86,14 @@ class InferenceClient(Inference): async def run_main(host: str, port: int, stream: bool): client = InferenceClient(f"http://{host}:{port}") + # from llama_toolchain.inference.meta_reference import get_provider_impl + # from .meta_reference.config import MetaReferenceImplConfig + + # config = MetaReferenceImplConfig( + # model="Meta-Llama3.1-8B-Instruct", + # max_seq_len=4096, + # ) + # client = await get_provider_impl(config, {}) message = UserMessage(content="hello world, troll me in two-paragraphs about 42") cprint(f"User>{message.content}", "green") diff --git a/llama_toolchain/tools/builtin.py b/llama_toolchain/tools/builtin.py index 3a53e2e26..1792fde57 100644 --- a/llama_toolchain/tools/builtin.py +++ b/llama_toolchain/tools/builtin.py @@ -14,6 +14,8 @@ from typing import List, Optional import requests from termcolor import cprint +from llama_toolchain.common.deployment_types import RestAPIExecutionConfig + from .ipython_tool.code_execution import ( CodeExecutionContext, CodeExecutionRequest, @@ -83,6 +85,47 @@ class PhotogenTool(SingleMessageBuiltinTool): raise NotImplementedError() +class RemoteSearchTool(SingleMessageBuiltinTool): + def __init__(self, config: RestAPIExecutionConfig) -> None: + self.config = config + + def get_name(self) -> str: + return BuiltinTool.brave_search.value + + async def run_impl(self, query: str) -> str: + params = self.config.params.copy() + params["q"] = query + response = requests.get( + url=self.config.url, + params=params, + headers=self.config.headers, + ) + response.raise_for_status() + clean = self._clean_response(response.json()) + return json.dumps(clean) + + def _clean_response(self, search_response): + clean_response = [] + query = search_response["queryContext"]["originalQuery"] + if "webPages" in search_response: + pages = search_response["webPages"]["value"] + for p in pages: + selected_keys = {"name", "url", "snippet"} + clean_response.append( + {k: v for k, v in p.items() if k in selected_keys} + ) + if "news" in search_response: + clean_news = [] + news = search_response["news"]["value"] + for n in news: + selected_keys = {"name", "url", "description"} + clean_news.append({k: v for k, v in n.items() if k in selected_keys}) + + clean_response.append(clean_news) + + return {"query": query, "top_k": clean_response} + + class BraveSearchTool(SingleMessageBuiltinTool): def __init__(self, api_key: str) -> None: self.api_key = api_key