mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
add wolfram alpha, bing search
This commit is contained in:
parent
f9a98c278a
commit
94cca7a72a
14 changed files with 411 additions and 1 deletions
|
@ -53,6 +53,7 @@ class ToolDef(BaseModel):
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
parameters: Optional[List[ToolParameter]] = None
|
parameters: Optional[List[ToolParameter]] = None
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
built_in_type: Optional[BuiltinTool] = None
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
default=ToolPromptFormat.json
|
default=ToolPromptFormat.json
|
||||||
)
|
)
|
||||||
|
|
|
@ -527,6 +527,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
provider_resource_id=tool_def.name,
|
provider_resource_id=tool_def.name,
|
||||||
metadata=tool_def.metadata,
|
metadata=tool_def.metadata,
|
||||||
tool_host=tool_host,
|
tool_host=tool_host,
|
||||||
|
built_in_type=tool_def.built_in_type,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
|
|
|
@ -9,6 +9,8 @@ import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
Tool,
|
Tool,
|
||||||
|
@ -56,6 +58,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
parameter_type="string",
|
parameter_type="string",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
built_in_type=BuiltinTool.code_interpreter,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,16 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="bing-search",
|
||||||
|
module="llama_stack.providers.remote.tool_runtime.bing_search",
|
||||||
|
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
|
||||||
|
pip_packages=["requests"],
|
||||||
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
@ -52,6 +62,16 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="wolfram-alpha",
|
||||||
|
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
|
||||||
|
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
|
||||||
|
pip_packages=["requests"],
|
||||||
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .bing_search import BingSearchToolRuntimeImpl
|
||||||
|
from .config import BingSearchToolConfig
|
||||||
|
|
||||||
|
__all__ = ["BingSearchToolConfig", "BingSearchToolRuntimeImpl"]
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BingSearchToolProviderDataValidator(BaseModel):
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: BingSearchToolConfig, _deps):
|
||||||
|
impl = BingSearchToolRuntimeImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,116 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
Tool,
|
||||||
|
ToolDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
|
from .config import BingSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BingSearchToolRuntimeImpl(
|
||||||
|
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||||
|
):
|
||||||
|
def __init__(self, config: BingSearchToolConfig):
|
||||||
|
self.config = config
|
||||||
|
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_tool(self, tool: Tool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _get_api_key(self) -> str:
|
||||||
|
if self.config.api_key:
|
||||||
|
return self.config.api_key
|
||||||
|
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass Bing Search API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
return provider_data.api_key
|
||||||
|
|
||||||
|
async def list_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
return [
|
||||||
|
ToolDef(
|
||||||
|
name="web_search",
|
||||||
|
description="Search the web using Bing Search API",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="query",
|
||||||
|
description="The query to search for",
|
||||||
|
parameter_type="string",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
built_in_type=BuiltinTool.brave_search,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def invoke_tool(
|
||||||
|
self, tool_name: str, args: Dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
api_key = self._get_api_key()
|
||||||
|
headers = {
|
||||||
|
"Ocp-Apim-Subscription-Key": api_key,
|
||||||
|
}
|
||||||
|
params = {
|
||||||
|
"count": self.config.top_k,
|
||||||
|
"textDecorations": True,
|
||||||
|
"textFormat": "HTML",
|
||||||
|
"q": args["query"],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
url=self.url,
|
||||||
|
params=params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=json.dumps(self._clean_response(response.json()))
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clean_response(self, search_response):
|
||||||
|
clean_response = []
|
||||||
|
query = search_response["queryContext"]["originalQuery"]
|
||||||
|
if "webPages" in search_response:
|
||||||
|
pages = search_response["webPages"]["value"]
|
||||||
|
for p in pages:
|
||||||
|
selected_keys = {"name", "url", "snippet"}
|
||||||
|
clean_response.append(
|
||||||
|
{k: v for k, v in p.items() if k in selected_keys}
|
||||||
|
)
|
||||||
|
if "news" in search_response:
|
||||||
|
clean_news = []
|
||||||
|
news = search_response["news"]["value"]
|
||||||
|
for n in news:
|
||||||
|
selected_keys = {"name", "url", "description"}
|
||||||
|
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
|
||||||
|
|
||||||
|
clean_response.append(clean_news)
|
||||||
|
|
||||||
|
return {"query": query, "top_k": clean_response}
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BingSearchToolConfig(BaseModel):
|
||||||
|
"""Configuration for Bing Search Tool Runtime"""
|
||||||
|
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
top_k: int = 3
|
|
@ -7,6 +7,7 @@
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
@ -62,6 +63,7 @@ class BraveSearchToolRuntimeImpl(
|
||||||
parameter_type="string",
|
parameter_type="string",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
built_in_type=BuiltinTool.brave_search,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
@ -63,6 +64,7 @@ class TavilySearchToolRuntimeImpl(
|
||||||
parameter_type="string",
|
parameter_type="string",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
built_in_type=BuiltinTool.brave_search,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .config import WolframAlphaToolConfig
|
||||||
|
from .wolfram_alpha import WolframAlphaToolRuntimeImpl
|
||||||
|
|
||||||
|
__all__ = ["WolframAlphaToolConfig", "WolframAlphaToolRuntimeImpl"]
|
||||||
|
|
||||||
|
|
||||||
|
class WolframAlphaToolProviderDataValidator(BaseModel):
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: WolframAlphaToolConfig, _deps):
|
||||||
|
impl = WolframAlphaToolRuntimeImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class WolframAlphaToolConfig(BaseModel):
|
||||||
|
"""Configuration for WolframAlpha Tool Runtime"""
|
||||||
|
|
||||||
|
api_key: Optional[str] = None
|
|
@ -0,0 +1,148 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
Tool,
|
||||||
|
ToolDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
|
from .config import WolframAlphaToolConfig
|
||||||
|
|
||||||
|
|
||||||
|
class WolframAlphaToolRuntimeImpl(
|
||||||
|
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||||
|
):
|
||||||
|
def __init__(self, config: WolframAlphaToolConfig):
|
||||||
|
self.config = config
|
||||||
|
self.url = "https://api.wolframalpha.com/v2/query"
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_tool(self, tool: Tool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _get_api_key(self) -> str:
|
||||||
|
if self.config.api_key:
|
||||||
|
return self.config.api_key
|
||||||
|
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass WolframAlpha API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
return provider_data.api_key
|
||||||
|
|
||||||
|
async def list_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
return [
|
||||||
|
ToolDef(
|
||||||
|
name="wolfram_alpha",
|
||||||
|
description="Query WolframAlpha for computational knowledge",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="query",
|
||||||
|
description="The query to compute",
|
||||||
|
parameter_type="string",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
built_in_type=BuiltinTool.wolfram_alpha,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def invoke_tool(
|
||||||
|
self, tool_name: str, args: Dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
api_key = self._get_api_key()
|
||||||
|
params = {
|
||||||
|
"input": args["query"],
|
||||||
|
"appid": api_key,
|
||||||
|
"format": "plaintext",
|
||||||
|
"output": "json",
|
||||||
|
}
|
||||||
|
response = requests.get(
|
||||||
|
self.url,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=json.dumps(self._clean_wolfram_alpha_response(response.json()))
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clean_wolfram_alpha_response(self, wa_response):
|
||||||
|
remove = {
|
||||||
|
"queryresult": [
|
||||||
|
"datatypes",
|
||||||
|
"error",
|
||||||
|
"timedout",
|
||||||
|
"timedoutpods",
|
||||||
|
"numpods",
|
||||||
|
"timing",
|
||||||
|
"parsetiming",
|
||||||
|
"parsetimedout",
|
||||||
|
"recalculate",
|
||||||
|
"id",
|
||||||
|
"host",
|
||||||
|
"server",
|
||||||
|
"related",
|
||||||
|
"version",
|
||||||
|
{
|
||||||
|
"pods": [
|
||||||
|
"scanner",
|
||||||
|
"id",
|
||||||
|
"error",
|
||||||
|
"expressiontypes",
|
||||||
|
"states",
|
||||||
|
"infos",
|
||||||
|
"position",
|
||||||
|
"numsubpods",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"assumptions",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for main_key in remove:
|
||||||
|
for key_to_remove in remove[main_key]:
|
||||||
|
try:
|
||||||
|
if key_to_remove == "assumptions":
|
||||||
|
if "assumptions" in wa_response[main_key]:
|
||||||
|
del wa_response[main_key][key_to_remove]
|
||||||
|
if isinstance(key_to_remove, dict):
|
||||||
|
for sub_key in key_to_remove:
|
||||||
|
if sub_key == "pods":
|
||||||
|
for i in range(len(wa_response[main_key][sub_key])):
|
||||||
|
if (
|
||||||
|
wa_response[main_key][sub_key][i]["title"]
|
||||||
|
== "Result"
|
||||||
|
):
|
||||||
|
del wa_response[main_key][sub_key][i + 1 :]
|
||||||
|
break
|
||||||
|
sub_items = wa_response[main_key][sub_key]
|
||||||
|
for i in range(len(sub_items)):
|
||||||
|
for sub_key_to_remove in key_to_remove[sub_key]:
|
||||||
|
if sub_key_to_remove in sub_items[i]:
|
||||||
|
del sub_items[i][sub_key_to_remove]
|
||||||
|
elif key_to_remove in wa_response[main_key]:
|
||||||
|
del wa_response[main_key][key_to_remove]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
return wa_response
|
|
@ -33,6 +33,13 @@ def tool_runtime_memory_and_search() -> ProviderFixture:
|
||||||
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
|
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="wolfram-alpha",
|
||||||
|
provider_type="remote::wolfram-alpha",
|
||||||
|
config={
|
||||||
|
"api_key": os.environ["WOLFRAM_ALPHA_API_KEY"],
|
||||||
|
},
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -53,12 +60,24 @@ def tool_group_input_tavily_search() -> ToolGroupInput:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def tool_group_input_wolfram_alpha() -> ToolGroupInput:
|
||||||
|
return ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::wolfram_alpha",
|
||||||
|
provider_id="wolfram-alpha",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
|
TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def tools_stack(
|
async def tools_stack(
|
||||||
request, inference_model, tool_group_input_memory, tool_group_input_tavily_search
|
request,
|
||||||
|
inference_model,
|
||||||
|
tool_group_input_memory,
|
||||||
|
tool_group_input_tavily_search,
|
||||||
|
tool_group_input_wolfram_alpha,
|
||||||
):
|
):
|
||||||
fixture_dict = request.param
|
fixture_dict = request.param
|
||||||
|
|
||||||
|
@ -104,6 +123,7 @@ async def tools_stack(
|
||||||
models=models,
|
models=models,
|
||||||
tool_groups=[
|
tool_groups=[
|
||||||
tool_group_input_tavily_search,
|
tool_group_input_tavily_search,
|
||||||
|
tool_group_input_wolfram_alpha,
|
||||||
tool_group_input_memory,
|
tool_group_input_memory,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,6 +20,11 @@ def sample_search_query():
|
||||||
return "What are the latest developments in quantum computing?"
|
return "What are the latest developments in quantum computing?"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_wolfram_alpha_query():
|
||||||
|
return "What is the square root of 16?"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_documents():
|
def sample_documents():
|
||||||
urls = [
|
urls = [
|
||||||
|
@ -61,6 +66,24 @@ class TestTools:
|
||||||
assert len(response.content) > 0
|
assert len(response.content) > 0
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wolfram_alpha_tool(self, tools_stack, sample_wolfram_alpha_query):
|
||||||
|
"""Test the wolfram alpha tool functionality."""
|
||||||
|
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
|
||||||
|
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
|
||||||
|
|
||||||
|
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||||
|
|
||||||
|
response = await tools_impl.invoke_tool(
|
||||||
|
tool_name="wolfram_alpha", args={"query": sample_wolfram_alpha_query}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert isinstance(response, ToolInvocationResult)
|
||||||
|
assert response.content is not None
|
||||||
|
assert len(response.content) > 0
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_memory_tool(self, tools_stack, sample_documents):
|
async def test_memory_tool(self, tools_stack, sample_documents):
|
||||||
"""Test the memory tool functionality."""
|
"""Test the memory tool functionality."""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue