diff --git a/llama_stack/providers/tests/tools/fixtures.py b/llama_stack/providers/tests/tools/fixtures.py index f7580ee2f..845e0dba4 100644 --- a/llama_stack/providers/tests/tools/fixtures.py +++ b/llama_stack/providers/tests/tools/fixtures.py @@ -17,6 +17,7 @@ from llama_stack.apis.tools import ( ToolParameter, UserDefinedToolGroupDef, ) +from llama_stack.apis.tools.tools import BuiltinTool from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.tests.resolver import construct_stack_for_test @@ -47,12 +48,12 @@ TOOL_RUNTIME_FIXTURES = ["memory_and_search"] @pytest_asyncio.fixture(scope="session") -async def tools_stack(request, inference_model, safety_shield): +async def tools_stack(request, inference_model): fixture_dict = request.param providers = {} provider_data = {} - for key in ["inference", "memory", "tools", "tool_runtime"]: + for key in ["inference", "memory", "tool_runtime"]: fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") providers[key] = fixture.providers if key == "inference": @@ -91,8 +92,7 @@ async def tools_stack(request, inference_model, safety_shield): tool_group=UserDefinedToolGroupDef( tools=[ BuiltInToolDef( - name="brave_search", - description="Search the web using Brave Search", + built_in_type=BuiltinTool.brave_search, metadata={}, ), ], @@ -108,19 +108,19 @@ async def tools_stack(request, inference_model, safety_shield): description="Query the memory bank", parameters=[ ToolParameter( - name="query", - description="The query to search for in memory", - parameter_type="string", - required=True, - ), - ToolParameter( - name="memory_bank_id", - description="The ID of the memory bank to search", - parameter_type="string", + name="input_messages", + description="The input messages to search for in memory", + parameter_type="list", required=True, ), ], - metadata={}, + metadata={ + "config": { + "memory_bank_configs": [ + {"bank_id": "test_bank", "type": "vector"} + ] + } + }, ) ], ), @@ -129,7 +129,7 @@ async def tools_stack(request, inference_model, safety_shield): ] test_stack = await construct_stack_for_test( - [Api.tools, Api.inference, Api.memory], + [Api.tool_groups, Api.inference, Api.memory, Api.tool_runtime], providers, provider_data, models=models, diff --git a/llama_stack/providers/tests/tools/test_tools.py b/llama_stack/providers/tests/tools/test_tools.py index 96a80414c..08c7afe1e 100644 --- a/llama_stack/providers/tests/tools/test_tools.py +++ b/llama_stack/providers/tests/tools/test_tools.py @@ -8,11 +8,14 @@ import os import pytest +from llama_stack.apis.inference import UserMessage from llama_stack.apis.memory import MemoryBankDocument from llama_stack.apis.memory_banks import VectorMemoryBankParams from llama_stack.apis.tools import ToolInvocationResult from llama_stack.providers.datatypes import Api +from .fixtures import tool_runtime_memory_and_search # noqa: F401 + @pytest.fixture def sample_search_query(): @@ -51,7 +54,7 @@ class TestTools: # Execute the tool response = await tools_impl.invoke_tool( - tool_name="brave_search", tool_args={"query": sample_search_query} + tool_name="brave_search", args={"query": sample_search_query} ) # Verify the response @@ -65,11 +68,11 @@ class TestTools: """Test the memory tool functionality.""" memory_banks_impl = tools_stack.impls[Api.memory_banks] memory_impl = tools_stack.impls[Api.memory] - tools_impl = tools_stack.impls[Api.tools] + tools_impl = tools_stack.impls[Api.tool_runtime] # Register memory bank await memory_banks_impl.register_memory_bank( - memory_bank_id="test_memory_bank", + memory_bank_id="test_bank", params=VectorMemoryBankParams( embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, @@ -79,16 +82,20 @@ class TestTools: ) # Insert documents into memory - memory_impl.insert_documents( - bank_id="test_memory_bank", + await memory_impl.insert_documents( + bank_id="test_bank", documents=sample_documents, ) # Execute the memory tool response = await tools_impl.invoke_tool( tool_name="memory", - tool_args={ - "query": "What are the main topics covered in the documentation?", + args={ + "input_messages": [ + UserMessage( + content="What are the main topics covered in the documentation?", + ) + ], }, ) @@ -96,4 +103,3 @@ class TestTools: assert isinstance(response, ToolInvocationResult) assert response.content is not None assert len(response.content) > 0 - assert isinstance(response.content, str)