mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 11:19:47 +00:00
Memory tests pass now
This commit is contained in:
parent
e51154964f
commit
59ce047aea
23 changed files with 122 additions and 81 deletions
|
|
@ -81,13 +81,13 @@ def pytest_addoption(parser):
|
|||
parser.addoption(
|
||||
"--inference-model",
|
||||
action="store",
|
||||
default="meta-llama/Llama-3.1-8B-Instruct",
|
||||
default="meta-llama/Llama-3.2-3B-Instruct",
|
||||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-shield",
|
||||
action="store",
|
||||
default="meta-llama/Llama-Guard-3-8B",
|
||||
default="meta-llama/Llama-Guard-3-1B",
|
||||
help="Specify the safety shield to use for testing",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -192,6 +192,19 @@ def inference_tgi() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_sentence_transformers() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="sentence_transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config={},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_model_short_name(model_name: str) -> str:
|
||||
"""Convert model name to a short test identifier.
|
||||
|
||||
|
|
|
|||
|
|
@ -15,23 +15,23 @@ from .fixtures import MEMORY_FIXTURES
|
|||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "meta_reference",
|
||||
"inference": "sentence_transformers",
|
||||
"memory": "faiss",
|
||||
},
|
||||
id="meta_reference",
|
||||
marks=pytest.mark.meta_reference,
|
||||
id="sentence_transformers",
|
||||
marks=pytest.mark.sentence_transformers,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"memory": "pgvector",
|
||||
"memory": "faiss",
|
||||
},
|
||||
id="ollama",
|
||||
marks=pytest.mark.ollama,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "together",
|
||||
"inference": "sentence_transformers",
|
||||
"memory": "chroma",
|
||||
},
|
||||
id="chroma",
|
||||
|
|
@ -58,10 +58,10 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--inference-model",
|
||||
"--embedding-model",
|
||||
action="store",
|
||||
default=None,
|
||||
help="Specify the inference model to use for testing",
|
||||
help="Specify the embedding model to use for testing",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -74,15 +74,15 @@ def pytest_configure(config):
|
|||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "inference_model" in metafunc.fixturenames:
|
||||
model = metafunc.config.getoption("--inference-model")
|
||||
if not model:
|
||||
raise ValueError(
|
||||
"No inference model specified. Please provide a valid inference model."
|
||||
)
|
||||
params = [pytest.param(model, id="")]
|
||||
if "embedding_model" in metafunc.fixturenames:
|
||||
model = metafunc.config.getoption("--embedding-model")
|
||||
if model:
|
||||
params = [pytest.param(model, id="")]
|
||||
else:
|
||||
params = [pytest.param("all-MiniLM-L6-v2", id="")]
|
||||
|
||||
metafunc.parametrize("embedding_model", params, indirect=True)
|
||||
|
||||
metafunc.parametrize("inference_model", params, indirect=True)
|
||||
if "memory_stack" in metafunc.fixturenames:
|
||||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
|
|
|
|||
|
|
@ -24,6 +24,13 @@ from ..conftest import ProviderFixture, remote_stack_fixture
|
|||
from ..env import get_env_or_fail
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def embedding_model(request):
|
||||
if hasattr(request, "param"):
|
||||
return request.param
|
||||
return request.config.getoption("--embedding-model", None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def memory_remote() -> ProviderFixture:
|
||||
return remote_stack_fixture()
|
||||
|
|
@ -107,7 +114,7 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def memory_stack(inference_model, request):
|
||||
async def memory_stack(embedding_model, request):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
|
|
@ -124,7 +131,7 @@ async def memory_stack(inference_model, request):
|
|||
provider_data,
|
||||
models=[
|
||||
ModelInput(
|
||||
model_id=inference_model,
|
||||
model_id=embedding_model,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
||||
|
|
|
|||
|
|
@ -46,13 +46,13 @@ def sample_documents():
|
|||
|
||||
|
||||
async def register_memory_bank(
|
||||
banks_impl: MemoryBanks, inference_model: str
|
||||
banks_impl: MemoryBanks, embedding_model: str
|
||||
) -> MemoryBank:
|
||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||
return await banks_impl.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model=inference_model,
|
||||
embedding_model=embedding_model,
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
|
|
@ -61,11 +61,11 @@ async def register_memory_bank(
|
|||
|
||||
class TestMemory:
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_list(self, memory_stack, inference_model):
|
||||
async def test_banks_list(self, memory_stack, embedding_model):
|
||||
_, banks_impl = memory_stack
|
||||
|
||||
# Register a test bank
|
||||
registered_bank = await register_memory_bank(banks_impl, inference_model)
|
||||
registered_bank = await register_memory_bank(banks_impl, embedding_model)
|
||||
|
||||
try:
|
||||
# Verify our bank shows up in list
|
||||
|
|
@ -86,7 +86,7 @@ class TestMemory:
|
|||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_register(self, memory_stack, inference_model):
|
||||
async def test_banks_register(self, memory_stack, embedding_model):
|
||||
_, banks_impl = memory_stack
|
||||
|
||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||
|
|
@ -96,7 +96,7 @@ class TestMemory:
|
|||
await banks_impl.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model=inference_model,
|
||||
embedding_model=embedding_model,
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
|
|
@ -111,7 +111,7 @@ class TestMemory:
|
|||
await banks_impl.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model=inference_model,
|
||||
embedding_model=embedding_model,
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
|
|
@ -129,14 +129,14 @@ class TestMemory:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(
|
||||
self, memory_stack, inference_model, sample_documents
|
||||
self, memory_stack, embedding_model, sample_documents
|
||||
):
|
||||
memory_impl, banks_impl = memory_stack
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||
|
||||
registered_bank = await register_memory_bank(banks_impl, inference_model)
|
||||
registered_bank = await register_memory_bank(banks_impl, embedding_model)
|
||||
await memory_impl.insert_documents(
|
||||
registered_bank.memory_bank_id, sample_documents
|
||||
)
|
||||
|
|
|
|||
|
|
@ -74,7 +74,9 @@ def pytest_addoption(parser):
|
|||
|
||||
|
||||
SAFETY_SHIELD_PARAMS = [
|
||||
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
|
||||
pytest.param(
|
||||
"meta-llama/Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -86,6 +88,7 @@ def pytest_generate_tests(metafunc):
|
|||
if "safety_shield" in metafunc.fixturenames:
|
||||
shield_id = metafunc.config.getoption("--safety-shield")
|
||||
if shield_id:
|
||||
assert shield_id.startswith("meta-llama/")
|
||||
params = [pytest.param(shield_id, id="")]
|
||||
else:
|
||||
params = SAFETY_SHIELD_PARAMS
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue