diff --git a/.gitignore b/.gitignore index 421ff4db1..f3585a51f 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ Package.resolved _build docs/src pyrightconfig.json +.aider* diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index 9a98526ab..b9dbb84f7 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -19,6 +19,7 @@ from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail diff --git a/llama_stack/providers/tests/tools/__init__.py b/llama_stack/providers/tests/tools/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/tests/tools/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/tests/tools/conftest.py b/llama_stack/providers/tests/tools/conftest.py new file mode 100644 index 000000000..11aad5ab6 --- /dev/null +++ b/llama_stack/providers/tests/tools/conftest.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from ..conftest import get_provider_fixture_overrides +from ..inference.fixtures import INFERENCE_FIXTURES +from ..memory.fixtures import MEMORY_FIXTURES +from ..safety.fixtures import SAFETY_FIXTURES +from .fixtures import TOOL_RUNTIME_FIXTURES + +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "inference": "together", + "safety": "llama_guard", + "memory": "faiss", + "tool_runtime": "memory_and_search", + }, + id="together", + marks=pytest.mark.together, + ), +] + + +def pytest_configure(config): + for mark in ["together"]: + config.addinivalue_line( + "markers", + f"{mark}: marks tests as {mark} specific", + ) + + +def pytest_addoption(parser): + parser.addoption( + "--inference-model", + action="store", + default="meta-llama/Llama-3.2-3B-Instruct", + help="Specify the inference model to use for testing", + ) + parser.addoption( + "--safety-shield", + action="store", + default="meta-llama/Llama-Guard-3-1B", + help="Specify the safety shield to use for testing", + ) + + +def pytest_generate_tests(metafunc): + if "tools_stack" in metafunc.fixturenames: + available_fixtures = { + "inference": INFERENCE_FIXTURES, + "safety": SAFETY_FIXTURES, + "memory": MEMORY_FIXTURES, + "tool_runtime": TOOL_RUNTIME_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS + ) + print(combinations) + metafunc.parametrize("tools_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/tools/fixtures.py b/llama_stack/providers/tests/tools/fixtures.py new file mode 100644 index 000000000..f7580ee2f --- /dev/null +++ b/llama_stack/providers/tests/tools/fixtures.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +import pytest +import pytest_asyncio + +from llama_stack.apis.models import ModelInput, ModelType +from llama_stack.apis.tools import ( + BuiltInToolDef, + CustomToolDef, + ToolGroupInput, + ToolParameter, + UserDefinedToolGroupDef, +) +from llama_stack.distribution.datatypes import Api, Provider +from llama_stack.providers.tests.resolver import construct_stack_for_test + +from ..conftest import ProviderFixture + + +@pytest.fixture(scope="session") +def tool_runtime_memory_and_search() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="memory-runtime", + provider_type="inline::memory-runtime", + config={}, + ), + Provider( + provider_id="tavily-search", + provider_type="remote::tavily-search", + config={ + "api_key": os.environ["TAVILY_SEARCH_API_KEY"], + }, + ), + ], + ) + + +TOOL_RUNTIME_FIXTURES = ["memory_and_search"] + + +@pytest_asyncio.fixture(scope="session") +async def tools_stack(request, inference_model, safety_shield): + fixture_dict = request.param + + providers = {} + provider_data = {} + for key in ["inference", "memory", "tools", "tool_runtime"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if key == "inference": + providers[key].append( + Provider( + provider_id="tools_memory_provider", + provider_type="inline::sentence-transformers", + config={}, + ) + ) + if fixture.provider_data: + provider_data.update(fixture.provider_data) + inference_models = ( + inference_model if isinstance(inference_model, list) else [inference_model] + ) + models = [ + ModelInput( + model_id=model, + model_type=ModelType.llm, + provider_id=providers["inference"][0].provider_id, + ) + for model in inference_models + ] + models.append( + ModelInput( + model_id="all-MiniLM-L6-v2", + model_type=ModelType.embedding, + provider_id="tools_memory_provider", + metadata={"embedding_dimension": 384}, + ) + ) + + tool_groups = [ + ToolGroupInput( + tool_group_id="tavily_search_group", + tool_group=UserDefinedToolGroupDef( + tools=[ + BuiltInToolDef( + name="brave_search", + description="Search the web using Brave Search", + metadata={}, + ), + ], + ), + provider_id="tavily-search", + ), + ToolGroupInput( + tool_group_id="memory_group", + tool_group=UserDefinedToolGroupDef( + tools=[ + CustomToolDef( + name="memory", + description="Query the memory bank", + parameters=[ + ToolParameter( + name="query", + description="The query to search for in memory", + parameter_type="string", + required=True, + ), + ToolParameter( + name="memory_bank_id", + description="The ID of the memory bank to search", + parameter_type="string", + required=True, + ), + ], + metadata={}, + ) + ], + ), + provider_id="memory-runtime", + ), + ] + + test_stack = await construct_stack_for_test( + [Api.tools, Api.inference, Api.memory], + providers, + provider_data, + models=models, + tool_groups=tool_groups, + ) + return test_stack diff --git a/llama_stack/providers/tests/tools/test_tools.py b/llama_stack/providers/tests/tools/test_tools.py new file mode 100644 index 000000000..96a80414c --- /dev/null +++ b/llama_stack/providers/tests/tools/test_tools.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +import pytest + +from llama_stack.apis.memory import MemoryBankDocument +from llama_stack.apis.memory_banks import VectorMemoryBankParams +from llama_stack.apis.tools import ToolInvocationResult +from llama_stack.providers.datatypes import Api + + +@pytest.fixture +def sample_search_query(): + return "What are the latest developments in quantum computing?" + + +@pytest.fixture +def sample_documents(): + urls = [ + "memory_optimizations.rst", + "chat.rst", + "llama3.rst", + "datasets.rst", + "qat_finetune.rst", + "lora_finetune.rst", + ] + return [ + MemoryBankDocument( + document_id=f"num-{i}", + content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", + mime_type="text/plain", + metadata={}, + ) + for i, url in enumerate(urls) + ] + + +class TestTools: + @pytest.mark.asyncio + async def test_brave_search_tool(self, tools_stack, sample_search_query): + """Test the Brave search tool functionality.""" + if "TAVILY_SEARCH_API_KEY" not in os.environ: + pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") + + tools_impl = tools_stack.impls[Api.tool_runtime] + + # Execute the tool + response = await tools_impl.invoke_tool( + tool_name="brave_search", tool_args={"query": sample_search_query} + ) + + # Verify the response + assert isinstance(response, ToolInvocationResult) + assert response.content is not None + assert len(response.content) > 0 + assert isinstance(response.content, str) + + @pytest.mark.asyncio + async def test_memory_tool(self, tools_stack, sample_documents): + """Test the memory tool functionality.""" + memory_banks_impl = tools_stack.impls[Api.memory_banks] + memory_impl = tools_stack.impls[Api.memory] + tools_impl = tools_stack.impls[Api.tools] + + # Register memory bank + await memory_banks_impl.register_memory_bank( + memory_bank_id="test_memory_bank", + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + provider_id="faiss", + ) + + # Insert documents into memory + memory_impl.insert_documents( + bank_id="test_memory_bank", + documents=sample_documents, + ) + + # Execute the memory tool + response = await tools_impl.invoke_tool( + tool_name="memory", + tool_args={ + "query": "What are the main topics covered in the documentation?", + }, + ) + + # Verify the response + assert isinstance(response, ToolInvocationResult) + assert response.content is not None + assert len(response.content) > 0 + assert isinstance(response.content, str)