forked from phoenix-oss/llama-stack-mirror
chore: Add fixtures to conftest.py (#2067)
Add fixtures for SqliteKVStore, DiskDistributionRegistry and CachedDiskDistributionRegistry. And use them in tests that had all been duplicating similar setups. ## Test Plan unit tests continue to run Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
4597145011
commit
2e807b38cc
8 changed files with 122 additions and 167 deletions
|
@ -26,20 +26,6 @@ from llama_stack.distribution.routers.routing_tables import (
|
|||
ToolGroupsRoutingTable,
|
||||
VectorDBsRoutingTable,
|
||||
)
|
||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dist_registry(tmp_path):
|
||||
db_path = tmp_path / "test_kv.db"
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
|
||||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||
await kvstore.initialize()
|
||||
registry = CachedDiskDistributionRegistry(kvstore)
|
||||
await registry.initialize()
|
||||
yield registry
|
||||
|
||||
|
||||
class Impl:
|
||||
|
@ -136,8 +122,8 @@ class ToolGroupsImpl(Impl):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_models_routing_table(dist_registry):
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, dist_registry)
|
||||
async def test_models_routing_table(cached_disk_dist_registry):
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple models and verify listing
|
||||
|
@ -178,8 +164,8 @@ async def test_models_routing_table(dist_registry):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shields_routing_table(dist_registry):
|
||||
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, dist_registry)
|
||||
async def test_shields_routing_table(cached_disk_dist_registry):
|
||||
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple shields and verify listing
|
||||
|
@ -194,11 +180,11 @@ async def test_shields_routing_table(dist_registry):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vectordbs_routing_table(dist_registry):
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, dist_registry)
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry)
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, dist_registry)
|
||||
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry)
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
|
@ -224,8 +210,8 @@ async def test_vectordbs_routing_table(dist_registry):
|
|||
assert len(vector_dbs.data) == 0
|
||||
|
||||
|
||||
async def test_datasets_routing_table(dist_registry):
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, dist_registry)
|
||||
async def test_datasets_routing_table(cached_disk_dist_registry):
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple datasets and verify listing
|
||||
|
@ -250,8 +236,8 @@ async def test_datasets_routing_table(dist_registry):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_functions_routing_table(dist_registry):
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, dist_registry)
|
||||
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple scoring functions and verify listing
|
||||
|
@ -276,8 +262,8 @@ async def test_scoring_functions_routing_table(dist_registry):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_benchmarks_routing_table(dist_registry):
|
||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, dist_registry)
|
||||
async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple benchmarks and verify listing
|
||||
|
@ -294,8 +280,8 @@ async def test_benchmarks_routing_table(dist_registry):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_groups_routing_table(dist_registry):
|
||||
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, dist_registry)
|
||||
async def test_tool_groups_routing_table(cached_disk_dist_registry):
|
||||
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple tool groups and verify listing
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue