From abe1cc6303cebeb9dd4823c6722e20b12b194a92 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 11 Nov 2024 22:31:09 -0800 Subject: [PATCH] Fix agents tests --- .../agents/meta_reference/agent_instance.py | 11 +++--- .../providers/tests/agents/conftest.py | 6 ++-- .../providers/tests/agents/fixtures.py | 34 +++++++++++++++++-- .../providers/tests/agents/test_agents.py | 9 ++--- .../providers/tests/memory/fixtures.py | 8 ++--- .../providers/tests/memory/test_memory.py | 3 +- .../providers/tests/safety/fixtures.py | 4 +-- 7 files changed, 51 insertions(+), 24 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index a36a2c24f..2b3d0dbc4 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -641,12 +641,13 @@ class ChatAgent(ShieldRunnerMixin): if session_info.memory_bank_id is None: bank_id = f"memory_bank_{session_id}" - memory_bank = VectorMemoryBank( - identifier=bank_id, - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, + await self.memory_banks_api.register_memory_bank( + memory_bank_id=bank_id, + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + ), ) - await self.memory_banks_api.register_memory_bank(memory_bank) await self.storage.add_memory_bank_to_session(session_id, bank_id) else: bank_id = session_info.memory_bank_id diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index c2e1261f7..aa3910b39 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -19,7 +19,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ { "inference": "meta_reference", "safety": "llama_guard", - "memory": "meta_reference", + "memory": "faiss", "agents": "meta_reference", }, id="meta_reference", @@ -29,7 +29,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ { "inference": "ollama", "safety": "llama_guard", - "memory": "meta_reference", + "memory": "faiss", "agents": "meta_reference", }, id="ollama", @@ -40,7 +40,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "inference": "together", "safety": "llama_guard", # make this work with Weaviate which is what the together distro supports - "memory": "meta_reference", + "memory": "faiss", "agents": "meta_reference", }, id="together", diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 8330e2604..6ee17ff1f 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -9,6 +9,7 @@ import tempfile import pytest import pytest_asyncio +from llama_stack.apis.models import Model from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.agents.meta_reference import ( @@ -17,8 +18,18 @@ from llama_stack.providers.inline.agents.meta_reference import ( from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig - from ..conftest import ProviderFixture, remote_stack_fixture +from ..safety.fixtures import get_shield_to_register + + +def pick_inference_model(inference_model): + # This is not entirely satisfactory. The fixture `inference_model` can correspond to + # multiple models when you need to run a safety model in addition to normal agent + # inference model. We filter off the safety model by looking for "Llama-Guard" + if isinstance(inference_model, list): + inference_model = next(m for m in inference_model if "Llama-Guard" not in m) + assert inference_model is not None + return inference_model @pytest.fixture(scope="session") @@ -49,7 +60,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"] @pytest_asyncio.fixture(scope="session") -async def agents_stack(request): +async def agents_stack(request, inference_model, safety_model): fixture_dict = request.param providers = {} @@ -60,9 +71,28 @@ async def agents_stack(request): if fixture.provider_data: provider_data.update(fixture.provider_data) + inf_provider_id = providers["inference"][0].provider_id + safety_provider_id = providers["safety"][0].provider_id + + shield = get_shield_to_register( + providers["safety"][0].provider_type, safety_provider_id, safety_model + ) + + inference_models = ( + inference_model if isinstance(inference_model, list) else [inference_model] + ) impls = await resolve_impls_for_test_v2( [Api.agents, Api.inference, Api.safety, Api.memory], providers, provider_data, + models=[ + Model( + identifier=model, + provider_id=inf_provider_id, + provider_resource_id=model, + ) + for model in inference_models + ], + shields=[shield], ) return impls[Api.agents], impls[Api.memory] diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 5b1fe202a..b3f3dc31c 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -16,15 +16,12 @@ from llama_stack.providers.datatypes import * # noqa: F403 # pytest -v -s llama_stack/providers/tests/agents/test_agents.py # -m "meta_reference" +from .fixtures import pick_inference_model + @pytest.fixture def common_params(inference_model): - # This is not entirely satisfactory. The fixture `inference_model` can correspond to - # multiple models when you need to run a safety model in addition to normal agent - # inference model. We filter off the safety model by looking for "Llama-Guard" - if isinstance(inference_model, list): - inference_model = next(m for m in inference_model if "Llama-Guard" not in m) - assert inference_model is not None + inference_model = pick_inference_model(inference_model) return dict( model=inference_model, diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index 482049045..456e354b2 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -26,13 +26,13 @@ def memory_remote() -> ProviderFixture: @pytest.fixture(scope="session") -def memory_meta_reference() -> ProviderFixture: +def memory_faiss() -> ProviderFixture: temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") return ProviderFixture( providers=[ Provider( - provider_id="meta-reference", - provider_type="meta-reference", + provider_id="faiss", + provider_type="inline::faiss", config=FaissImplConfig( kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(), ).model_dump(), @@ -93,7 +93,7 @@ def memory_chroma() -> ProviderFixture: ) -MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote", "chroma"] +MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index a1befa6b0..24cef8a24 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -44,7 +44,6 @@ def sample_documents(): async def register_memory_bank(banks_impl: MemoryBanks): - return await banks_impl.register_memory_bank( memory_bank_id="test_bank", params=VectorMemoryBankParams( @@ -71,7 +70,7 @@ class TestMemory: # but so far we don't have an unregister API unfortunately, so be careful _, banks_impl = memory_stack - bank = await banks_impl.register_memory_bank( + await banks_impl.register_memory_bank( memory_bank_id="test_bank_no_provider", params=VectorMemoryBankParams( embedding_model="all-MiniLM-L6-v2", diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 942e6c116..5e553830c 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -99,7 +99,7 @@ async def safety_stack(inference_model, safety_model, request): provider_data.update(safety_fixture.provider_data) shield_provider_type = safety_fixture.providers[0].provider_type - shield = get_shield( + shield = get_shield_to_register( shield_provider_type, safety_fixture.providers[0].provider_id, safety_model ) @@ -120,7 +120,7 @@ async def safety_stack(inference_model, safety_model, request): return impls[Api.safety], impls[Api.shields], shield -def get_shield(provider_type: str, provider_id: str, safety_model: str): +def get_shield_to_register(provider_type: str, provider_id: str, safety_model: str): shield_config = {} shield_type = ShieldType.llama_guard identifier = "llama_guard"