mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 21:48:36 +00:00
# What does this PR do? objects (vector dbs, models, scoring functions, etc) have an identifier and associated object values. we allow exact duplicate registrations. we reject registrations when the identifier exists and the associated object values differ. note: model are namespaced, i.e. {provider_id}/{identifier}, while other object types are not ## Test Plan ci w/ new tests
328 lines
12 KiB
Python
328 lines
12 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
import pytest
|
|
|
|
from llama_stack.apis.inference import Model
|
|
from llama_stack.apis.vector_dbs import VectorDB
|
|
from llama_stack.core.datatypes import VectorDBWithOwner
|
|
from llama_stack.core.store.registry import (
|
|
KEY_FORMAT,
|
|
CachedDiskDistributionRegistry,
|
|
DiskDistributionRegistry,
|
|
)
|
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_vector_db():
|
|
return VectorDB(
|
|
identifier="test_vector_db",
|
|
embedding_model="all-MiniLM-L6-v2",
|
|
embedding_dimension=384,
|
|
provider_resource_id="test_vector_db",
|
|
provider_id="test-provider",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_model():
|
|
return Model(
|
|
identifier="test_model",
|
|
provider_resource_id="test_model",
|
|
provider_id="test-provider",
|
|
)
|
|
|
|
|
|
async def test_registry_initialization(disk_dist_registry):
|
|
# Test empty registry
|
|
result = await disk_dist_registry.get("nonexistent", "nonexistent")
|
|
assert result is None
|
|
|
|
|
|
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
|
|
print(f"Registering {sample_vector_db}")
|
|
await disk_dist_registry.register(sample_vector_db)
|
|
print(f"Registering {sample_model}")
|
|
await disk_dist_registry.register(sample_model)
|
|
print("Getting vector_db")
|
|
result_vector_db = await disk_dist_registry.get("vector_db", "test_vector_db")
|
|
assert result_vector_db is not None
|
|
assert result_vector_db.identifier == sample_vector_db.identifier
|
|
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
|
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
|
|
|
result_model = await disk_dist_registry.get("model", "test_model")
|
|
assert result_model is not None
|
|
assert result_model.identifier == sample_model.identifier
|
|
assert result_model.provider_id == sample_model.provider_id
|
|
|
|
|
|
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
|
|
# First populate the disk registry
|
|
disk_registry = DiskDistributionRegistry(sqlite_kvstore)
|
|
await disk_registry.initialize()
|
|
await disk_registry.register(sample_vector_db)
|
|
await disk_registry.register(sample_model)
|
|
|
|
# Test cached version loads from disk
|
|
db_path = sqlite_kvstore.db_path
|
|
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
|
|
await cached_registry.initialize()
|
|
|
|
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
|
|
assert result_vector_db is not None
|
|
assert result_vector_db.identifier == sample_vector_db.identifier
|
|
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
|
assert result_vector_db.embedding_dimension == sample_vector_db.embedding_dimension
|
|
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
|
|
|
|
|
async def test_cached_registry_updates(cached_disk_dist_registry):
|
|
new_vector_db = VectorDB(
|
|
identifier="test_vector_db_2",
|
|
embedding_model="all-MiniLM-L6-v2",
|
|
embedding_dimension=384,
|
|
provider_resource_id="test_vector_db_2",
|
|
provider_id="baz",
|
|
)
|
|
await cached_disk_dist_registry.register(new_vector_db)
|
|
|
|
# Verify in cache
|
|
result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
|
assert result_vector_db is not None
|
|
assert result_vector_db.identifier == new_vector_db.identifier
|
|
assert result_vector_db.provider_id == new_vector_db.provider_id
|
|
|
|
# Verify persisted to disk
|
|
db_path = cached_disk_dist_registry.kvstore.db_path
|
|
new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
|
|
await new_registry.initialize()
|
|
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
|
|
assert result_vector_db is not None
|
|
assert result_vector_db.identifier == new_vector_db.identifier
|
|
assert result_vector_db.provider_id == new_vector_db.provider_id
|
|
|
|
|
|
async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
|
original_vector_db = VectorDB(
|
|
identifier="test_vector_db_2",
|
|
embedding_model="all-MiniLM-L6-v2",
|
|
embedding_dimension=384,
|
|
provider_resource_id="test_vector_db_2",
|
|
provider_id="baz",
|
|
)
|
|
assert await cached_disk_dist_registry.register(original_vector_db)
|
|
|
|
duplicate_vector_db = VectorDB(
|
|
identifier="test_vector_db_2",
|
|
embedding_model="different-model",
|
|
embedding_dimension=384,
|
|
provider_resource_id="test_vector_db_2",
|
|
provider_id="baz", # Same provider_id
|
|
)
|
|
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db_2' already exists"):
|
|
await cached_disk_dist_registry.register(duplicate_vector_db)
|
|
|
|
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
|
assert result is not None
|
|
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
|
|
|
|
|
async def test_get_all_objects(cached_disk_dist_registry):
|
|
# Create multiple test banks
|
|
# Create multiple test banks
|
|
test_vector_dbs = [
|
|
VectorDB(
|
|
identifier=f"test_vector_db_{i}",
|
|
embedding_model="all-MiniLM-L6-v2",
|
|
embedding_dimension=384,
|
|
provider_resource_id=f"test_vector_db_{i}",
|
|
provider_id=f"provider_{i}",
|
|
)
|
|
for i in range(3)
|
|
]
|
|
|
|
# Register all vector_dbs
|
|
for vector_db in test_vector_dbs:
|
|
await cached_disk_dist_registry.register(vector_db)
|
|
|
|
# Test get_all retrieval
|
|
all_results = await cached_disk_dist_registry.get_all()
|
|
assert len(all_results) == 3
|
|
|
|
# Verify each vector_db was stored correctly
|
|
for original_vector_db in test_vector_dbs:
|
|
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
|
|
assert len(matching_vector_dbs) == 1
|
|
stored_vector_db = matching_vector_dbs[0]
|
|
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
|
|
assert stored_vector_db.provider_id == original_vector_db.provider_id
|
|
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
|
|
|
|
|
async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
|
valid_db = VectorDB(
|
|
identifier="valid_vector_db",
|
|
embedding_model="all-MiniLM-L6-v2",
|
|
embedding_dimension=384,
|
|
provider_resource_id="valid_vector_db",
|
|
provider_id="test-provider",
|
|
)
|
|
|
|
await sqlite_kvstore.set(
|
|
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
|
|
)
|
|
|
|
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
|
|
|
|
await sqlite_kvstore.set(
|
|
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
|
'{"type": "vector_db", "identifier": "missing_fields"}',
|
|
)
|
|
|
|
test_registry = DiskDistributionRegistry(sqlite_kvstore)
|
|
await test_registry.initialize()
|
|
|
|
# Get all objects, which should only return the valid one
|
|
all_objects = await test_registry.get_all()
|
|
|
|
# Should have filtered out the invalid entries
|
|
assert len(all_objects) == 1
|
|
assert all_objects[0].identifier == "valid_vector_db"
|
|
|
|
# Check that the get method also handles errors correctly
|
|
invalid_obj = await test_registry.get("vector_db", "corrupted_json")
|
|
assert invalid_obj is None
|
|
|
|
invalid_obj = await test_registry.get("vector_db", "missing_fields")
|
|
assert invalid_obj is None
|
|
|
|
|
|
async def test_cached_registry_error_handling(sqlite_kvstore):
|
|
valid_db = VectorDB(
|
|
identifier="valid_cached_db",
|
|
embedding_model="all-MiniLM-L6-v2",
|
|
embedding_dimension=384,
|
|
provider_resource_id="valid_cached_db",
|
|
provider_id="test-provider",
|
|
)
|
|
|
|
await sqlite_kvstore.set(
|
|
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
|
|
)
|
|
|
|
await sqlite_kvstore.set(
|
|
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
|
|
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
|
|
)
|
|
|
|
cached_registry = CachedDiskDistributionRegistry(sqlite_kvstore)
|
|
await cached_registry.initialize()
|
|
|
|
all_objects = await cached_registry.get_all()
|
|
assert len(all_objects) == 1
|
|
assert all_objects[0].identifier == "valid_cached_db"
|
|
|
|
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
|
|
assert invalid_obj is None
|
|
|
|
|
|
async def test_double_registration_identical_objects(disk_dist_registry):
|
|
"""Test that registering identical objects succeeds (idempotent)."""
|
|
vector_db = VectorDBWithOwner(
|
|
identifier="test_vector_db",
|
|
embedding_model="all-MiniLM-L6-v2",
|
|
embedding_dimension=384,
|
|
provider_resource_id="test_vector_db",
|
|
provider_id="test-provider",
|
|
)
|
|
|
|
# First registration should succeed
|
|
result1 = await disk_dist_registry.register(vector_db)
|
|
assert result1 is True
|
|
|
|
# Second registration of identical object should also succeed (idempotent)
|
|
result2 = await disk_dist_registry.register(vector_db)
|
|
assert result2 is True
|
|
|
|
# Verify object exists and is unchanged
|
|
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
|
|
assert retrieved is not None
|
|
assert retrieved.identifier == vector_db.identifier
|
|
assert retrieved.embedding_model == vector_db.embedding_model
|
|
|
|
|
|
async def test_double_registration_different_objects(disk_dist_registry):
|
|
"""Test that registering different objects with same identifier fails."""
|
|
vector_db1 = VectorDBWithOwner(
|
|
identifier="test_vector_db",
|
|
embedding_model="all-MiniLM-L6-v2",
|
|
embedding_dimension=384,
|
|
provider_resource_id="test_vector_db",
|
|
provider_id="test-provider",
|
|
)
|
|
|
|
vector_db2 = VectorDBWithOwner(
|
|
identifier="test_vector_db", # Same identifier
|
|
embedding_model="different-model", # Different embedding model
|
|
embedding_dimension=384,
|
|
provider_resource_id="test_vector_db",
|
|
provider_id="test-provider",
|
|
)
|
|
|
|
# First registration should succeed
|
|
result1 = await disk_dist_registry.register(vector_db1)
|
|
assert result1 is True
|
|
|
|
# Second registration with different data should fail
|
|
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db' already exists"):
|
|
await disk_dist_registry.register(vector_db2)
|
|
|
|
# Verify original object is unchanged
|
|
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
|
|
assert retrieved is not None
|
|
assert retrieved.embedding_model == "all-MiniLM-L6-v2" # Original value
|
|
|
|
|
|
async def test_double_registration_with_cache(cached_disk_dist_registry):
|
|
"""Test double registration behavior with caching enabled."""
|
|
from llama_stack.apis.models import ModelType
|
|
from llama_stack.core.datatypes import ModelWithOwner
|
|
|
|
model1 = ModelWithOwner(
|
|
identifier="test_model",
|
|
provider_resource_id="test_model",
|
|
provider_id="test-provider",
|
|
model_type=ModelType.llm,
|
|
)
|
|
|
|
model2 = ModelWithOwner(
|
|
identifier="test_model", # Same identifier
|
|
provider_resource_id="test_model",
|
|
provider_id="test-provider",
|
|
model_type=ModelType.embedding, # Different type
|
|
)
|
|
|
|
# First registration should succeed and populate cache
|
|
result1 = await cached_disk_dist_registry.register(model1)
|
|
assert result1 is True
|
|
|
|
# Verify in cache
|
|
cached_model = cached_disk_dist_registry.get_cached("model", "test_model")
|
|
assert cached_model is not None
|
|
assert cached_model.model_type == ModelType.llm
|
|
|
|
# Second registration with different data should fail
|
|
with pytest.raises(ValueError, match="Object of type 'model' and identifier 'test_model' already exists"):
|
|
await cached_disk_dist_registry.register(model2)
|
|
|
|
# Cache should still contain original model
|
|
cached_model_after = cached_disk_dist_registry.get_cached("model", "test_model")
|
|
assert cached_model_after is not None
|
|
assert cached_model_after.model_type == ModelType.llm
|