chore: Add fixtures to conftest.py (#2067)

Add fixtures for SqliteKVStore, DiskDistributionRegistry and
CachedDiskDistributionRegistry. And use them in tests that had all been
duplicating similar setups.

## Test Plan
unit tests continue to run

Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
Derek Higgins 2025-05-06 12:57:48 +01:00 committed by GitHub
parent 4597145011
commit 2e807b38cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 122 additions and 167 deletions

View file

@ -26,20 +26,6 @@ from llama_stack.distribution.routers.routing_tables import (
ToolGroupsRoutingTable,
VectorDBsRoutingTable,
)
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture
async def dist_registry(tmp_path):
db_path = tmp_path / "test_kv.db"
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
registry = CachedDiskDistributionRegistry(kvstore)
await registry.initialize()
yield registry
class Impl:
@ -136,8 +122,8 @@ class ToolGroupsImpl(Impl):
@pytest.mark.asyncio
async def test_models_routing_table(dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, dist_registry)
async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry)
await table.initialize()
# Register multiple models and verify listing
@ -178,8 +164,8 @@ async def test_models_routing_table(dist_registry):
@pytest.mark.asyncio
async def test_shields_routing_table(dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, dist_registry)
async def test_shields_routing_table(cached_disk_dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry)
await table.initialize()
# Register multiple shields and verify listing
@ -194,11 +180,11 @@ async def test_shields_routing_table(dist_registry):
@pytest.mark.asyncio
async def test_vectordbs_routing_table(dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, dist_registry)
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry)
await table.initialize()
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, dist_registry)
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry)
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
@ -224,8 +210,8 @@ async def test_vectordbs_routing_table(dist_registry):
assert len(vector_dbs.data) == 0
async def test_datasets_routing_table(dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, dist_registry)
async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry)
await table.initialize()
# Register multiple datasets and verify listing
@ -250,8 +236,8 @@ async def test_datasets_routing_table(dist_registry):
@pytest.mark.asyncio
async def test_scoring_functions_routing_table(dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, dist_registry)
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry)
await table.initialize()
# Register multiple scoring functions and verify listing
@ -276,8 +262,8 @@ async def test_scoring_functions_routing_table(dist_registry):
@pytest.mark.asyncio
async def test_benchmarks_routing_table(dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, dist_registry)
async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry)
await table.initialize()
# Register multiple benchmarks and verify listing
@ -294,8 +280,8 @@ async def test_benchmarks_routing_table(dist_registry):
@pytest.mark.asyncio
async def test_tool_groups_routing_table(dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, dist_registry)
async def test_tool_groups_routing_table(cached_disk_dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry)
await table.initialize()
# Register multiple tool groups and verify listing