chore: Add fixtures to conftest.py (#2067)

Add fixtures for SqliteKVStore, DiskDistributionRegistry and
CachedDiskDistributionRegistry. And use them in tests that had all been
duplicating similar setups.

## Test Plan
unit tests continue to run

Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
Derek Higgins 2025-05-06 12:57:48 +01:00 committed by GitHub
parent 4597145011
commit 2e807b38cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 122 additions and 167 deletions

View file

@ -4,10 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
from llama_stack.apis.inference import Model
from llama_stack.apis.vector_dbs import VectorDB
@ -20,28 +18,6 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@pytest.fixture
def config():
config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
if os.path.exists(config.db_path):
os.remove(config.db_path)
return config
@pytest_asyncio.fixture(scope="function")
async def registry(config):
registry = DiskDistributionRegistry(await kvstore_impl(config))
await registry.initialize()
return registry
@pytest_asyncio.fixture(scope="function")
async def cached_registry(config):
registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await registry.initialize()
return registry
@pytest.fixture
def sample_vector_db():
return VectorDB(
@ -63,41 +39,42 @@ def sample_model():
@pytest.mark.asyncio
async def test_registry_initialization(registry):
async def test_registry_initialization(disk_dist_registry):
# Test empty registry
result = await registry.get("nonexistent", "nonexistent")
result = await disk_dist_registry.get("nonexistent", "nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_basic_registration(registry, sample_vector_db, sample_model):
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
print(f"Registering {sample_vector_db}")
await registry.register(sample_vector_db)
await disk_dist_registry.register(sample_vector_db)
print(f"Registering {sample_model}")
await registry.register(sample_model)
await disk_dist_registry.register(sample_model)
print("Getting vector_db")
result_vector_db = await registry.get("vector_db", "test_vector_db")
result_vector_db = await disk_dist_registry.get("vector_db", "test_vector_db")
assert result_vector_db is not None
assert result_vector_db.identifier == sample_vector_db.identifier
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
assert result_vector_db.provider_id == sample_vector_db.provider_id
result_model = await registry.get("model", "test_model")
result_model = await disk_dist_registry.get("model", "test_model")
assert result_model is not None
assert result_model.identifier == sample_model.identifier
assert result_model.provider_id == sample_model.provider_id
@pytest.mark.asyncio
async def test_cached_registry_initialization(config, sample_vector_db, sample_model):
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
# First populate the disk registry
disk_registry = DiskDistributionRegistry(await kvstore_impl(config))
disk_registry = DiskDistributionRegistry(sqlite_kvstore)
await disk_registry.initialize()
await disk_registry.register(sample_vector_db)
await disk_registry.register(sample_model)
# Test cached version loads from disk
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
db_path = sqlite_kvstore.db_path
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
await cached_registry.initialize()
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
@ -109,10 +86,7 @@ async def test_cached_registry_initialization(config, sample_vector_db, sample_m
@pytest.mark.asyncio
async def test_cached_registry_updates(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
async def test_cached_registry_updates(cached_disk_dist_registry):
new_vector_db = VectorDB(
identifier="test_vector_db_2",
embedding_model="all-MiniLM-L6-v2",
@ -120,16 +94,17 @@ async def test_cached_registry_updates(config):
provider_resource_id="test_vector_db_2",
provider_id="baz",
)
await cached_registry.register(new_vector_db)
await cached_disk_dist_registry.register(new_vector_db)
# Verify in cache
result_vector_db = await cached_registry.get("vector_db", "test_vector_db_2")
result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result_vector_db is not None
assert result_vector_db.identifier == new_vector_db.identifier
assert result_vector_db.provider_id == new_vector_db.provider_id
# Verify persisted to disk
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
db_path = cached_disk_dist_registry.kvstore.db_path
new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
await new_registry.initialize()
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
assert result_vector_db is not None
@ -138,10 +113,7 @@ async def test_cached_registry_updates(config):
@pytest.mark.asyncio
async def test_duplicate_provider_registration(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
async def test_duplicate_provider_registration(cached_disk_dist_registry):
original_vector_db = VectorDB(
identifier="test_vector_db_2",
embedding_model="all-MiniLM-L6-v2",
@ -149,7 +121,7 @@ async def test_duplicate_provider_registration(config):
provider_resource_id="test_vector_db_2",
provider_id="baz",
)
await cached_registry.register(original_vector_db)
await cached_disk_dist_registry.register(original_vector_db)
duplicate_vector_db = VectorDB(
identifier="test_vector_db_2",
@ -158,18 +130,16 @@ async def test_duplicate_provider_registration(config):
provider_resource_id="test_vector_db_2",
provider_id="baz", # Same provider_id
)
await cached_registry.register(duplicate_vector_db)
await cached_disk_dist_registry.register(duplicate_vector_db)
result = await cached_registry.get("vector_db", "test_vector_db_2")
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result is not None
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
@pytest.mark.asyncio
async def test_get_all_objects(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
async def test_get_all_objects(cached_disk_dist_registry):
# Create multiple test banks
# Create multiple test banks
test_vector_dbs = [
VectorDB(
@ -184,10 +154,10 @@ async def test_get_all_objects(config):
# Register all vector_dbs
for vector_db in test_vector_dbs:
await cached_registry.register(vector_db)
await cached_disk_dist_registry.register(vector_db)
# Test get_all retrieval
all_results = await cached_registry.get_all()
all_results = await cached_disk_dist_registry.get_all()
assert len(all_results) == 3
# Verify each vector_db was stored correctly
@ -201,9 +171,7 @@ async def test_get_all_objects(config):
@pytest.mark.asyncio
async def test_parse_registry_values_error_handling(config):
kvstore = await kvstore_impl(config)
async def test_parse_registry_values_error_handling(sqlite_kvstore):
valid_db = VectorDB(
identifier="valid_vector_db",
embedding_model="all-MiniLM-L6-v2",
@ -212,16 +180,18 @@ async def test_parse_registry_values_error_handling(config):
provider_id="test-provider",
)
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json())
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
)
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
await kvstore.set(
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
'{"type": "vector_db", "identifier": "missing_fields"}',
)
test_registry = DiskDistributionRegistry(kvstore)
test_registry = DiskDistributionRegistry(sqlite_kvstore)
await test_registry.initialize()
# Get all objects, which should only return the valid one
@ -240,9 +210,7 @@ async def test_parse_registry_values_error_handling(config):
@pytest.mark.asyncio
async def test_cached_registry_error_handling(config):
kvstore = await kvstore_impl(config)
async def test_cached_registry_error_handling(sqlite_kvstore):
valid_db = VectorDB(
identifier="valid_cached_db",
embedding_model="all-MiniLM-L6-v2",
@ -251,14 +219,16 @@ async def test_cached_registry_error_handling(config):
provider_id="test-provider",
)
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json())
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
)
await kvstore.set(
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
)
cached_registry = CachedDiskDistributionRegistry(kvstore)
cached_registry = CachedDiskDistributionRegistry(sqlite_kvstore)
await cached_registry.initialize()
all_objects = await cached_registry.get_all()