mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
only reject registrations that share an id and differ on content
This commit is contained in:
parent
142ea659da
commit
f2d821ab3e
3 changed files with 22 additions and 14 deletions
|
|
@ -96,9 +96,11 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
|
|
||||||
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
||||||
existing_obj = await self.get(obj.type, obj.identifier)
|
existing_obj = await self.get(obj.type, obj.identifier)
|
||||||
# dont register if the object's providerid already exists
|
if existing_obj and existing_obj != obj:
|
||||||
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
raise ValueError(
|
||||||
return False
|
f"Object of type '{obj.type}' and identifier '{obj.identifier}' already exists. "
|
||||||
|
"Unregister it first if you want to replace it."
|
||||||
|
)
|
||||||
|
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||||
|
|
|
||||||
|
|
@ -360,10 +360,10 @@ async def test_double_registration_models_positive(cached_disk_dist_registry):
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# Register a model
|
# Register a model
|
||||||
await table.register_model(model_id="test-model", provider_id="test_provider")
|
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
|
||||||
|
|
||||||
# Register the exact same model again - should succeed (idempotent)
|
# Register the exact same model again - should succeed (idempotent)
|
||||||
await table.register_model(model_id="test-model", provider_id="test_provider")
|
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
|
||||||
|
|
||||||
# Verify only one model exists
|
# Verify only one model exists
|
||||||
models = await table.list_models()
|
models = await table.list_models()
|
||||||
|
|
@ -380,7 +380,9 @@ async def test_double_registration_models_negative(cached_disk_dist_registry):
|
||||||
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
|
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
|
||||||
|
|
||||||
# Try to register the same model with different metadata - should fail
|
# Try to register the same model with different metadata - should fail
|
||||||
with pytest.raises(ValueError, match="Provider 'test_provider' is already registered"):
|
with pytest.raises(
|
||||||
|
ValueError, match="Object of type 'model' and identifier 'test_provider/test-model' already exists"
|
||||||
|
):
|
||||||
await table.register_model(
|
await table.register_model(
|
||||||
model_id="test-model", provider_id="test_provider", metadata={"param1": "different_value"}
|
model_id="test-model", provider_id="test_provider", metadata={"param1": "different_value"}
|
||||||
)
|
)
|
||||||
|
|
@ -427,7 +429,9 @@ async def test_double_registration_scoring_functions_negative(cached_disk_dist_r
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to register the same scoring function with different description - should fail
|
# Try to register the same scoring function with different description - should fail
|
||||||
with pytest.raises(ValueError, match="Provider 'test_provider' is already registered"):
|
with pytest.raises(
|
||||||
|
ValueError, match="Object of type 'scoring_function' and identifier 'test-scoring-fn' already exists"
|
||||||
|
):
|
||||||
await table.register_scoring_function(
|
await table.register_scoring_function(
|
||||||
scoring_fn_id="test-scoring-fn",
|
scoring_fn_id="test-scoring-fn",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import pytest
|
||||||
|
|
||||||
from llama_stack.apis.inference import Model
|
from llama_stack.apis.inference import Model
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
from llama_stack.core.datatypes import VectorDBWithOwner
|
||||||
from llama_stack.core.store.registry import (
|
from llama_stack.core.store.registry import (
|
||||||
KEY_FORMAT,
|
KEY_FORMAT,
|
||||||
CachedDiskDistributionRegistry,
|
CachedDiskDistributionRegistry,
|
||||||
|
|
@ -116,7 +117,7 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
||||||
provider_resource_id="test_vector_db_2",
|
provider_resource_id="test_vector_db_2",
|
||||||
provider_id="baz",
|
provider_id="baz",
|
||||||
)
|
)
|
||||||
await cached_disk_dist_registry.register(original_vector_db)
|
assert await cached_disk_dist_registry.register(original_vector_db)
|
||||||
|
|
||||||
duplicate_vector_db = VectorDB(
|
duplicate_vector_db = VectorDB(
|
||||||
identifier="test_vector_db_2",
|
identifier="test_vector_db_2",
|
||||||
|
|
@ -125,7 +126,8 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
||||||
provider_resource_id="test_vector_db_2",
|
provider_resource_id="test_vector_db_2",
|
||||||
provider_id="baz", # Same provider_id
|
provider_id="baz", # Same provider_id
|
||||||
)
|
)
|
||||||
await cached_disk_dist_registry.register(duplicate_vector_db)
|
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db_2' already exists"):
|
||||||
|
await cached_disk_dist_registry.register(duplicate_vector_db)
|
||||||
|
|
||||||
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
|
@ -233,7 +235,7 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
|
||||||
|
|
||||||
async def test_double_registration_identical_objects(disk_dist_registry):
|
async def test_double_registration_identical_objects(disk_dist_registry):
|
||||||
"""Test that registering identical objects succeeds (idempotent)."""
|
"""Test that registering identical objects succeeds (idempotent)."""
|
||||||
vector_db = VectorDB(
|
vector_db = VectorDBWithOwner(
|
||||||
identifier="test_vector_db",
|
identifier="test_vector_db",
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
embedding_dimension=384,
|
embedding_dimension=384,
|
||||||
|
|
@ -258,7 +260,7 @@ async def test_double_registration_identical_objects(disk_dist_registry):
|
||||||
|
|
||||||
async def test_double_registration_different_objects(disk_dist_registry):
|
async def test_double_registration_different_objects(disk_dist_registry):
|
||||||
"""Test that registering different objects with same identifier fails."""
|
"""Test that registering different objects with same identifier fails."""
|
||||||
vector_db1 = VectorDB(
|
vector_db1 = VectorDBWithOwner(
|
||||||
identifier="test_vector_db",
|
identifier="test_vector_db",
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
embedding_dimension=384,
|
embedding_dimension=384,
|
||||||
|
|
@ -266,7 +268,7 @@ async def test_double_registration_different_objects(disk_dist_registry):
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_db2 = VectorDB(
|
vector_db2 = VectorDBWithOwner(
|
||||||
identifier="test_vector_db", # Same identifier
|
identifier="test_vector_db", # Same identifier
|
||||||
embedding_model="different-model", # Different embedding model
|
embedding_model="different-model", # Different embedding model
|
||||||
embedding_dimension=384,
|
embedding_dimension=384,
|
||||||
|
|
@ -279,7 +281,7 @@ async def test_double_registration_different_objects(disk_dist_registry):
|
||||||
assert result1 is True
|
assert result1 is True
|
||||||
|
|
||||||
# Second registration with different data should fail
|
# Second registration with different data should fail
|
||||||
with pytest.raises(ValueError, match="Provider 'test-provider' is already registered"):
|
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db' already exists"):
|
||||||
await disk_dist_registry.register(vector_db2)
|
await disk_dist_registry.register(vector_db2)
|
||||||
|
|
||||||
# Verify original object is unchanged
|
# Verify original object is unchanged
|
||||||
|
|
@ -317,7 +319,7 @@ async def test_double_registration_with_cache(cached_disk_dist_registry):
|
||||||
assert cached_model.model_type == ModelType.llm
|
assert cached_model.model_type == ModelType.llm
|
||||||
|
|
||||||
# Second registration with different data should fail
|
# Second registration with different data should fail
|
||||||
with pytest.raises(ValueError, match="Provider 'test-provider' is already registered"):
|
with pytest.raises(ValueError, match="Object of type 'model' and identifier 'test_model' already exists"):
|
||||||
await cached_disk_dist_registry.register(model2)
|
await cached_disk_dist_registry.register(model2)
|
||||||
|
|
||||||
# Cache should still contain original model
|
# Cache should still contain original model
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue