From ee3f0c6b5577d4412bd89e3421140955a094779b Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 12 Dec 2024 11:43:06 -0800 Subject: [PATCH] fix failing memory tests --- llama_stack/providers/tests/memory/conftest.py | 12 ++++++------ llama_stack/providers/tests/memory/fixtures.py | 4 ++-- .../providers/tests/memory/test_memory.py | 18 +++++++++--------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/llama_stack/providers/tests/memory/conftest.py b/llama_stack/providers/tests/memory/conftest.py index 023a1a156..7595538eb 100644 --- a/llama_stack/providers/tests/memory/conftest.py +++ b/llama_stack/providers/tests/memory/conftest.py @@ -58,10 +58,10 @@ DEFAULT_PROVIDER_COMBINATIONS = [ def pytest_addoption(parser): parser.addoption( - "--embedding-model", + "--inference-model", action="store", default=None, - help="Specify the embedding model to use for testing", + help="Specify the inference model to use for testing", ) @@ -74,15 +74,15 @@ def pytest_configure(config): def pytest_generate_tests(metafunc): - if "embedding_model" in metafunc.fixturenames: - model = metafunc.config.getoption("--embedding-model") + if "inference_model" in metafunc.fixturenames: + model = metafunc.config.getoption("--inference-model") if not model: raise ValueError( - "No embedding model specified. Please provide a valid embedding model." + "No inference model specified. Please provide a valid inference model." ) params = [pytest.param(model, id="")] - metafunc.parametrize("embedding_model", params, indirect=True) + metafunc.parametrize("inference_model", params, indirect=True) if "memory_stack" in metafunc.fixturenames: available_fixtures = { "inference": INFERENCE_FIXTURES, diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index 22406fe27..92fd1720e 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -107,7 +107,7 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"] @pytest_asyncio.fixture(scope="session") -async def memory_stack(embedding_model, request): +async def memory_stack(inference_model, request): fixture_dict = request.param providers = {} @@ -124,7 +124,7 @@ async def memory_stack(embedding_model, request): provider_data, models=[ ModelInput( - model_id=embedding_model, + model_id=inference_model, model_type=ModelType.embedding_model, metadata={ "embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"), diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index 526aa646c..03597d073 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -46,13 +46,13 @@ def sample_documents(): async def register_memory_bank( - banks_impl: MemoryBanks, embedding_model: str + banks_impl: MemoryBanks, inference_model: str ) -> MemoryBank: bank_id = f"test_bank_{uuid.uuid4().hex}" return await banks_impl.register_memory_bank( memory_bank_id=bank_id, params=VectorMemoryBankParams( - embedding_model=embedding_model, + embedding_model=inference_model, chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -61,11 +61,11 @@ async def register_memory_bank( class TestMemory: @pytest.mark.asyncio - async def test_banks_list(self, memory_stack, embedding_model): + async def test_banks_list(self, memory_stack, inference_model): _, banks_impl = memory_stack # Register a test bank - registered_bank = await register_memory_bank(banks_impl, embedding_model) + registered_bank = await register_memory_bank(banks_impl, inference_model) try: # Verify our bank shows up in list @@ -86,7 +86,7 @@ class TestMemory: ) @pytest.mark.asyncio - async def test_banks_register(self, memory_stack, embedding_model): + async def test_banks_register(self, memory_stack, inference_model): _, banks_impl = memory_stack bank_id = f"test_bank_{uuid.uuid4().hex}" @@ -96,7 +96,7 @@ class TestMemory: await banks_impl.register_memory_bank( memory_bank_id=bank_id, params=VectorMemoryBankParams( - embedding_model=embedding_model, + embedding_model=inference_model, chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -111,7 +111,7 @@ class TestMemory: await banks_impl.register_memory_bank( memory_bank_id=bank_id, params=VectorMemoryBankParams( - embedding_model=embedding_model, + embedding_model=inference_model, chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -129,14 +129,14 @@ class TestMemory: @pytest.mark.asyncio async def test_query_documents( - self, memory_stack, embedding_model, sample_documents + self, memory_stack, inference_model, sample_documents ): memory_impl, banks_impl = memory_stack with pytest.raises(ValueError): await memory_impl.insert_documents("test_bank", sample_documents) - registered_bank = await register_memory_bank(banks_impl, embedding_model) + registered_bank = await register_memory_bank(banks_impl, inference_model) await memory_impl.insert_documents( registered_bank.memory_bank_id, sample_documents )