fix failing memory tests

This commit is contained in:
Dinesh Yeduguru 2024-12-12 11:43:06 -08:00
parent c38d377eb7
commit ee3f0c6b55
3 changed files with 17 additions and 17 deletions

View file

@ -58,10 +58,10 @@ DEFAULT_PROVIDER_COMBINATIONS = [
def pytest_addoption(parser): def pytest_addoption(parser):
parser.addoption( parser.addoption(
"--embedding-model", "--inference-model",
action="store", action="store",
default=None, default=None,
help="Specify the embedding model to use for testing", help="Specify the inference model to use for testing",
) )
@ -74,15 +74,15 @@ def pytest_configure(config):
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
if "embedding_model" in metafunc.fixturenames: if "inference_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--embedding-model") model = metafunc.config.getoption("--inference-model")
if not model: if not model:
raise ValueError( raise ValueError(
"No embedding model specified. Please provide a valid embedding model." "No inference model specified. Please provide a valid inference model."
) )
params = [pytest.param(model, id="")] params = [pytest.param(model, id="")]
metafunc.parametrize("embedding_model", params, indirect=True) metafunc.parametrize("inference_model", params, indirect=True)
if "memory_stack" in metafunc.fixturenames: if "memory_stack" in metafunc.fixturenames:
available_fixtures = { available_fixtures = {
"inference": INFERENCE_FIXTURES, "inference": INFERENCE_FIXTURES,

View file

@ -107,7 +107,7 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def memory_stack(embedding_model, request): async def memory_stack(inference_model, request):
fixture_dict = request.param fixture_dict = request.param
providers = {} providers = {}
@ -124,7 +124,7 @@ async def memory_stack(embedding_model, request):
provider_data, provider_data,
models=[ models=[
ModelInput( ModelInput(
model_id=embedding_model, model_id=inference_model,
model_type=ModelType.embedding_model, model_type=ModelType.embedding_model,
metadata={ metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"), "embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),

View file

@ -46,13 +46,13 @@ def sample_documents():
async def register_memory_bank( async def register_memory_bank(
banks_impl: MemoryBanks, embedding_model: str banks_impl: MemoryBanks, inference_model: str
) -> MemoryBank: ) -> MemoryBank:
bank_id = f"test_bank_{uuid.uuid4().hex}" bank_id = f"test_bank_{uuid.uuid4().hex}"
return await banks_impl.register_memory_bank( return await banks_impl.register_memory_bank(
memory_bank_id=bank_id, memory_bank_id=bank_id,
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model=embedding_model, embedding_model=inference_model,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -61,11 +61,11 @@ async def register_memory_bank(
class TestMemory: class TestMemory:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_banks_list(self, memory_stack, embedding_model): async def test_banks_list(self, memory_stack, inference_model):
_, banks_impl = memory_stack _, banks_impl = memory_stack
# Register a test bank # Register a test bank
registered_bank = await register_memory_bank(banks_impl, embedding_model) registered_bank = await register_memory_bank(banks_impl, inference_model)
try: try:
# Verify our bank shows up in list # Verify our bank shows up in list
@ -86,7 +86,7 @@ class TestMemory:
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_banks_register(self, memory_stack, embedding_model): async def test_banks_register(self, memory_stack, inference_model):
_, banks_impl = memory_stack _, banks_impl = memory_stack
bank_id = f"test_bank_{uuid.uuid4().hex}" bank_id = f"test_bank_{uuid.uuid4().hex}"
@ -96,7 +96,7 @@ class TestMemory:
await banks_impl.register_memory_bank( await banks_impl.register_memory_bank(
memory_bank_id=bank_id, memory_bank_id=bank_id,
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model=embedding_model, embedding_model=inference_model,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -111,7 +111,7 @@ class TestMemory:
await banks_impl.register_memory_bank( await banks_impl.register_memory_bank(
memory_bank_id=bank_id, memory_bank_id=bank_id,
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model=embedding_model, embedding_model=inference_model,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -129,14 +129,14 @@ class TestMemory:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_documents( async def test_query_documents(
self, memory_stack, embedding_model, sample_documents self, memory_stack, inference_model, sample_documents
): ):
memory_impl, banks_impl = memory_stack memory_impl, banks_impl = memory_stack
with pytest.raises(ValueError): with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents) await memory_impl.insert_documents("test_bank", sample_documents)
registered_bank = await register_memory_bank(banks_impl, embedding_model) registered_bank = await register_memory_bank(banks_impl, inference_model)
await memory_impl.insert_documents( await memory_impl.insert_documents(
registered_bank.memory_bank_id, sample_documents registered_bank.memory_bank_id, sample_documents
) )