mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Merge d5d2061c8c
into a09e30bd87
This commit is contained in:
commit
4d9142b81c
2 changed files with 24 additions and 6 deletions
|
@ -14,7 +14,10 @@ from llama_stack.core.datatypes import RoutableObjectWithProvider
|
|||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__, category="core::registry")
|
||||
|
||||
|
@ -98,7 +101,10 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
existing_obj = await self.get(obj.type, obj.identifier)
|
||||
# dont register if the object's providerid already exists
|
||||
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
||||
return False
|
||||
raise ValueError(
|
||||
f"Provider '{obj.provider_id}' is already registered."
|
||||
f"Unregister the existing provider first before registering it again."
|
||||
)
|
||||
|
||||
await self.kvstore.set(
|
||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||
|
|
|
@ -125,8 +125,15 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
|||
provider_resource_id="test_vector_db_2",
|
||||
provider_id="baz", # Same provider_id
|
||||
)
|
||||
|
||||
# Now we expect a ValueError to be raised for duplicate registration
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Provider 'baz' is already registered.*Unregister the existing provider first before registering it again.",
|
||||
):
|
||||
await cached_disk_dist_registry.register(duplicate_vector_db)
|
||||
|
||||
# Verify the original registration is still intact
|
||||
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result is not None
|
||||
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
||||
|
@ -174,10 +181,14 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
|||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
|
||||
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"),
|
||||
valid_db.model_dump_json(),
|
||||
)
|
||||
|
||||
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"),
|
||||
"{not valid json",
|
||||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
||||
|
@ -212,7 +223,8 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
|
|||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
|
||||
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"),
|
||||
valid_db.model_dump_json(),
|
||||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue