add wolfram alpha, bing search

This commit is contained in:
Dinesh Yeduguru 2025-01-07 16:07:03 -08:00
parent f9a98c278a
commit 94cca7a72a
14 changed files with 411 additions and 1 deletions

View file

@ -53,6 +53,7 @@ class ToolDef(BaseModel):
description: Optional[str] = None description: Optional[str] = None
parameters: Optional[List[ToolParameter]] = None parameters: Optional[List[ToolParameter]] = None
metadata: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None
built_in_type: Optional[BuiltinTool] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field( tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json default=ToolPromptFormat.json
) )

View file

@ -527,6 +527,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
provider_resource_id=tool_def.name, provider_resource_id=tool_def.name,
metadata=tool_def.metadata, metadata=tool_def.metadata,
tool_host=tool_host, tool_host=tool_host,
built_in_type=tool_def.built_in_type,
) )
) )
for tool in tools: for tool in tools:

View file

@ -9,6 +9,8 @@ import logging
import tempfile import tempfile
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
Tool, Tool,
@ -56,6 +58,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
parameter_type="string", parameter_type="string",
), ),
], ],
built_in_type=BuiltinTool.code_interpreter,
) )
] ]

View file

@ -42,6 +42,16 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
), ),
), ),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="bing-search",
module="llama_stack.providers.remote.tool_runtime.bing_search",
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
),
),
remote_provider_spec( remote_provider_spec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec( adapter=AdapterSpec(
@ -52,6 +62,16 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
), ),
), ),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="wolfram-alpha",
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
),
),
remote_provider_spec( remote_provider_spec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec( adapter=AdapterSpec(

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .bing_search import BingSearchToolRuntimeImpl
from .config import BingSearchToolConfig
__all__ = ["BingSearchToolConfig", "BingSearchToolRuntimeImpl"]
from pydantic import BaseModel
class BingSearchToolProviderDataValidator(BaseModel):
api_key: str
async def get_adapter_impl(config: BingSearchToolConfig, _deps):
impl = BingSearchToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
import requests
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import BingSearchToolConfig
class BingSearchToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: BingSearchToolConfig):
self.config = config
self.url = "https://api.bing.microsoft.com/v7.0/search"
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
pass
async def unregister_tool(self, tool_id: str) -> None:
return
def _get_api_key(self) -> str:
if self.config.api_key:
return self.config.api_key
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key:
raise ValueError(
'Pass Bing Search API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
)
return provider_data.api_key
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="web_search",
description="Search the web using Bing Search API",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
built_in_type=BuiltinTool.brave_search,
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
headers = {
"Ocp-Apim-Subscription-Key": api_key,
}
params = {
"count": self.config.top_k,
"textDecorations": True,
"textFormat": "HTML",
"q": args["query"],
}
response = requests.get(
url=self.url,
params=params,
headers=headers,
)
response.raise_for_status()
return ToolInvocationResult(
content=json.dumps(self._clean_response(response.json()))
)
def _clean_response(self, search_response):
clean_response = []
query = search_response["queryContext"]["originalQuery"]
if "webPages" in search_response:
pages = search_response["webPages"]["value"]
for p in pages:
selected_keys = {"name", "url", "snippet"}
clean_response.append(
{k: v for k, v in p.items() if k in selected_keys}
)
if "news" in search_response:
clean_news = []
news = search_response["news"]["value"]
for n in news:
selected_keys = {"name", "url", "description"}
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
clean_response.append(clean_news)
return {"query": query, "top_k": clean_response}

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel
class BingSearchToolConfig(BaseModel):
"""Configuration for Bing Search Tool Runtime"""
api_key: Optional[str] = None
top_k: int = 3

View file

@ -7,6 +7,7 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests import requests
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
@ -62,6 +63,7 @@ class BraveSearchToolRuntimeImpl(
parameter_type="string", parameter_type="string",
) )
], ],
built_in_type=BuiltinTool.brave_search,
) )
] ]

View file

@ -8,6 +8,7 @@ import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests import requests
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
@ -63,6 +64,7 @@ class TavilySearchToolRuntimeImpl(
parameter_type="string", parameter_type="string",
) )
], ],
built_in_type=BuiltinTool.brave_search,
) )
] ]

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import WolframAlphaToolConfig
from .wolfram_alpha import WolframAlphaToolRuntimeImpl
__all__ = ["WolframAlphaToolConfig", "WolframAlphaToolRuntimeImpl"]
class WolframAlphaToolProviderDataValidator(BaseModel):
api_key: str
async def get_adapter_impl(config: WolframAlphaToolConfig, _deps):
impl = WolframAlphaToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel
class WolframAlphaToolConfig(BaseModel):
"""Configuration for WolframAlpha Tool Runtime"""
api_key: Optional[str] = None

View file

@ -0,0 +1,148 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
import requests
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import WolframAlphaToolConfig
class WolframAlphaToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: WolframAlphaToolConfig):
self.config = config
self.url = "https://api.wolframalpha.com/v2/query"
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
pass
async def unregister_tool(self, tool_id: str) -> None:
return
def _get_api_key(self) -> str:
if self.config.api_key:
return self.config.api_key
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key:
raise ValueError(
'Pass WolframAlpha API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
)
return provider_data.api_key
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="wolfram_alpha",
description="Query WolframAlpha for computational knowledge",
parameters=[
ToolParameter(
name="query",
description="The query to compute",
parameter_type="string",
)
],
built_in_type=BuiltinTool.wolfram_alpha,
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
params = {
"input": args["query"],
"appid": api_key,
"format": "plaintext",
"output": "json",
}
response = requests.get(
self.url,
params=params,
)
return ToolInvocationResult(
content=json.dumps(self._clean_wolfram_alpha_response(response.json()))
)
def _clean_wolfram_alpha_response(self, wa_response):
remove = {
"queryresult": [
"datatypes",
"error",
"timedout",
"timedoutpods",
"numpods",
"timing",
"parsetiming",
"parsetimedout",
"recalculate",
"id",
"host",
"server",
"related",
"version",
{
"pods": [
"scanner",
"id",
"error",
"expressiontypes",
"states",
"infos",
"position",
"numsubpods",
]
},
"assumptions",
],
}
for main_key in remove:
for key_to_remove in remove[main_key]:
try:
if key_to_remove == "assumptions":
if "assumptions" in wa_response[main_key]:
del wa_response[main_key][key_to_remove]
if isinstance(key_to_remove, dict):
for sub_key in key_to_remove:
if sub_key == "pods":
for i in range(len(wa_response[main_key][sub_key])):
if (
wa_response[main_key][sub_key][i]["title"]
== "Result"
):
del wa_response[main_key][sub_key][i + 1 :]
break
sub_items = wa_response[main_key][sub_key]
for i in range(len(sub_items)):
for sub_key_to_remove in key_to_remove[sub_key]:
if sub_key_to_remove in sub_items[i]:
del sub_items[i][sub_key_to_remove]
elif key_to_remove in wa_response[main_key]:
del wa_response[main_key][key_to_remove]
except KeyError:
pass
return wa_response

View file

@ -33,6 +33,13 @@ def tool_runtime_memory_and_search() -> ProviderFixture:
"api_key": os.environ["TAVILY_SEARCH_API_KEY"], "api_key": os.environ["TAVILY_SEARCH_API_KEY"],
}, },
), ),
Provider(
provider_id="wolfram-alpha",
provider_type="remote::wolfram-alpha",
config={
"api_key": os.environ["WOLFRAM_ALPHA_API_KEY"],
},
),
], ],
) )
@ -53,12 +60,24 @@ def tool_group_input_tavily_search() -> ToolGroupInput:
) )
@pytest.fixture(scope="session")
def tool_group_input_wolfram_alpha() -> ToolGroupInput:
return ToolGroupInput(
toolgroup_id="builtin::wolfram_alpha",
provider_id="wolfram-alpha",
)
TOOL_RUNTIME_FIXTURES = ["memory_and_search"] TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def tools_stack( async def tools_stack(
request, inference_model, tool_group_input_memory, tool_group_input_tavily_search request,
inference_model,
tool_group_input_memory,
tool_group_input_tavily_search,
tool_group_input_wolfram_alpha,
): ):
fixture_dict = request.param fixture_dict = request.param
@ -104,6 +123,7 @@ async def tools_stack(
models=models, models=models,
tool_groups=[ tool_groups=[
tool_group_input_tavily_search, tool_group_input_tavily_search,
tool_group_input_wolfram_alpha,
tool_group_input_memory, tool_group_input_memory,
], ],
) )

View file

@ -20,6 +20,11 @@ def sample_search_query():
return "What are the latest developments in quantum computing?" return "What are the latest developments in quantum computing?"
@pytest.fixture
def sample_wolfram_alpha_query():
return "What is the square root of 16?"
@pytest.fixture @pytest.fixture
def sample_documents(): def sample_documents():
urls = [ urls = [
@ -61,6 +66,24 @@ class TestTools:
assert len(response.content) > 0 assert len(response.content) > 0
assert isinstance(response.content, str) assert isinstance(response.content, str)
@pytest.mark.asyncio
async def test_wolfram_alpha_tool(self, tools_stack, sample_wolfram_alpha_query):
"""Test the wolfram alpha tool functionality."""
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
tools_impl = tools_stack.impls[Api.tool_runtime]
response = await tools_impl.invoke_tool(
tool_name="wolfram_alpha", args={"query": sample_wolfram_alpha_query}
)
# Verify the response
assert isinstance(response, ToolInvocationResult)
assert response.content is not None
assert len(response.content) > 0
assert isinstance(response.content, str)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_memory_tool(self, tools_stack, sample_documents): async def test_memory_tool(self, tools_stack, sample_documents):
"""Test the memory tool functionality.""" """Test the memory tool functionality."""