mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
add tool for bing search
This commit is contained in:
parent
741310f78e
commit
640c5f8ab9
5 changed files with 79 additions and 10 deletions
|
@ -16,6 +16,7 @@ from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
|
||||||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from .api import * # noqa: F403
|
from .api import * # noqa: F403
|
||||||
|
@ -134,8 +135,21 @@ async def run_main(host: str, port: int):
|
||||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
tool_definitions = [
|
tool_definitions = [
|
||||||
BraveSearchToolDefinition(),
|
BraveSearchToolDefinition(
|
||||||
WolframAlphaToolDefinition(),
|
remote_execution=RestAPIExecutionConfig(
|
||||||
|
url=URL(uri="https://api.bing.microsoft.com/v7.0/search"),
|
||||||
|
method=RestAPIMethod.GET,
|
||||||
|
headers={
|
||||||
|
"Ocp-Apim-Subscription-Key": "2259b3f9e0cb4fc9b968bb3b02ab13e7"
|
||||||
|
},
|
||||||
|
params={
|
||||||
|
"count": 3,
|
||||||
|
"textDecorations": True,
|
||||||
|
"textFormat": "HTML",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
),
|
||||||
|
# WolframAlphaToolDefinition(),
|
||||||
CodeInterpreterToolDefinition(),
|
CodeInterpreterToolDefinition(),
|
||||||
]
|
]
|
||||||
tool_definitions += [
|
tool_definitions += [
|
||||||
|
|
|
@ -18,6 +18,7 @@ from llama_toolchain.tools.builtin import (
|
||||||
BraveSearchTool,
|
BraveSearchTool,
|
||||||
CodeInterpreterTool,
|
CodeInterpreterTool,
|
||||||
PhotogenTool,
|
PhotogenTool,
|
||||||
|
RemoteSearchTool,
|
||||||
WolframAlphaTool,
|
WolframAlphaTool,
|
||||||
)
|
)
|
||||||
from llama_toolchain.tools.safety import with_safety
|
from llama_toolchain.tools.safety import with_safety
|
||||||
|
@ -63,10 +64,13 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
raise ValueError("Wolfram API key not defined in config")
|
raise ValueError("Wolfram API key not defined in config")
|
||||||
tool = WolframAlphaTool(key)
|
tool = WolframAlphaTool(key)
|
||||||
elif isinstance(tool_defn, BraveSearchToolDefinition):
|
elif isinstance(tool_defn, BraveSearchToolDefinition):
|
||||||
key = self.config.brave_search_api_key
|
if tool_defn.remote_execution is not None:
|
||||||
if not key:
|
tool = RemoteSearchTool(tool_defn.remote_execution)
|
||||||
raise ValueError("Brave API key not defined in config")
|
else:
|
||||||
tool = BraveSearchTool(key)
|
key = self.config.brave_search_api_key
|
||||||
|
if not key:
|
||||||
|
raise ValueError("Brave API key not defined in config")
|
||||||
|
tool = BraveSearchTool(key)
|
||||||
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
|
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
|
||||||
tool = CodeInterpreterTool()
|
tool = CodeInterpreterTool()
|
||||||
elif isinstance(tool_defn, PhotogenToolDefinition):
|
elif isinstance(tool_defn, PhotogenToolDefinition):
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import URL
|
from llama_models.llama3.api.datatypes import URL
|
||||||
|
|
||||||
|
@ -26,6 +26,6 @@ class RestAPIMethod(Enum):
|
||||||
class RestAPIExecutionConfig(BaseModel):
|
class RestAPIExecutionConfig(BaseModel):
|
||||||
url: URL
|
url: URL
|
||||||
method: RestAPIMethod
|
method: RestAPIMethod
|
||||||
params: Optional[Dict[str, str]] = None
|
params: Optional[Dict[str, Any]] = None
|
||||||
headers: Optional[Dict[str, str]] = None
|
headers: Optional[Dict[str, Any]] = None
|
||||||
body: Optional[Dict[str, str]] = None
|
body: Optional[Dict[str, Any]] = None
|
||||||
|
|
|
@ -86,6 +86,14 @@ class InferenceClient(Inference):
|
||||||
|
|
||||||
async def run_main(host: str, port: int, stream: bool):
|
async def run_main(host: str, port: int, stream: bool):
|
||||||
client = InferenceClient(f"http://{host}:{port}")
|
client = InferenceClient(f"http://{host}:{port}")
|
||||||
|
# from llama_toolchain.inference.meta_reference import get_provider_impl
|
||||||
|
# from .meta_reference.config import MetaReferenceImplConfig
|
||||||
|
|
||||||
|
# config = MetaReferenceImplConfig(
|
||||||
|
# model="Meta-Llama3.1-8B-Instruct",
|
||||||
|
# max_seq_len=4096,
|
||||||
|
# )
|
||||||
|
# client = await get_provider_impl(config, {})
|
||||||
|
|
||||||
message = UserMessage(content="hello world, troll me in two-paragraphs about 42")
|
message = UserMessage(content="hello world, troll me in two-paragraphs about 42")
|
||||||
cprint(f"User>{message.content}", "green")
|
cprint(f"User>{message.content}", "green")
|
||||||
|
|
|
@ -14,6 +14,8 @@ from typing import List, Optional
|
||||||
import requests
|
import requests
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
|
||||||
|
|
||||||
from .ipython_tool.code_execution import (
|
from .ipython_tool.code_execution import (
|
||||||
CodeExecutionContext,
|
CodeExecutionContext,
|
||||||
CodeExecutionRequest,
|
CodeExecutionRequest,
|
||||||
|
@ -83,6 +85,47 @@ class PhotogenTool(SingleMessageBuiltinTool):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteSearchTool(SingleMessageBuiltinTool):
|
||||||
|
def __init__(self, config: RestAPIExecutionConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return BuiltinTool.brave_search.value
|
||||||
|
|
||||||
|
async def run_impl(self, query: str) -> str:
|
||||||
|
params = self.config.params.copy()
|
||||||
|
params["q"] = query
|
||||||
|
response = requests.get(
|
||||||
|
url=self.config.url,
|
||||||
|
params=params,
|
||||||
|
headers=self.config.headers,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
clean = self._clean_response(response.json())
|
||||||
|
return json.dumps(clean)
|
||||||
|
|
||||||
|
def _clean_response(self, search_response):
|
||||||
|
clean_response = []
|
||||||
|
query = search_response["queryContext"]["originalQuery"]
|
||||||
|
if "webPages" in search_response:
|
||||||
|
pages = search_response["webPages"]["value"]
|
||||||
|
for p in pages:
|
||||||
|
selected_keys = {"name", "url", "snippet"}
|
||||||
|
clean_response.append(
|
||||||
|
{k: v for k, v in p.items() if k in selected_keys}
|
||||||
|
)
|
||||||
|
if "news" in search_response:
|
||||||
|
clean_news = []
|
||||||
|
news = search_response["news"]["value"]
|
||||||
|
for n in news:
|
||||||
|
selected_keys = {"name", "url", "description"}
|
||||||
|
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
|
||||||
|
|
||||||
|
clean_response.append(clean_news)
|
||||||
|
|
||||||
|
return {"query": query, "top_k": clean_response}
|
||||||
|
|
||||||
|
|
||||||
class BraveSearchTool(SingleMessageBuiltinTool):
|
class BraveSearchTool(SingleMessageBuiltinTool):
|
||||||
def __init__(self, api_key: str) -> None:
|
def __init__(self, api_key: str) -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue