passing tool tests

This commit is contained in:
Dinesh Yeduguru 2024-12-30 11:37:28 -08:00
parent 50852cadf3
commit b7ae86ae03
2 changed files with 29 additions and 23 deletions

View file

@ -17,6 +17,7 @@ from llama_stack.apis.tools import (
ToolParameter,
UserDefinedToolGroupDef,
)
from llama_stack.apis.tools.tools import BuiltinTool
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
@ -47,12 +48,12 @@ TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
@pytest_asyncio.fixture(scope="session")
async def tools_stack(request, inference_model, safety_shield):
async def tools_stack(request, inference_model):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["inference", "memory", "tools", "tool_runtime"]:
for key in ["inference", "memory", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if key == "inference":
@ -91,8 +92,7 @@ async def tools_stack(request, inference_model, safety_shield):
tool_group=UserDefinedToolGroupDef(
tools=[
BuiltInToolDef(
name="brave_search",
description="Search the web using Brave Search",
built_in_type=BuiltinTool.brave_search,
metadata={},
),
],
@ -108,19 +108,19 @@ async def tools_stack(request, inference_model, safety_shield):
description="Query the memory bank",
parameters=[
ToolParameter(
name="query",
description="The query to search for in memory",
parameter_type="string",
required=True,
),
ToolParameter(
name="memory_bank_id",
description="The ID of the memory bank to search",
parameter_type="string",
name="input_messages",
description="The input messages to search for in memory",
parameter_type="list",
required=True,
),
],
metadata={},
metadata={
"config": {
"memory_bank_configs": [
{"bank_id": "test_bank", "type": "vector"}
]
}
},
)
],
),
@ -129,7 +129,7 @@ async def tools_stack(request, inference_model, safety_shield):
]
test_stack = await construct_stack_for_test(
[Api.tools, Api.inference, Api.memory],
[Api.tool_groups, Api.inference, Api.memory, Api.tool_runtime],
providers,
provider_data,
models=models,