mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
fix failing memory tests
This commit is contained in:
parent
c38d377eb7
commit
ee3f0c6b55
3 changed files with 17 additions and 17 deletions
|
@ -58,10 +58,10 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser):
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--embedding-model",
|
"--inference-model",
|
||||||
action="store",
|
action="store",
|
||||||
default=None,
|
default=None,
|
||||||
help="Specify the embedding model to use for testing",
|
help="Specify the inference model to use for testing",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,15 +74,15 @@ def pytest_configure(config):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
if "embedding_model" in metafunc.fixturenames:
|
if "inference_model" in metafunc.fixturenames:
|
||||||
model = metafunc.config.getoption("--embedding-model")
|
model = metafunc.config.getoption("--inference-model")
|
||||||
if not model:
|
if not model:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No embedding model specified. Please provide a valid embedding model."
|
"No inference model specified. Please provide a valid inference model."
|
||||||
)
|
)
|
||||||
params = [pytest.param(model, id="")]
|
params = [pytest.param(model, id="")]
|
||||||
|
|
||||||
metafunc.parametrize("embedding_model", params, indirect=True)
|
metafunc.parametrize("inference_model", params, indirect=True)
|
||||||
if "memory_stack" in metafunc.fixturenames:
|
if "memory_stack" in metafunc.fixturenames:
|
||||||
available_fixtures = {
|
available_fixtures = {
|
||||||
"inference": INFERENCE_FIXTURES,
|
"inference": INFERENCE_FIXTURES,
|
||||||
|
|
|
@ -107,7 +107,7 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def memory_stack(embedding_model, request):
|
async def memory_stack(inference_model, request):
|
||||||
fixture_dict = request.param
|
fixture_dict = request.param
|
||||||
|
|
||||||
providers = {}
|
providers = {}
|
||||||
|
@ -124,7 +124,7 @@ async def memory_stack(embedding_model, request):
|
||||||
provider_data,
|
provider_data,
|
||||||
models=[
|
models=[
|
||||||
ModelInput(
|
ModelInput(
|
||||||
model_id=embedding_model,
|
model_id=inference_model,
|
||||||
model_type=ModelType.embedding_model,
|
model_type=ModelType.embedding_model,
|
||||||
metadata={
|
metadata={
|
||||||
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
||||||
|
|
|
@ -46,13 +46,13 @@ def sample_documents():
|
||||||
|
|
||||||
|
|
||||||
async def register_memory_bank(
|
async def register_memory_bank(
|
||||||
banks_impl: MemoryBanks, embedding_model: str
|
banks_impl: MemoryBanks, inference_model: str
|
||||||
) -> MemoryBank:
|
) -> MemoryBank:
|
||||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
return await banks_impl.register_memory_bank(
|
return await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model=embedding_model,
|
embedding_model=inference_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -61,11 +61,11 @@ async def register_memory_bank(
|
||||||
|
|
||||||
class TestMemory:
|
class TestMemory:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_list(self, memory_stack, embedding_model):
|
async def test_banks_list(self, memory_stack, inference_model):
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
# Register a test bank
|
# Register a test bank
|
||||||
registered_bank = await register_memory_bank(banks_impl, embedding_model)
|
registered_bank = await register_memory_bank(banks_impl, inference_model)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Verify our bank shows up in list
|
# Verify our bank shows up in list
|
||||||
|
@ -86,7 +86,7 @@ class TestMemory:
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_register(self, memory_stack, embedding_model):
|
async def test_banks_register(self, memory_stack, inference_model):
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
|
@ -96,7 +96,7 @@ class TestMemory:
|
||||||
await banks_impl.register_memory_bank(
|
await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model=embedding_model,
|
embedding_model=inference_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -111,7 +111,7 @@ class TestMemory:
|
||||||
await banks_impl.register_memory_bank(
|
await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model=embedding_model,
|
embedding_model=inference_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -129,14 +129,14 @@ class TestMemory:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_documents(
|
async def test_query_documents(
|
||||||
self, memory_stack, embedding_model, sample_documents
|
self, memory_stack, inference_model, sample_documents
|
||||||
):
|
):
|
||||||
memory_impl, banks_impl = memory_stack
|
memory_impl, banks_impl = memory_stack
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||||
|
|
||||||
registered_bank = await register_memory_bank(banks_impl, embedding_model)
|
registered_bank = await register_memory_bank(banks_impl, inference_model)
|
||||||
await memory_impl.insert_documents(
|
await memory_impl.insert_documents(
|
||||||
registered_bank.memory_bank_id, sample_documents
|
registered_bank.memory_bank_id, sample_documents
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue