mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
chore: Add fixtures to conftest.py (#2067)
Add fixtures for SqliteKVStore, DiskDistributionRegistry and CachedDiskDistributionRegistry. And use them in tests that had all been duplicating similar setups. ## Test Plan unit tests continue to run Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
4597145011
commit
2e807b38cc
8 changed files with 122 additions and 167 deletions
5
tests/unit/__init__.py
Normal file
5
tests/unit/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
9
tests/unit/conftest.py
Normal file
9
tests/unit/conftest.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# We need to import the fixtures here so that pytest can find them
|
||||||
|
# but ruff doesn't think they are used and removes the import. "noqa: F401" prevents them from being removed
|
||||||
|
from .fixtures import cached_disk_dist_registry, disk_dist_registry, sqlite_kvstore # noqa: F401
|
|
@ -26,20 +26,6 @@ from llama_stack.distribution.routers.routing_tables import (
|
||||||
ToolGroupsRoutingTable,
|
ToolGroupsRoutingTable,
|
||||||
VectorDBsRoutingTable,
|
VectorDBsRoutingTable,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
||||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def dist_registry(tmp_path):
|
|
||||||
db_path = tmp_path / "test_kv.db"
|
|
||||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
|
|
||||||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
|
||||||
await kvstore.initialize()
|
|
||||||
registry = CachedDiskDistributionRegistry(kvstore)
|
|
||||||
await registry.initialize()
|
|
||||||
yield registry
|
|
||||||
|
|
||||||
|
|
||||||
class Impl:
|
class Impl:
|
||||||
|
@ -136,8 +122,8 @@ class ToolGroupsImpl(Impl):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_models_routing_table(dist_registry):
|
async def test_models_routing_table(cached_disk_dist_registry):
|
||||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, dist_registry)
|
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry)
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# Register multiple models and verify listing
|
# Register multiple models and verify listing
|
||||||
|
@ -178,8 +164,8 @@ async def test_models_routing_table(dist_registry):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_shields_routing_table(dist_registry):
|
async def test_shields_routing_table(cached_disk_dist_registry):
|
||||||
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, dist_registry)
|
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry)
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# Register multiple shields and verify listing
|
# Register multiple shields and verify listing
|
||||||
|
@ -194,11 +180,11 @@ async def test_shields_routing_table(dist_registry):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_vectordbs_routing_table(dist_registry):
|
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, dist_registry)
|
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry)
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, dist_registry)
|
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry)
|
||||||
await m_table.initialize()
|
await m_table.initialize()
|
||||||
await m_table.register_model(
|
await m_table.register_model(
|
||||||
model_id="test-model",
|
model_id="test-model",
|
||||||
|
@ -224,8 +210,8 @@ async def test_vectordbs_routing_table(dist_registry):
|
||||||
assert len(vector_dbs.data) == 0
|
assert len(vector_dbs.data) == 0
|
||||||
|
|
||||||
|
|
||||||
async def test_datasets_routing_table(dist_registry):
|
async def test_datasets_routing_table(cached_disk_dist_registry):
|
||||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, dist_registry)
|
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry)
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# Register multiple datasets and verify listing
|
# Register multiple datasets and verify listing
|
||||||
|
@ -250,8 +236,8 @@ async def test_datasets_routing_table(dist_registry):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_scoring_functions_routing_table(dist_registry):
|
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
||||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, dist_registry)
|
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry)
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# Register multiple scoring functions and verify listing
|
# Register multiple scoring functions and verify listing
|
||||||
|
@ -276,8 +262,8 @@ async def test_scoring_functions_routing_table(dist_registry):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_benchmarks_routing_table(dist_registry):
|
async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
||||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, dist_registry)
|
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry)
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# Register multiple benchmarks and verify listing
|
# Register multiple benchmarks and verify listing
|
||||||
|
@ -294,8 +280,8 @@ async def test_benchmarks_routing_table(dist_registry):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tool_groups_routing_table(dist_registry):
|
async def test_tool_groups_routing_table(cached_disk_dist_registry):
|
||||||
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, dist_registry)
|
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry)
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# Register multiple tool groups and verify listing
|
# Register multiple tool groups and verify listing
|
||||||
|
|
34
tests/unit/fixtures.py
Normal file
34
tests/unit/fixtures.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
|
||||||
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
async def sqlite_kvstore(tmp_path):
|
||||||
|
db_path = tmp_path / "test_kv.db"
|
||||||
|
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
|
||||||
|
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||||
|
await kvstore.initialize()
|
||||||
|
yield kvstore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
async def disk_dist_registry(sqlite_kvstore):
|
||||||
|
registry = DiskDistributionRegistry(sqlite_kvstore)
|
||||||
|
await registry.initialize()
|
||||||
|
yield registry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
async def cached_disk_dist_registry(sqlite_kvstore):
|
||||||
|
registry = CachedDiskDistributionRegistry(sqlite_kvstore)
|
||||||
|
await registry.initialize()
|
||||||
|
yield registry
|
|
@ -4,9 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import tempfile
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
@ -17,20 +14,12 @@ from llama_stack.apis.agents import Turn
|
||||||
from llama_stack.apis.inference import CompletionMessage, StopReason
|
from llama_stack.apis.inference import CompletionMessage, StopReason
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
||||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def test_setup():
|
async def test_setup(sqlite_kvstore):
|
||||||
temp_dir = tempfile.mkdtemp()
|
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore)
|
||||||
db_path = os.path.join(temp_dir, "test_persistence_access_control.db")
|
|
||||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
|
|
||||||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
|
||||||
await kvstore.initialize()
|
|
||||||
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=kvstore)
|
|
||||||
yield agent_persistence
|
yield agent_persistence
|
||||||
shutil.rmtree(temp_dir)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
@ -4,10 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import Model
|
from llama_stack.apis.inference import Model
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
@ -20,28 +18,6 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def config():
|
|
||||||
config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
|
|
||||||
if os.path.exists(config.db_path):
|
|
||||||
os.remove(config.db_path)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="function")
|
|
||||||
async def registry(config):
|
|
||||||
registry = DiskDistributionRegistry(await kvstore_impl(config))
|
|
||||||
await registry.initialize()
|
|
||||||
return registry
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="function")
|
|
||||||
async def cached_registry(config):
|
|
||||||
registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
|
||||||
await registry.initialize()
|
|
||||||
return registry
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_vector_db():
|
def sample_vector_db():
|
||||||
return VectorDB(
|
return VectorDB(
|
||||||
|
@ -63,41 +39,42 @@ def sample_model():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_registry_initialization(registry):
|
async def test_registry_initialization(disk_dist_registry):
|
||||||
# Test empty registry
|
# Test empty registry
|
||||||
result = await registry.get("nonexistent", "nonexistent")
|
result = await disk_dist_registry.get("nonexistent", "nonexistent")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_basic_registration(registry, sample_vector_db, sample_model):
|
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
|
||||||
print(f"Registering {sample_vector_db}")
|
print(f"Registering {sample_vector_db}")
|
||||||
await registry.register(sample_vector_db)
|
await disk_dist_registry.register(sample_vector_db)
|
||||||
print(f"Registering {sample_model}")
|
print(f"Registering {sample_model}")
|
||||||
await registry.register(sample_model)
|
await disk_dist_registry.register(sample_model)
|
||||||
print("Getting vector_db")
|
print("Getting vector_db")
|
||||||
result_vector_db = await registry.get("vector_db", "test_vector_db")
|
result_vector_db = await disk_dist_registry.get("vector_db", "test_vector_db")
|
||||||
assert result_vector_db is not None
|
assert result_vector_db is not None
|
||||||
assert result_vector_db.identifier == sample_vector_db.identifier
|
assert result_vector_db.identifier == sample_vector_db.identifier
|
||||||
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
||||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
||||||
|
|
||||||
result_model = await registry.get("model", "test_model")
|
result_model = await disk_dist_registry.get("model", "test_model")
|
||||||
assert result_model is not None
|
assert result_model is not None
|
||||||
assert result_model.identifier == sample_model.identifier
|
assert result_model.identifier == sample_model.identifier
|
||||||
assert result_model.provider_id == sample_model.provider_id
|
assert result_model.provider_id == sample_model.provider_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cached_registry_initialization(config, sample_vector_db, sample_model):
|
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
|
||||||
# First populate the disk registry
|
# First populate the disk registry
|
||||||
disk_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
disk_registry = DiskDistributionRegistry(sqlite_kvstore)
|
||||||
await disk_registry.initialize()
|
await disk_registry.initialize()
|
||||||
await disk_registry.register(sample_vector_db)
|
await disk_registry.register(sample_vector_db)
|
||||||
await disk_registry.register(sample_model)
|
await disk_registry.register(sample_model)
|
||||||
|
|
||||||
# Test cached version loads from disk
|
# Test cached version loads from disk
|
||||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
db_path = sqlite_kvstore.db_path
|
||||||
|
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
|
||||||
await cached_registry.initialize()
|
await cached_registry.initialize()
|
||||||
|
|
||||||
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
|
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
|
||||||
|
@ -109,10 +86,7 @@ async def test_cached_registry_initialization(config, sample_vector_db, sample_m
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cached_registry_updates(config):
|
async def test_cached_registry_updates(cached_disk_dist_registry):
|
||||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
|
||||||
await cached_registry.initialize()
|
|
||||||
|
|
||||||
new_vector_db = VectorDB(
|
new_vector_db = VectorDB(
|
||||||
identifier="test_vector_db_2",
|
identifier="test_vector_db_2",
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
@ -120,16 +94,17 @@ async def test_cached_registry_updates(config):
|
||||||
provider_resource_id="test_vector_db_2",
|
provider_resource_id="test_vector_db_2",
|
||||||
provider_id="baz",
|
provider_id="baz",
|
||||||
)
|
)
|
||||||
await cached_registry.register(new_vector_db)
|
await cached_disk_dist_registry.register(new_vector_db)
|
||||||
|
|
||||||
# Verify in cache
|
# Verify in cache
|
||||||
result_vector_db = await cached_registry.get("vector_db", "test_vector_db_2")
|
result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||||
assert result_vector_db is not None
|
assert result_vector_db is not None
|
||||||
assert result_vector_db.identifier == new_vector_db.identifier
|
assert result_vector_db.identifier == new_vector_db.identifier
|
||||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
assert result_vector_db.provider_id == new_vector_db.provider_id
|
||||||
|
|
||||||
# Verify persisted to disk
|
# Verify persisted to disk
|
||||||
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
db_path = cached_disk_dist_registry.kvstore.db_path
|
||||||
|
new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
|
||||||
await new_registry.initialize()
|
await new_registry.initialize()
|
||||||
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
|
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
|
||||||
assert result_vector_db is not None
|
assert result_vector_db is not None
|
||||||
|
@ -138,10 +113,7 @@ async def test_cached_registry_updates(config):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_duplicate_provider_registration(config):
|
async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
||||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
|
||||||
await cached_registry.initialize()
|
|
||||||
|
|
||||||
original_vector_db = VectorDB(
|
original_vector_db = VectorDB(
|
||||||
identifier="test_vector_db_2",
|
identifier="test_vector_db_2",
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
@ -149,7 +121,7 @@ async def test_duplicate_provider_registration(config):
|
||||||
provider_resource_id="test_vector_db_2",
|
provider_resource_id="test_vector_db_2",
|
||||||
provider_id="baz",
|
provider_id="baz",
|
||||||
)
|
)
|
||||||
await cached_registry.register(original_vector_db)
|
await cached_disk_dist_registry.register(original_vector_db)
|
||||||
|
|
||||||
duplicate_vector_db = VectorDB(
|
duplicate_vector_db = VectorDB(
|
||||||
identifier="test_vector_db_2",
|
identifier="test_vector_db_2",
|
||||||
|
@ -158,18 +130,16 @@ async def test_duplicate_provider_registration(config):
|
||||||
provider_resource_id="test_vector_db_2",
|
provider_resource_id="test_vector_db_2",
|
||||||
provider_id="baz", # Same provider_id
|
provider_id="baz", # Same provider_id
|
||||||
)
|
)
|
||||||
await cached_registry.register(duplicate_vector_db)
|
await cached_disk_dist_registry.register(duplicate_vector_db)
|
||||||
|
|
||||||
result = await cached_registry.get("vector_db", "test_vector_db_2")
|
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_all_objects(config):
|
async def test_get_all_objects(cached_disk_dist_registry):
|
||||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
# Create multiple test banks
|
||||||
await cached_registry.initialize()
|
|
||||||
|
|
||||||
# Create multiple test banks
|
# Create multiple test banks
|
||||||
test_vector_dbs = [
|
test_vector_dbs = [
|
||||||
VectorDB(
|
VectorDB(
|
||||||
|
@ -184,10 +154,10 @@ async def test_get_all_objects(config):
|
||||||
|
|
||||||
# Register all vector_dbs
|
# Register all vector_dbs
|
||||||
for vector_db in test_vector_dbs:
|
for vector_db in test_vector_dbs:
|
||||||
await cached_registry.register(vector_db)
|
await cached_disk_dist_registry.register(vector_db)
|
||||||
|
|
||||||
# Test get_all retrieval
|
# Test get_all retrieval
|
||||||
all_results = await cached_registry.get_all()
|
all_results = await cached_disk_dist_registry.get_all()
|
||||||
assert len(all_results) == 3
|
assert len(all_results) == 3
|
||||||
|
|
||||||
# Verify each vector_db was stored correctly
|
# Verify each vector_db was stored correctly
|
||||||
|
@ -201,9 +171,7 @@ async def test_get_all_objects(config):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_parse_registry_values_error_handling(config):
|
async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
||||||
kvstore = await kvstore_impl(config)
|
|
||||||
|
|
||||||
valid_db = VectorDB(
|
valid_db = VectorDB(
|
||||||
identifier="valid_vector_db",
|
identifier="valid_vector_db",
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
@ -212,16 +180,18 @@ async def test_parse_registry_values_error_handling(config):
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json())
|
await sqlite_kvstore.set(
|
||||||
|
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
|
||||||
|
)
|
||||||
|
|
||||||
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
|
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
|
||||||
|
|
||||||
await kvstore.set(
|
await sqlite_kvstore.set(
|
||||||
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
||||||
'{"type": "vector_db", "identifier": "missing_fields"}',
|
'{"type": "vector_db", "identifier": "missing_fields"}',
|
||||||
)
|
)
|
||||||
|
|
||||||
test_registry = DiskDistributionRegistry(kvstore)
|
test_registry = DiskDistributionRegistry(sqlite_kvstore)
|
||||||
await test_registry.initialize()
|
await test_registry.initialize()
|
||||||
|
|
||||||
# Get all objects, which should only return the valid one
|
# Get all objects, which should only return the valid one
|
||||||
|
@ -240,9 +210,7 @@ async def test_parse_registry_values_error_handling(config):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cached_registry_error_handling(config):
|
async def test_cached_registry_error_handling(sqlite_kvstore):
|
||||||
kvstore = await kvstore_impl(config)
|
|
||||||
|
|
||||||
valid_db = VectorDB(
|
valid_db = VectorDB(
|
||||||
identifier="valid_cached_db",
|
identifier="valid_cached_db",
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
@ -251,14 +219,16 @@ async def test_cached_registry_error_handling(config):
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json())
|
await sqlite_kvstore.set(
|
||||||
|
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
|
||||||
|
)
|
||||||
|
|
||||||
await kvstore.set(
|
await sqlite_kvstore.set(
|
||||||
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
|
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
|
||||||
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
|
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
|
||||||
)
|
)
|
||||||
|
|
||||||
cached_registry = CachedDiskDistributionRegistry(kvstore)
|
cached_registry = CachedDiskDistributionRegistry(sqlite_kvstore)
|
||||||
await cached_registry.initialize()
|
await cached_registry.initialize()
|
||||||
|
|
||||||
all_objects = await cached_registry.get_all()
|
all_objects = await cached_registry.get_all()
|
||||||
|
|
|
@ -4,9 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -14,30 +11,10 @@ from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.distribution.datatypes import ModelWithACL
|
from llama_stack.distribution.datatypes import ModelWithACL
|
||||||
from llama_stack.distribution.server.auth_providers import AccessAttributes
|
from llama_stack.distribution.server.auth_providers import AccessAttributes
|
||||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
||||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
async def kvstore():
|
|
||||||
temp_dir = tempfile.mkdtemp()
|
|
||||||
db_path = os.path.join(temp_dir, "test_registry_acl.db")
|
|
||||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
|
|
||||||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
|
||||||
await kvstore.initialize()
|
|
||||||
yield kvstore
|
|
||||||
shutil.rmtree(temp_dir)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
async def registry(kvstore):
|
|
||||||
registry = CachedDiskDistributionRegistry(kvstore)
|
|
||||||
await registry.initialize()
|
|
||||||
return registry
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_registry_cache_with_acl(registry):
|
async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
||||||
model = ModelWithACL(
|
model = ModelWithACL(
|
||||||
identifier="model-acl",
|
identifier="model-acl",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
|
@ -46,30 +23,30 @@ async def test_registry_cache_with_acl(registry):
|
||||||
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
|
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
success = await registry.register(model)
|
success = await cached_disk_dist_registry.register(model)
|
||||||
assert success
|
assert success
|
||||||
|
|
||||||
cached_model = registry.get_cached("model", "model-acl")
|
cached_model = cached_disk_dist_registry.get_cached("model", "model-acl")
|
||||||
assert cached_model is not None
|
assert cached_model is not None
|
||||||
assert cached_model.identifier == "model-acl"
|
assert cached_model.identifier == "model-acl"
|
||||||
assert cached_model.access_attributes.roles == ["admin"]
|
assert cached_model.access_attributes.roles == ["admin"]
|
||||||
assert cached_model.access_attributes.teams == ["ai-team"]
|
assert cached_model.access_attributes.teams == ["ai-team"]
|
||||||
|
|
||||||
fetched_model = await registry.get("model", "model-acl")
|
fetched_model = await cached_disk_dist_registry.get("model", "model-acl")
|
||||||
assert fetched_model is not None
|
assert fetched_model is not None
|
||||||
assert fetched_model.identifier == "model-acl"
|
assert fetched_model.identifier == "model-acl"
|
||||||
assert fetched_model.access_attributes.roles == ["admin"]
|
assert fetched_model.access_attributes.roles == ["admin"]
|
||||||
|
|
||||||
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
|
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
|
||||||
await registry.update(model)
|
await cached_disk_dist_registry.update(model)
|
||||||
|
|
||||||
updated_cached = registry.get_cached("model", "model-acl")
|
updated_cached = cached_disk_dist_registry.get_cached("model", "model-acl")
|
||||||
assert updated_cached is not None
|
assert updated_cached is not None
|
||||||
assert updated_cached.access_attributes.roles == ["admin", "user"]
|
assert updated_cached.access_attributes.roles == ["admin", "user"]
|
||||||
assert updated_cached.access_attributes.projects == ["project-x"]
|
assert updated_cached.access_attributes.projects == ["project-x"]
|
||||||
assert updated_cached.access_attributes.teams is None
|
assert updated_cached.access_attributes.teams is None
|
||||||
|
|
||||||
new_registry = CachedDiskDistributionRegistry(registry.kvstore)
|
new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
|
||||||
await new_registry.initialize()
|
await new_registry.initialize()
|
||||||
|
|
||||||
new_model = await new_registry.get("model", "model-acl")
|
new_model = await new_registry.get("model", "model-acl")
|
||||||
|
@ -81,7 +58,7 @@ async def test_registry_cache_with_acl(registry):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_registry_empty_acl(registry):
|
async def test_registry_empty_acl(cached_disk_dist_registry):
|
||||||
model = ModelWithACL(
|
model = ModelWithACL(
|
||||||
identifier="model-empty-acl",
|
identifier="model-empty-acl",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
|
@ -90,9 +67,9 @@ async def test_registry_empty_acl(registry):
|
||||||
access_attributes=AccessAttributes(),
|
access_attributes=AccessAttributes(),
|
||||||
)
|
)
|
||||||
|
|
||||||
await registry.register(model)
|
await cached_disk_dist_registry.register(model)
|
||||||
|
|
||||||
cached_model = registry.get_cached("model", "model-empty-acl")
|
cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl")
|
||||||
assert cached_model is not None
|
assert cached_model is not None
|
||||||
assert cached_model.access_attributes is not None
|
assert cached_model.access_attributes is not None
|
||||||
assert cached_model.access_attributes.roles is None
|
assert cached_model.access_attributes.roles is None
|
||||||
|
@ -100,7 +77,7 @@ async def test_registry_empty_acl(registry):
|
||||||
assert cached_model.access_attributes.projects is None
|
assert cached_model.access_attributes.projects is None
|
||||||
assert cached_model.access_attributes.namespaces is None
|
assert cached_model.access_attributes.namespaces is None
|
||||||
|
|
||||||
all_models = await registry.get_all()
|
all_models = await cached_disk_dist_registry.get_all()
|
||||||
assert len(all_models) == 1
|
assert len(all_models) == 1
|
||||||
|
|
||||||
model = ModelWithACL(
|
model = ModelWithACL(
|
||||||
|
@ -110,18 +87,18 @@ async def test_registry_empty_acl(registry):
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
await registry.register(model)
|
await cached_disk_dist_registry.register(model)
|
||||||
|
|
||||||
cached_model = registry.get_cached("model", "model-no-acl")
|
cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl")
|
||||||
assert cached_model is not None
|
assert cached_model is not None
|
||||||
assert cached_model.access_attributes is None
|
assert cached_model.access_attributes is None
|
||||||
|
|
||||||
all_models = await registry.get_all()
|
all_models = await cached_disk_dist_registry.get_all()
|
||||||
assert len(all_models) == 2
|
assert len(all_models) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_registry_serialization(registry):
|
async def test_registry_serialization(cached_disk_dist_registry):
|
||||||
attributes = AccessAttributes(
|
attributes = AccessAttributes(
|
||||||
roles=["admin", "researcher"],
|
roles=["admin", "researcher"],
|
||||||
teams=["ai-team", "ml-team"],
|
teams=["ai-team", "ml-team"],
|
||||||
|
@ -137,9 +114,9 @@ async def test_registry_serialization(registry):
|
||||||
access_attributes=attributes,
|
access_attributes=attributes,
|
||||||
)
|
)
|
||||||
|
|
||||||
await registry.register(model)
|
await cached_disk_dist_registry.register(model)
|
||||||
|
|
||||||
new_registry = CachedDiskDistributionRegistry(registry.kvstore)
|
new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
|
||||||
await new_registry.initialize()
|
await new_registry.initialize()
|
||||||
|
|
||||||
loaded_model = await new_registry.get("model", "model-serialize")
|
loaded_model = await new_registry.get("model", "model-serialize")
|
||||||
|
|
|
@ -4,9 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import tempfile
|
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -15,9 +12,6 @@ from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
||||||
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
|
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
|
||||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
||||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncMock(MagicMock):
|
class AsyncMock(MagicMock):
|
||||||
|
@ -30,25 +24,16 @@ def _return_model(model):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def test_setup():
|
async def test_setup(cached_disk_dist_registry):
|
||||||
temp_dir = tempfile.mkdtemp()
|
|
||||||
db_path = os.path.join(temp_dir, "test_access_control.db")
|
|
||||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
|
|
||||||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
|
||||||
await kvstore.initialize()
|
|
||||||
registry = CachedDiskDistributionRegistry(kvstore)
|
|
||||||
await registry.initialize()
|
|
||||||
|
|
||||||
mock_inference = Mock()
|
mock_inference = Mock()
|
||||||
mock_inference.__provider_spec__ = MagicMock()
|
mock_inference.__provider_spec__ = MagicMock()
|
||||||
mock_inference.__provider_spec__.api = Api.inference
|
mock_inference.__provider_spec__.api = Api.inference
|
||||||
mock_inference.register_model = AsyncMock(side_effect=_return_model)
|
mock_inference.register_model = AsyncMock(side_effect=_return_model)
|
||||||
routing_table = ModelsRoutingTable(
|
routing_table = ModelsRoutingTable(
|
||||||
impls_by_provider_id={"test_provider": mock_inference},
|
impls_by_provider_id={"test_provider": mock_inference},
|
||||||
dist_registry=registry,
|
dist_registry=cached_disk_dist_registry,
|
||||||
)
|
)
|
||||||
yield registry, routing_table
|
yield cached_disk_dist_registry, routing_table
|
||||||
shutil.rmtree(temp_dir)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue