mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-24 06:13:56 +00:00
add wolfram alpha, bing search
This commit is contained in:
parent
f9a98c278a
commit
94cca7a72a
14 changed files with 411 additions and 1 deletions
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .bing_search import BingSearchToolRuntimeImpl
|
||||
from .config import BingSearchToolConfig
|
||||
|
||||
__all__ = ["BingSearchToolConfig", "BingSearchToolRuntimeImpl"]
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BingSearchToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: BingSearchToolConfig, _deps):
|
||||
impl = BingSearchToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .config import BingSearchToolConfig
|
||||
|
||||
|
||||
class BingSearchToolRuntimeImpl(
|
||||
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: BingSearchToolConfig):
|
||||
self.config = config
|
||||
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key:
|
||||
return self.config.api_key
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.api_key:
|
||||
raise ValueError(
|
||||
'Pass Bing Search API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
|
||||
)
|
||||
return provider_data.api_key
|
||||
|
||||
async def list_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web using Bing Search API",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
built_in_type=BuiltinTool.brave_search,
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": api_key,
|
||||
}
|
||||
params = {
|
||||
"count": self.config.top_k,
|
||||
"textDecorations": True,
|
||||
"textFormat": "HTML",
|
||||
"q": args["query"],
|
||||
}
|
||||
|
||||
response = requests.get(
|
||||
url=self.url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=json.dumps(self._clean_response(response.json()))
|
||||
)
|
||||
|
||||
def _clean_response(self, search_response):
|
||||
clean_response = []
|
||||
query = search_response["queryContext"]["originalQuery"]
|
||||
if "webPages" in search_response:
|
||||
pages = search_response["webPages"]["value"]
|
||||
for p in pages:
|
||||
selected_keys = {"name", "url", "snippet"}
|
||||
clean_response.append(
|
||||
{k: v for k, v in p.items() if k in selected_keys}
|
||||
)
|
||||
if "news" in search_response:
|
||||
clean_news = []
|
||||
news = search_response["news"]["value"]
|
||||
for n in news:
|
||||
selected_keys = {"name", "url", "description"}
|
||||
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
|
||||
|
||||
clean_response.append(clean_news)
|
||||
|
||||
return {"query": query, "top_k": clean_response}
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BingSearchToolConfig(BaseModel):
|
||||
"""Configuration for Bing Search Tool Runtime"""
|
||||
|
||||
api_key: Optional[str] = None
|
||||
top_k: int = 3
|
||||
|
|
@ -7,6 +7,7 @@
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
|
|
@ -62,6 +63,7 @@ class BraveSearchToolRuntimeImpl(
|
|||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
built_in_type=BuiltinTool.brave_search,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import json
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
|
|
@ -63,6 +64,7 @@ class TavilySearchToolRuntimeImpl(
|
|||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
built_in_type=BuiltinTool.brave_search,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import WolframAlphaToolConfig
|
||||
from .wolfram_alpha import WolframAlphaToolRuntimeImpl
|
||||
|
||||
__all__ = ["WolframAlphaToolConfig", "WolframAlphaToolRuntimeImpl"]
|
||||
|
||||
|
||||
class WolframAlphaToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WolframAlphaToolConfig, _deps):
|
||||
impl = WolframAlphaToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class WolframAlphaToolConfig(BaseModel):
|
||||
"""Configuration for WolframAlpha Tool Runtime"""
|
||||
|
||||
api_key: Optional[str] = None
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .config import WolframAlphaToolConfig
|
||||
|
||||
|
||||
class WolframAlphaToolRuntimeImpl(
|
||||
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: WolframAlphaToolConfig):
|
||||
self.config = config
|
||||
self.url = "https://api.wolframalpha.com/v2/query"
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key:
|
||||
return self.config.api_key
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.api_key:
|
||||
raise ValueError(
|
||||
'Pass WolframAlpha API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
|
||||
)
|
||||
return provider_data.api_key
|
||||
|
||||
async def list_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="wolfram_alpha",
|
||||
description="Query WolframAlpha for computational knowledge",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to compute",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
built_in_type=BuiltinTool.wolfram_alpha,
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
params = {
|
||||
"input": args["query"],
|
||||
"appid": api_key,
|
||||
"format": "plaintext",
|
||||
"output": "json",
|
||||
}
|
||||
response = requests.get(
|
||||
self.url,
|
||||
params=params,
|
||||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=json.dumps(self._clean_wolfram_alpha_response(response.json()))
|
||||
)
|
||||
|
||||
def _clean_wolfram_alpha_response(self, wa_response):
|
||||
remove = {
|
||||
"queryresult": [
|
||||
"datatypes",
|
||||
"error",
|
||||
"timedout",
|
||||
"timedoutpods",
|
||||
"numpods",
|
||||
"timing",
|
||||
"parsetiming",
|
||||
"parsetimedout",
|
||||
"recalculate",
|
||||
"id",
|
||||
"host",
|
||||
"server",
|
||||
"related",
|
||||
"version",
|
||||
{
|
||||
"pods": [
|
||||
"scanner",
|
||||
"id",
|
||||
"error",
|
||||
"expressiontypes",
|
||||
"states",
|
||||
"infos",
|
||||
"position",
|
||||
"numsubpods",
|
||||
]
|
||||
},
|
||||
"assumptions",
|
||||
],
|
||||
}
|
||||
for main_key in remove:
|
||||
for key_to_remove in remove[main_key]:
|
||||
try:
|
||||
if key_to_remove == "assumptions":
|
||||
if "assumptions" in wa_response[main_key]:
|
||||
del wa_response[main_key][key_to_remove]
|
||||
if isinstance(key_to_remove, dict):
|
||||
for sub_key in key_to_remove:
|
||||
if sub_key == "pods":
|
||||
for i in range(len(wa_response[main_key][sub_key])):
|
||||
if (
|
||||
wa_response[main_key][sub_key][i]["title"]
|
||||
== "Result"
|
||||
):
|
||||
del wa_response[main_key][sub_key][i + 1 :]
|
||||
break
|
||||
sub_items = wa_response[main_key][sub_key]
|
||||
for i in range(len(sub_items)):
|
||||
for sub_key_to_remove in key_to_remove[sub_key]:
|
||||
if sub_key_to_remove in sub_items[i]:
|
||||
del sub_items[i][sub_key_to_remove]
|
||||
elif key_to_remove in wa_response[main_key]:
|
||||
del wa_response[main_key][key_to_remove]
|
||||
except KeyError:
|
||||
pass
|
||||
return wa_response
|
||||
Loading…
Add table
Add a link
Reference in a new issue