diff --git a/llama_stack/core/routing_tables/common.py b/llama_stack/core/routing_tables/common.py index ca2f3af42..0800b909b 100644 --- a/llama_stack/core/routing_tables/common.py +++ b/llama_stack/core/routing_tables/common.py @@ -9,7 +9,6 @@ from typing import Any from llama_stack.apis.common.errors import ModelNotFoundError from llama_stack.apis.models import Model from llama_stack.apis.resource import ResourceType -from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed from llama_stack.core.access_control.datatypes import Action from llama_stack.core.datatypes import ( @@ -17,6 +16,7 @@ from llama_stack.core.datatypes import ( RoutableObject, RoutableObjectWithProvider, RoutedProtocol, + ScoringFnWithOwner, ) from llama_stack.core.request_headers import get_authenticated_user from llama_stack.core.store import DistributionRegistry @@ -114,7 +114,7 @@ class CommonRoutingTableImpl(RoutingTable): elif api == Api.scoring: p.scoring_function_store = self scoring_functions = await p.list_scoring_functions() - await add_objects(scoring_functions, pid, ScoringFn) + await add_objects(scoring_functions, pid, ScoringFnWithOwner) elif api == Api.eval: p.benchmark_store = self elif api == Api.tool_runtime: diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index 624dbd176..04581bab5 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -96,9 +96,11 @@ class DiskDistributionRegistry(DistributionRegistry): async def register(self, obj: RoutableObjectWithProvider) -> bool: existing_obj = await self.get(obj.type, obj.identifier) - # dont register if the object's providerid already exists - if existing_obj and existing_obj.provider_id == obj.provider_id: - return False + if existing_obj and existing_obj != obj: + raise ValueError( + f"Object of type '{obj.type}' and identifier '{obj.identifier}' already exists. " + "Unregister it first if you want to replace it." + ) await self.kvstore.set( KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index a1c3d1e95..8b03ec260 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -354,6 +354,111 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry): assert len(scoring_functions_list_after_deletion.data) == 0 +async def test_double_registration_models_positive(cached_disk_dist_registry): + """Test that registering the same model twice with identical data succeeds.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register a model + await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"}) + + # Register the exact same model again - should succeed (idempotent) + await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"}) + + # Verify only one model exists + models = await table.list_models() + assert len(models.data) == 1 + assert models.data[0].identifier == "test_provider/test-model" + + +async def test_double_registration_models_negative(cached_disk_dist_registry): + """Test that registering the same model with different data fails.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register a model with specific metadata + await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"}) + + # Try to register the same model with different metadata - should fail + with pytest.raises( + ValueError, match="Object of type 'model' and identifier 'test_provider/test-model' already exists" + ): + await table.register_model( + model_id="test-model", provider_id="test_provider", metadata={"param1": "different_value"} + ) + + +async def test_double_registration_scoring_functions_positive(cached_disk_dist_registry): + """Test that registering the same scoring function twice with identical data succeeds.""" + table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register a scoring function + await table.register_scoring_function( + scoring_fn_id="test-scoring-fn", + provider_id="test_provider", + description="Test scoring function", + return_type=NumberType(), + ) + + # Register the exact same scoring function again - should succeed (idempotent) + await table.register_scoring_function( + scoring_fn_id="test-scoring-fn", + provider_id="test_provider", + description="Test scoring function", + return_type=NumberType(), + ) + + # Verify only one scoring function exists + scoring_functions = await table.list_scoring_functions() + assert len(scoring_functions.data) == 1 + assert scoring_functions.data[0].identifier == "test-scoring-fn" + + +async def test_double_registration_scoring_functions_negative(cached_disk_dist_registry): + """Test that registering the same scoring function with different data fails.""" + table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register a scoring function + await table.register_scoring_function( + scoring_fn_id="test-scoring-fn", + provider_id="test_provider", + description="Test scoring function", + return_type=NumberType(), + ) + + # Try to register the same scoring function with different description - should fail + with pytest.raises( + ValueError, match="Object of type 'scoring_function' and identifier 'test-scoring-fn' already exists" + ): + await table.register_scoring_function( + scoring_fn_id="test-scoring-fn", + provider_id="test_provider", + description="Different description", + return_type=NumberType(), + ) + + +async def test_double_registration_different_providers(cached_disk_dist_registry): + """Test that registering objects with same ID but different providers succeeds.""" + impl1 = InferenceImpl() + impl2 = InferenceImpl() + table = ModelsRoutingTable({"provider1": impl1, "provider2": impl2}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register same model ID with different providers - should succeed + await table.register_model(model_id="shared-model", provider_id="provider1") + await table.register_model(model_id="shared-model", provider_id="provider2") + + # Verify both models exist with different identifiers + models = await table.list_models() + assert len(models.data) == 2 + model_ids = {m.identifier for m in models.data} + assert "provider1/shared-model" in model_ids + assert "provider2/shared-model" in model_ids + + async def test_benchmarks_routing_table(cached_disk_dist_registry): table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {}) await table.initialize() diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 4ea4a20b9..61afa0561 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -9,6 +9,7 @@ import pytest from llama_stack.apis.inference import Model from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.core.datatypes import VectorDBWithOwner from llama_stack.core.store.registry import ( KEY_FORMAT, CachedDiskDistributionRegistry, @@ -116,7 +117,7 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry): provider_resource_id="test_vector_db_2", provider_id="baz", ) - await cached_disk_dist_registry.register(original_vector_db) + assert await cached_disk_dist_registry.register(original_vector_db) duplicate_vector_db = VectorDB( identifier="test_vector_db_2", @@ -125,7 +126,8 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry): provider_resource_id="test_vector_db_2", provider_id="baz", # Same provider_id ) - await cached_disk_dist_registry.register(duplicate_vector_db) + with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db_2' already exists"): + await cached_disk_dist_registry.register(duplicate_vector_db) result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result is not None @@ -229,3 +231,98 @@ async def test_cached_registry_error_handling(sqlite_kvstore): invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db") assert invalid_obj is None + + +async def test_double_registration_identical_objects(disk_dist_registry): + """Test that registering identical objects succeeds (idempotent).""" + vector_db = VectorDBWithOwner( + identifier="test_vector_db", + embedding_model="all-MiniLM-L6-v2", + embedding_dimension=384, + provider_resource_id="test_vector_db", + provider_id="test-provider", + ) + + # First registration should succeed + result1 = await disk_dist_registry.register(vector_db) + assert result1 is True + + # Second registration of identical object should also succeed (idempotent) + result2 = await disk_dist_registry.register(vector_db) + assert result2 is True + + # Verify object exists and is unchanged + retrieved = await disk_dist_registry.get("vector_db", "test_vector_db") + assert retrieved is not None + assert retrieved.identifier == vector_db.identifier + assert retrieved.embedding_model == vector_db.embedding_model + + +async def test_double_registration_different_objects(disk_dist_registry): + """Test that registering different objects with same identifier fails.""" + vector_db1 = VectorDBWithOwner( + identifier="test_vector_db", + embedding_model="all-MiniLM-L6-v2", + embedding_dimension=384, + provider_resource_id="test_vector_db", + provider_id="test-provider", + ) + + vector_db2 = VectorDBWithOwner( + identifier="test_vector_db", # Same identifier + embedding_model="different-model", # Different embedding model + embedding_dimension=384, + provider_resource_id="test_vector_db", + provider_id="test-provider", + ) + + # First registration should succeed + result1 = await disk_dist_registry.register(vector_db1) + assert result1 is True + + # Second registration with different data should fail + with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db' already exists"): + await disk_dist_registry.register(vector_db2) + + # Verify original object is unchanged + retrieved = await disk_dist_registry.get("vector_db", "test_vector_db") + assert retrieved is not None + assert retrieved.embedding_model == "all-MiniLM-L6-v2" # Original value + + +async def test_double_registration_with_cache(cached_disk_dist_registry): + """Test double registration behavior with caching enabled.""" + from llama_stack.apis.models import ModelType + from llama_stack.core.datatypes import ModelWithOwner + + model1 = ModelWithOwner( + identifier="test_model", + provider_resource_id="test_model", + provider_id="test-provider", + model_type=ModelType.llm, + ) + + model2 = ModelWithOwner( + identifier="test_model", # Same identifier + provider_resource_id="test_model", + provider_id="test-provider", + model_type=ModelType.embedding, # Different type + ) + + # First registration should succeed and populate cache + result1 = await cached_disk_dist_registry.register(model1) + assert result1 is True + + # Verify in cache + cached_model = cached_disk_dist_registry.get_cached("model", "test_model") + assert cached_model is not None + assert cached_model.model_type == ModelType.llm + + # Second registration with different data should fail + with pytest.raises(ValueError, match="Object of type 'model' and identifier 'test_model' already exists"): + await cached_disk_dist_registry.register(model2) + + # Cache should still contain original model + cached_model_after = cached_disk_dist_registry.get_cached("model", "test_model") + assert cached_model_after is not None + assert cached_model_after.model_type == ModelType.llm