From c2dd0cdc78e87d983797c9745885c7320116bdc4 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 30 Dec 2024 13:27:43 -0800 Subject: [PATCH] more test fixes --- .gitignore | 1 - .../providers/tests/agents/conftest.py | 3 +- .../providers/tests/agents/fixtures.py | 81 ++-------------- .../providers/tests/agents/test_agents.py | 2 + llama_stack/providers/tests/conftest.py | 1 + llama_stack/providers/tests/tools/conftest.py | 2 +- llama_stack/providers/tests/tools/fixtures.py | 93 ++++++++++--------- 7 files changed, 63 insertions(+), 120 deletions(-) diff --git a/.gitignore b/.gitignore index f3585a51f..421ff4db1 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,3 @@ Package.resolved _build docs/src pyrightconfig.json -.aider* diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index f805fbbbb..ecd05dcf8 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -10,7 +10,8 @@ from ..conftest import get_provider_fixture_overrides from ..inference.fixtures import INFERENCE_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield -from .fixtures import AGENTS_FIXTURES, TOOL_RUNTIME_FIXTURES +from ..tools.fixtures import TOOL_RUNTIME_FIXTURES +from .fixtures import AGENTS_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 71e98102e..1b1781f36 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -4,21 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os import tempfile import pytest import pytest_asyncio -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.models import ModelInput, ModelType -from llama_stack.apis.tools import ( - BuiltInToolDef, - CustomToolDef, - ToolGroupInput, - ToolParameter, - UserDefinedToolGroupDef, -) from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.agents.meta_reference import ( MetaReferenceAgentsImplConfig, @@ -63,32 +54,17 @@ def agents_meta_reference() -> ProviderFixture: ) -@pytest.fixture(scope="session") -def tool_runtime_memory_and_search() -> ProviderFixture: - return ProviderFixture( - providers=[ - Provider( - provider_id="memory-runtime", - provider_type="inline::memory-runtime", - config={}, - ), - Provider( - provider_id="tavily-search", - provider_type="remote::tavily-search", - config={ - "api_key": os.environ["TAVILY_SEARCH_API_KEY"], - }, - ), - ], - ) - - AGENTS_FIXTURES = ["meta_reference", "remote"] -TOOL_RUNTIME_FIXTURES = ["memory_and_search"] @pytest_asyncio.fixture(scope="session") -async def agents_stack(request, inference_model, safety_shield): +async def agents_stack( + request, + inference_model, + safety_shield, + tool_group_input_memory, + tool_group_input_tavily_search, +): fixture_dict = request.param providers = {} @@ -140,47 +116,6 @@ async def agents_stack(request, inference_model, safety_shield): metadata={"embedding_dimension": 384}, ) ) - tool_groups = [ - ToolGroupInput( - tool_group_id="tavily_search_group", - tool_group=UserDefinedToolGroupDef( - tools=[ - BuiltInToolDef( - built_in_type=BuiltinTool.brave_search, - metadata={}, - ), - ], - ), - provider_id="tavily-search", - ), - ToolGroupInput( - tool_group_id="memory_group", - tool_group=UserDefinedToolGroupDef( - tools=[ - CustomToolDef( - name="memory", - description="memory", - parameters=[ - ToolParameter( - name="input_messages", - description="messages", - parameter_type="list", - required=True, - ), - ], - metadata={ - "config": { - "memory_bank_configs": [ - {"bank_id": "test_bank", "type": "vector"} - ] - } - }, - ) - ], - ), - provider_id="memory-runtime", - ), - ] test_stack = await construct_stack_for_test( [Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime], @@ -188,6 +123,6 @@ async def agents_stack(request, inference_model, safety_shield): provider_data, models=models, shields=[safety_shield] if safety_shield else [], - tool_groups=tool_groups, + tool_groups=[tool_group_input_memory, tool_group_input_tavily_search], ) return test_stack diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 3534e0f84..e02af9c92 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -22,6 +22,8 @@ from llama_stack.apis.agents import ( Turn, ) from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage +from llama_stack.apis.memory import MemoryBankDocument +from llama_stack.apis.memory_banks import VectorMemoryBankParams from llama_stack.apis.safety import ViolationLevel from llama_stack.providers.datatypes import Api diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 4d7831ae3..7408a6375 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -157,4 +157,5 @@ pytest_plugins = [ "llama_stack.providers.tests.scoring.fixtures", "llama_stack.providers.tests.eval.fixtures", "llama_stack.providers.tests.post_training.fixtures", + "llama_stack.providers.tests.tools.fixtures", ] diff --git a/llama_stack/providers/tests/tools/conftest.py b/llama_stack/providers/tests/tools/conftest.py index 6de90dc48..11aad5ab6 100644 --- a/llama_stack/providers/tests/tools/conftest.py +++ b/llama_stack/providers/tests/tools/conftest.py @@ -10,7 +10,7 @@ from ..conftest import get_provider_fixture_overrides from ..inference.fixtures import INFERENCE_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES -from .fixtures import TOOL_RUNTIME_FIXTURES, tools_stack # noqa: F401 +from .fixtures import TOOL_RUNTIME_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( diff --git a/llama_stack/providers/tests/tools/fixtures.py b/llama_stack/providers/tests/tools/fixtures.py index 5493a4987..911043011 100644 --- a/llama_stack/providers/tests/tools/fixtures.py +++ b/llama_stack/providers/tests/tools/fixtures.py @@ -44,11 +44,55 @@ def tool_runtime_memory_and_search() -> ProviderFixture: ) +@pytest.fixture(scope="session") +def tool_group_input_memory() -> ToolGroupInput: + return ToolGroupInput( + tool_group_id="memory_group", + tool_group=UserDefinedToolGroupDef( + tools=[ + CustomToolDef( + name="memory", + description="Query the memory bank", + parameters=[ + ToolParameter( + name="input_messages", + description="The input messages to search for in memory", + parameter_type="list", + required=True, + ), + ], + metadata={ + "config": { + "memory_bank_configs": [ + {"bank_id": "test_bank", "type": "vector"} + ] + } + }, + ) + ], + ), + provider_id="memory-runtime", + ) + + +@pytest.fixture(scope="session") +def tool_group_input_tavily_search() -> ToolGroupInput: + return ToolGroupInput( + tool_group_id="tavily_search_group", + tool_group=UserDefinedToolGroupDef( + tools=[BuiltInToolDef(built_in_type=BuiltinTool.brave_search, metadata={})], + ), + provider_id="tavily-search", + ) + + TOOL_RUNTIME_FIXTURES = ["memory_and_search"] @pytest_asyncio.fixture(scope="session") -async def tools_stack(request, inference_model): +async def tools_stack( + request, inference_model, tool_group_input_memory, tool_group_input_tavily_search +): fixture_dict = request.param providers = {} @@ -86,53 +130,14 @@ async def tools_stack(request, inference_model): ) ) - tool_groups = [ - ToolGroupInput( - tool_group_id="tavily_search_group", - tool_group=UserDefinedToolGroupDef( - tools=[ - BuiltInToolDef( - built_in_type=BuiltinTool.brave_search, - metadata={}, - ), - ], - ), - provider_id="tavily-search", - ), - ToolGroupInput( - tool_group_id="memory_group", - tool_group=UserDefinedToolGroupDef( - tools=[ - CustomToolDef( - name="memory", - description="Query the memory bank", - parameters=[ - ToolParameter( - name="input_messages", - description="The input messages to search for in memory", - parameter_type="list", - required=True, - ), - ], - metadata={ - "config": { - "memory_bank_configs": [ - {"bank_id": "test_bank", "type": "vector"} - ] - } - }, - ) - ], - ), - provider_id="memory-runtime", - ), - ] - test_stack = await construct_stack_for_test( [Api.tool_groups, Api.inference, Api.memory, Api.tool_runtime], providers, provider_data, models=models, - tool_groups=tool_groups, + tool_groups=[ + tool_group_input_tavily_search, + tool_group_input_memory, + ], ) return test_stack