mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-10 13:28:40 +00:00
feat: make object registration idempotent (#3752)
# What does this PR do? objects (vector dbs, models, scoring functions, etc) have an identifier and associated object values. we allow exact duplicate registrations. we reject registrations when the identifier exists and the associated object values differ. note: model are namespaced, i.e. {provider_id}/{identifier}, while other object types are not ## Test Plan ci w/ new tests
This commit is contained in:
parent
7ee0ee7843
commit
145b2bcf25
4 changed files with 211 additions and 7 deletions
|
@ -9,7 +9,6 @@ from typing import Any
|
|||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.access_control.datatypes import Action
|
||||
from llama_stack.core.datatypes import (
|
||||
|
@ -17,6 +16,7 @@ from llama_stack.core.datatypes import (
|
|||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
ScoringFnWithOwner,
|
||||
)
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.core.store import DistributionRegistry
|
||||
|
@ -114,7 +114,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
elif api == Api.scoring:
|
||||
p.scoring_function_store = self
|
||||
scoring_functions = await p.list_scoring_functions()
|
||||
await add_objects(scoring_functions, pid, ScoringFn)
|
||||
await add_objects(scoring_functions, pid, ScoringFnWithOwner)
|
||||
elif api == Api.eval:
|
||||
p.benchmark_store = self
|
||||
elif api == Api.tool_runtime:
|
||||
|
|
|
@ -96,9 +96,11 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
|
||||
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
||||
existing_obj = await self.get(obj.type, obj.identifier)
|
||||
# dont register if the object's providerid already exists
|
||||
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
||||
return False
|
||||
if existing_obj and existing_obj != obj:
|
||||
raise ValueError(
|
||||
f"Object of type '{obj.type}' and identifier '{obj.identifier}' already exists. "
|
||||
"Unregister it first if you want to replace it."
|
||||
)
|
||||
|
||||
await self.kvstore.set(
|
||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||
|
|
|
@ -354,6 +354,111 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
|||
assert len(scoring_functions_list_after_deletion.data) == 0
|
||||
|
||||
|
||||
async def test_double_registration_models_positive(cached_disk_dist_registry):
|
||||
"""Test that registering the same model twice with identical data succeeds."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a model
|
||||
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
|
||||
|
||||
# Register the exact same model again - should succeed (idempotent)
|
||||
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
|
||||
|
||||
# Verify only one model exists
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 1
|
||||
assert models.data[0].identifier == "test_provider/test-model"
|
||||
|
||||
|
||||
async def test_double_registration_models_negative(cached_disk_dist_registry):
|
||||
"""Test that registering the same model with different data fails."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a model with specific metadata
|
||||
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
|
||||
|
||||
# Try to register the same model with different metadata - should fail
|
||||
with pytest.raises(
|
||||
ValueError, match="Object of type 'model' and identifier 'test_provider/test-model' already exists"
|
||||
):
|
||||
await table.register_model(
|
||||
model_id="test-model", provider_id="test_provider", metadata={"param1": "different_value"}
|
||||
)
|
||||
|
||||
|
||||
async def test_double_registration_scoring_functions_positive(cached_disk_dist_registry):
|
||||
"""Test that registering the same scoring function twice with identical data succeeds."""
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a scoring function
|
||||
await table.register_scoring_function(
|
||||
scoring_fn_id="test-scoring-fn",
|
||||
provider_id="test_provider",
|
||||
description="Test scoring function",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
# Register the exact same scoring function again - should succeed (idempotent)
|
||||
await table.register_scoring_function(
|
||||
scoring_fn_id="test-scoring-fn",
|
||||
provider_id="test_provider",
|
||||
description="Test scoring function",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
# Verify only one scoring function exists
|
||||
scoring_functions = await table.list_scoring_functions()
|
||||
assert len(scoring_functions.data) == 1
|
||||
assert scoring_functions.data[0].identifier == "test-scoring-fn"
|
||||
|
||||
|
||||
async def test_double_registration_scoring_functions_negative(cached_disk_dist_registry):
|
||||
"""Test that registering the same scoring function with different data fails."""
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a scoring function
|
||||
await table.register_scoring_function(
|
||||
scoring_fn_id="test-scoring-fn",
|
||||
provider_id="test_provider",
|
||||
description="Test scoring function",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
# Try to register the same scoring function with different description - should fail
|
||||
with pytest.raises(
|
||||
ValueError, match="Object of type 'scoring_function' and identifier 'test-scoring-fn' already exists"
|
||||
):
|
||||
await table.register_scoring_function(
|
||||
scoring_fn_id="test-scoring-fn",
|
||||
provider_id="test_provider",
|
||||
description="Different description",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
|
||||
async def test_double_registration_different_providers(cached_disk_dist_registry):
|
||||
"""Test that registering objects with same ID but different providers succeeds."""
|
||||
impl1 = InferenceImpl()
|
||||
impl2 = InferenceImpl()
|
||||
table = ModelsRoutingTable({"provider1": impl1, "provider2": impl2}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register same model ID with different providers - should succeed
|
||||
await table.register_model(model_id="shared-model", provider_id="provider1")
|
||||
await table.register_model(model_id="shared-model", provider_id="provider2")
|
||||
|
||||
# Verify both models exist with different identifiers
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
model_ids = {m.identifier for m in models.data}
|
||||
assert "provider1/shared-model" in model_ids
|
||||
assert "provider2/shared-model" in model_ids
|
||||
|
||||
|
||||
async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.core.datatypes import VectorDBWithOwner
|
||||
from llama_stack.core.store.registry import (
|
||||
KEY_FORMAT,
|
||||
CachedDiskDistributionRegistry,
|
||||
|
@ -116,7 +117,7 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
|||
provider_resource_id="test_vector_db_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_disk_dist_registry.register(original_vector_db)
|
||||
assert await cached_disk_dist_registry.register(original_vector_db)
|
||||
|
||||
duplicate_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
|
@ -125,6 +126,7 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
|||
provider_resource_id="test_vector_db_2",
|
||||
provider_id="baz", # Same provider_id
|
||||
)
|
||||
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db_2' already exists"):
|
||||
await cached_disk_dist_registry.register(duplicate_vector_db)
|
||||
|
||||
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||
|
@ -229,3 +231,98 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
|
|||
|
||||
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
|
||||
assert invalid_obj is None
|
||||
|
||||
|
||||
async def test_double_registration_identical_objects(disk_dist_registry):
|
||||
"""Test that registering identical objects succeeds (idempotent)."""
|
||||
vector_db = VectorDBWithOwner(
|
||||
identifier="test_vector_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
# First registration should succeed
|
||||
result1 = await disk_dist_registry.register(vector_db)
|
||||
assert result1 is True
|
||||
|
||||
# Second registration of identical object should also succeed (idempotent)
|
||||
result2 = await disk_dist_registry.register(vector_db)
|
||||
assert result2 is True
|
||||
|
||||
# Verify object exists and is unchanged
|
||||
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
|
||||
assert retrieved is not None
|
||||
assert retrieved.identifier == vector_db.identifier
|
||||
assert retrieved.embedding_model == vector_db.embedding_model
|
||||
|
||||
|
||||
async def test_double_registration_different_objects(disk_dist_registry):
|
||||
"""Test that registering different objects with same identifier fails."""
|
||||
vector_db1 = VectorDBWithOwner(
|
||||
identifier="test_vector_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
vector_db2 = VectorDBWithOwner(
|
||||
identifier="test_vector_db", # Same identifier
|
||||
embedding_model="different-model", # Different embedding model
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
# First registration should succeed
|
||||
result1 = await disk_dist_registry.register(vector_db1)
|
||||
assert result1 is True
|
||||
|
||||
# Second registration with different data should fail
|
||||
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db' already exists"):
|
||||
await disk_dist_registry.register(vector_db2)
|
||||
|
||||
# Verify original object is unchanged
|
||||
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
|
||||
assert retrieved is not None
|
||||
assert retrieved.embedding_model == "all-MiniLM-L6-v2" # Original value
|
||||
|
||||
|
||||
async def test_double_registration_with_cache(cached_disk_dist_registry):
|
||||
"""Test double registration behavior with caching enabled."""
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import ModelWithOwner
|
||||
|
||||
model1 = ModelWithOwner(
|
||||
identifier="test_model",
|
||||
provider_resource_id="test_model",
|
||||
provider_id="test-provider",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
model2 = ModelWithOwner(
|
||||
identifier="test_model", # Same identifier
|
||||
provider_resource_id="test_model",
|
||||
provider_id="test-provider",
|
||||
model_type=ModelType.embedding, # Different type
|
||||
)
|
||||
|
||||
# First registration should succeed and populate cache
|
||||
result1 = await cached_disk_dist_registry.register(model1)
|
||||
assert result1 is True
|
||||
|
||||
# Verify in cache
|
||||
cached_model = cached_disk_dist_registry.get_cached("model", "test_model")
|
||||
assert cached_model is not None
|
||||
assert cached_model.model_type == ModelType.llm
|
||||
|
||||
# Second registration with different data should fail
|
||||
with pytest.raises(ValueError, match="Object of type 'model' and identifier 'test_model' already exists"):
|
||||
await cached_disk_dist_registry.register(model2)
|
||||
|
||||
# Cache should still contain original model
|
||||
cached_model_after = cached_disk_dist_registry.get_cached("model", "test_model")
|
||||
assert cached_model_after is not None
|
||||
assert cached_model_after.model_type == ModelType.llm
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue