This commit is contained in:
Omar Abdelwahab 2025-10-03 16:35:47 +02:00 committed by GitHub
commit 4d9142b81c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 24 additions and 6 deletions

View file

@ -14,7 +14,10 @@ from llama_stack.core.datatypes import RoutableObjectWithProvider
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
logger = get_logger(__name__, category="core::registry")
@ -98,7 +101,10 @@ class DiskDistributionRegistry(DistributionRegistry):
existing_obj = await self.get(obj.type, obj.identifier)
# dont register if the object's providerid already exists
if existing_obj and existing_obj.provider_id == obj.provider_id:
return False
raise ValueError(
f"Provider '{obj.provider_id}' is already registered."
f"Unregister the existing provider first before registering it again."
)
await self.kvstore.set(
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),

View file

@ -125,8 +125,15 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
provider_resource_id="test_vector_db_2",
provider_id="baz", # Same provider_id
)
# Now we expect a ValueError to be raised for duplicate registration
with pytest.raises(
ValueError,
match=r"Provider 'baz' is already registered.*Unregister the existing provider first before registering it again.",
):
await cached_disk_dist_registry.register(duplicate_vector_db)
# Verify the original registration is still intact
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result is not None
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
@ -174,10 +181,14 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
)
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"),
valid_db.model_dump_json(),
)
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"),
"{not valid json",
)
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
@ -212,7 +223,8 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
)
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"),
valid_db.model_dump_json(),
)
await sqlite_kvstore.set(