mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Merge d5d2061c8c
into a09e30bd87
This commit is contained in:
commit
4d9142b81c
2 changed files with 24 additions and 6 deletions
|
@ -14,7 +14,10 @@ from llama_stack.core.datatypes import RoutableObjectWithProvider
|
||||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import (
|
||||||
|
KVStoreConfig,
|
||||||
|
SqliteKVStoreConfig,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__, category="core::registry")
|
logger = get_logger(__name__, category="core::registry")
|
||||||
|
|
||||||
|
@ -98,7 +101,10 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
existing_obj = await self.get(obj.type, obj.identifier)
|
existing_obj = await self.get(obj.type, obj.identifier)
|
||||||
# dont register if the object's providerid already exists
|
# dont register if the object's providerid already exists
|
||||||
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
||||||
return False
|
raise ValueError(
|
||||||
|
f"Provider '{obj.provider_id}' is already registered."
|
||||||
|
f"Unregister the existing provider first before registering it again."
|
||||||
|
)
|
||||||
|
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||||
|
|
|
@ -125,8 +125,15 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
||||||
provider_resource_id="test_vector_db_2",
|
provider_resource_id="test_vector_db_2",
|
||||||
provider_id="baz", # Same provider_id
|
provider_id="baz", # Same provider_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Now we expect a ValueError to be raised for duplicate registration
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r"Provider 'baz' is already registered.*Unregister the existing provider first before registering it again.",
|
||||||
|
):
|
||||||
await cached_disk_dist_registry.register(duplicate_vector_db)
|
await cached_disk_dist_registry.register(duplicate_vector_db)
|
||||||
|
|
||||||
|
# Verify the original registration is still intact
|
||||||
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
||||||
|
@ -174,10 +181,14 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
||||||
)
|
)
|
||||||
|
|
||||||
await sqlite_kvstore.set(
|
await sqlite_kvstore.set(
|
||||||
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
|
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"),
|
||||||
|
valid_db.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
|
await sqlite_kvstore.set(
|
||||||
|
KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"),
|
||||||
|
"{not valid json",
|
||||||
|
)
|
||||||
|
|
||||||
await sqlite_kvstore.set(
|
await sqlite_kvstore.set(
|
||||||
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
||||||
|
@ -212,7 +223,8 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
|
||||||
)
|
)
|
||||||
|
|
||||||
await sqlite_kvstore.set(
|
await sqlite_kvstore.set(
|
||||||
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
|
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"),
|
||||||
|
valid_db.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
await sqlite_kvstore.set(
|
await sqlite_kvstore.set(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue