mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 19:49:47 +00:00
remove mixin and test fixes
This commit is contained in:
parent
5bbeb985ca
commit
0e451525e5
9 changed files with 140 additions and 69 deletions
|
|
@ -6,9 +6,65 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from .fixtures import MEMORY_FIXTURES
|
||||
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "meta_reference",
|
||||
"memory": "faiss",
|
||||
},
|
||||
id="meta_reference",
|
||||
marks=pytest.mark.meta_reference,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"memory": "pgvector",
|
||||
},
|
||||
id="ollama",
|
||||
marks=pytest.mark.ollama,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "together",
|
||||
"memory": "chroma",
|
||||
},
|
||||
id="chroma",
|
||||
marks=pytest.mark.chroma,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "bedrock",
|
||||
"memory": "qdrant",
|
||||
},
|
||||
id="qdrant",
|
||||
marks=pytest.mark.qdrant,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "fireworks",
|
||||
"memory": "weaviate",
|
||||
},
|
||||
id="weaviate",
|
||||
marks=pytest.mark.weaviate,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--embedding-model",
|
||||
action="store",
|
||||
default=None,
|
||||
help="Specify the embedding model to use for testing",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for fixture_name in MEMORY_FIXTURES:
|
||||
config.addinivalue_line(
|
||||
|
|
@ -18,12 +74,22 @@ def pytest_configure(config):
|
|||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "embedding_model" in metafunc.fixturenames:
|
||||
model = metafunc.config.getoption("--embedding-model")
|
||||
if not model:
|
||||
raise ValueError(
|
||||
"No embedding model specified. Please provide a valid embedding model."
|
||||
)
|
||||
params = [pytest.param(model, id="")]
|
||||
|
||||
metafunc.parametrize("embedding_model", params, indirect=True)
|
||||
if "memory_stack" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"memory_stack",
|
||||
[
|
||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||
for fixture_name in MEMORY_FIXTURES
|
||||
],
|
||||
indirect=True,
|
||||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
"memory": MEMORY_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||
or DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
metafunc.parametrize("memory_stack", combinations, indirect=True)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import tempfile
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.inference import ModelInput, ModelType
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig
|
||||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||
|
|
@ -97,14 +99,30 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def memory_stack(request):
|
||||
fixture_name = request.param
|
||||
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
||||
async def memory_stack(embedding_model, request):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "memory"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.memory],
|
||||
{"memory": fixture.providers},
|
||||
fixture.provider_data,
|
||||
[Api.memory, Api.inference],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[
|
||||
ModelInput(
|
||||
model_id=embedding_model,
|
||||
model_type=ModelType.embedding_model,
|
||||
metadata={
|
||||
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
||||
|
|
|
|||
|
|
@ -45,12 +45,14 @@ def sample_documents():
|
|||
]
|
||||
|
||||
|
||||
async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
|
||||
async def register_memory_bank(
|
||||
banks_impl: MemoryBanks, embedding_model: str
|
||||
) -> MemoryBank:
|
||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||
return await banks_impl.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_model=embedding_model,
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
|
|
@ -59,11 +61,11 @@ async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
|
|||
|
||||
class TestMemory:
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_list(self, memory_stack):
|
||||
async def test_banks_list(self, memory_stack, embedding_model):
|
||||
_, banks_impl = memory_stack
|
||||
|
||||
# Register a test bank
|
||||
registered_bank = await register_memory_bank(banks_impl)
|
||||
registered_bank = await register_memory_bank(banks_impl, embedding_model)
|
||||
|
||||
try:
|
||||
# Verify our bank shows up in list
|
||||
|
|
@ -84,7 +86,7 @@ class TestMemory:
|
|||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_register(self, memory_stack):
|
||||
async def test_banks_register(self, memory_stack, embedding_model):
|
||||
_, banks_impl = memory_stack
|
||||
|
||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||
|
|
@ -94,7 +96,7 @@ class TestMemory:
|
|||
await banks_impl.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_model=embedding_model,
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
|
|
@ -109,7 +111,7 @@ class TestMemory:
|
|||
await banks_impl.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_model=embedding_model,
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
|
|
@ -126,13 +128,15 @@ class TestMemory:
|
|||
await banks_impl.unregister_memory_bank(bank_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(self, memory_stack, sample_documents):
|
||||
async def test_query_documents(
|
||||
self, memory_stack, embedding_model, sample_documents
|
||||
):
|
||||
memory_impl, banks_impl = memory_stack
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||
|
||||
registered_bank = await register_memory_bank(banks_impl)
|
||||
registered_bank = await register_memory_bank(banks_impl, embedding_model)
|
||||
await memory_impl.insert_documents(
|
||||
registered_bank.memory_bank_id, sample_documents
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue