From 663883cc294c79239dd7c92503bad010bade5f51 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 4 Nov 2024 17:25:06 -0800 Subject: [PATCH] persist registered objects with distribution (#354) * persist registered objects with distribution * linter fixes * comment * use annotate and field discriminator * workign tests * donot use global state * precommit failures fixed * add back Any * fix imports * remove unnecessary changes in ollama * precommit failures fixed * make kvstore configurable for dist and rename registry * add comment about registry list return * fix linter errors * use registry to hydrate * remove debug print * linter fixes * remove kvstore.db * rename distribution_registry_store --------- Co-authored-by: Dinesh Yeduguru --- llama_stack/apis/datasets/datasets.py | 3 +- llama_stack/apis/models/models.py | 3 +- .../scoring_functions/scoring_functions.py | 3 +- llama_stack/apis/shields/shields.py | 3 +- llama_stack/distribution/datatypes.py | 23 ++- llama_stack/distribution/resolver.py | 9 +- llama_stack/distribution/routers/__init__.py | 6 +- .../distribution/routers/routing_tables.py | 62 +++---- llama_stack/distribution/server/server.py | 22 ++- llama_stack/distribution/store/__init__.py | 7 + llama_stack/distribution/store/registry.py | 135 ++++++++++++++ .../distribution/store/tests/test_registry.py | 171 ++++++++++++++++++ 12 files changed, 401 insertions(+), 46 deletions(-) create mode 100644 llama_stack/distribution/store/__init__.py create mode 100644 llama_stack/distribution/store/registry.py create mode 100644 llama_stack/distribution/store/tests/test_registry.py diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 7a56049bf..1695c888b 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional, Protocol +from typing import Any, Dict, List, Literal, Optional, Protocol from llama_models.llama3.api.datatypes import URL @@ -32,6 +32,7 @@ class DatasetDef(BaseModel): @json_schema_type class DatasetDefWithProvider(DatasetDef): + type: Literal["dataset"] = "dataset" provider_id: str = Field( description="ID of the provider which serves this dataset", ) diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 994c8e995..ffb3b022e 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional, Protocol, runtime_checkable +from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field @@ -25,6 +25,7 @@ class ModelDef(BaseModel): @json_schema_type class ModelDefWithProvider(ModelDef): + type: Literal["model"] = "model" provider_id: str = Field( description="The provider ID for this model", ) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 2e5bf0aef..d0a9cc597 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional, Protocol, runtime_checkable +from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field @@ -53,6 +53,7 @@ class ScoringFnDef(BaseModel): @json_schema_type class ScoringFnDefWithProvider(ScoringFnDef): + type: Literal["scoring_fn"] = "scoring_fn" provider_id: str = Field( description="ID of the provider which serves this dataset", ) diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 7f003faa2..0d1177f5a 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List, Optional, Protocol, runtime_checkable +from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field @@ -34,6 +34,7 @@ class ShieldDef(BaseModel): @json_schema_type class ShieldDefWithProvider(ShieldDef): + type: Literal["shield"] = "shield" provider_id: str = Field( description="The provider ID for this shield type", ) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 9ad82cd79..3a4806e27 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -21,6 +21,7 @@ from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring +from llama_stack.providers.utils.kvstore.config import KVStoreConfig LLAMA_STACK_BUILD_CONFIG_VERSION = "2" LLAMA_STACK_RUN_CONFIG_VERSION = "2" @@ -37,12 +38,16 @@ RoutableObject = Union[ ScoringFnDef, ] -RoutableObjectWithProvider = Union[ - ModelDefWithProvider, - ShieldDefWithProvider, - MemoryBankDefWithProvider, - DatasetDefWithProvider, - ScoringFnDefWithProvider, + +RoutableObjectWithProvider = Annotated[ + Union[ + ModelDefWithProvider, + ShieldDefWithProvider, + MemoryBankDefWithProvider, + DatasetDefWithProvider, + ScoringFnDefWithProvider, + ], + Field(discriminator="type"), ] RoutedProtocol = Union[ @@ -134,6 +139,12 @@ One or more providers to use for each API. The same provider_type (e.g., meta-re can be instantiated multiple times (with different configs) if necessary. """, ) + metadata_store: Optional[KVStoreConfig] = Field( + default=None, + description=""" +Configuration for the persistence store used by the distribution registry. If not specified, +a default SQLite store will be used.""", + ) class BuildConfig(BaseModel): diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index a93cc1183..96b4b81e6 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -26,6 +26,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry from llama_stack.distribution.distribution import builtin_automatically_routed_apis +from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -65,7 +66,9 @@ class ProviderWithSpec(Provider): # TODO: this code is not very straightforward to follow and needs one more round of refactoring async def resolve_impls( - run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]] + run_config: StackRunConfig, + provider_registry: Dict[Api, Dict[str, ProviderSpec]], + dist_registry: DistributionRegistry, ) -> Dict[Api, Any]: """ Does two things: @@ -189,6 +192,7 @@ async def resolve_impls( provider, deps, inner_impls, + dist_registry, ) # TODO: ugh slightly redesign this shady looking code if "inner-" in api_str: @@ -237,6 +241,7 @@ async def instantiate_provider( provider: ProviderWithSpec, deps: Dict[str, Any], inner_impls: Dict[str, Any], + dist_registry: DistributionRegistry, ): protocols = api_protocol_map() additional_protocols = additional_protocols_map() @@ -270,7 +275,7 @@ async def instantiate_provider( method = "get_routing_table_impl" config = None - args = [provider_spec.api, inner_impls, deps] + args = [provider_spec.api, inner_impls, deps, dist_registry] else: method = "get_provider_impl" diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 2cc89848e..b3ebd1368 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -7,6 +7,9 @@ from typing import Any from llama_stack.distribution.datatypes import * # noqa: F403 + +from llama_stack.distribution.store import DistributionRegistry + from .routing_tables import ( DatasetsRoutingTable, MemoryBanksRoutingTable, @@ -20,6 +23,7 @@ async def get_routing_table_impl( api: Api, impls_by_provider_id: Dict[str, RoutedProtocol], _deps, + dist_registry: DistributionRegistry, ) -> Any: api_to_tables = { "memory_banks": MemoryBanksRoutingTable, @@ -32,7 +36,7 @@ async def get_routing_table_impl( if api.value not in api_to_tables: raise ValueError(f"API {api.value} not found in router map") - impl = api_to_tables[api.value](impls_by_provider_id) + impl = api_to_tables[api.value](impls_by_provider_id, dist_registry) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 4e462c54b..fc7eda012 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -13,6 +13,7 @@ from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.datatypes import * # noqa: F403 @@ -46,25 +47,23 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: Registry = Dict[str, List[RoutableObjectWithProvider]] -# TODO: this routing table maintains state in memory purely. We need to -# add persistence to it when we add dynamic registration of objects. class CommonRoutingTableImpl(RoutingTable): def __init__( self, impls_by_provider_id: Dict[str, RoutedProtocol], + dist_registry: DistributionRegistry, ) -> None: self.impls_by_provider_id = impls_by_provider_id + self.dist_registry = dist_registry async def initialize(self) -> None: - self.registry: Registry = {} + # Initialize the registry if not already done + await self.dist_registry.initialize() - def add_objects( + async def add_objects( objs: List[RoutableObjectWithProvider], provider_id: str, cls ) -> None: for obj in objs: - if obj.identifier not in self.registry: - self.registry[obj.identifier] = [] - if cls is None: obj.provider_id = provider_id else: @@ -74,34 +73,35 @@ class CommonRoutingTableImpl(RoutingTable): obj.provider_id = provider_id else: obj = cls(**obj.model_dump(), provider_id=provider_id) - self.registry[obj.identifier].append(obj) + await self.dist_registry.register(obj) + # Register all objects from providers for pid, p in self.impls_by_provider_id.items(): api = get_impl_api(p) if api == Api.inference: p.model_store = self models = await p.list_models() - add_objects(models, pid, ModelDefWithProvider) + await add_objects(models, pid, ModelDefWithProvider) elif api == Api.safety: p.shield_store = self shields = await p.list_shields() - add_objects(shields, pid, ShieldDefWithProvider) + await add_objects(shields, pid, ShieldDefWithProvider) elif api == Api.memory: p.memory_bank_store = self memory_banks = await p.list_memory_banks() - add_objects(memory_banks, pid, None) + await add_objects(memory_banks, pid, None) elif api == Api.datasetio: p.dataset_store = self datasets = await p.list_datasets() - add_objects(datasets, pid, DatasetDefWithProvider) + await add_objects(datasets, pid, DatasetDefWithProvider) elif api == Api.scoring: p.scoring_function_store = self scoring_functions = await p.list_scoring_functions() - add_objects(scoring_functions, pid, ScoringFnDefWithProvider) + await add_objects(scoring_functions, pid, ScoringFnDefWithProvider) async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): @@ -124,39 +124,44 @@ class CommonRoutingTableImpl(RoutingTable): else: raise ValueError("Unknown routing table type") - if routing_key not in self.registry: + # Get objects from disk registry + objects = self.dist_registry.get_cached(routing_key) + if not objects: apiname, objname = apiname_object() raise ValueError( f"`{routing_key}` not registered. Make sure there is an {apiname} provider serving this {objname}." ) - objs = self.registry[routing_key] - for obj in objs: + for obj in objects: if not provider_id or provider_id == obj.provider_id: return self.impls_by_provider_id[obj.provider_id] raise ValueError(f"Provider not found for `{routing_key}`") - def get_object_by_identifier( + async def get_object_by_identifier( self, identifier: str ) -> Optional[RoutableObjectWithProvider]: - objs = self.registry.get(identifier, []) - if not objs: + # Get from disk registry + objects = await self.dist_registry.get(identifier) + if not objects: return None # kind of ill-defined behavior here, but we'll just return the first one - return objs[0] + return objects[0] async def register_object(self, obj: RoutableObjectWithProvider): - entries = self.registry.get(obj.identifier, []) - for entry in entries: - if entry.provider_id == obj.provider_id or not obj.provider_id: + # Get existing objects from registry + existing_objects = await self.dist_registry.get(obj.identifier) + + # Check for existing registration + for existing_obj in existing_objects: + if existing_obj.provider_id == obj.provider_id or not obj.provider_id: print( - f"`{obj.identifier}` already registered with `{entry.provider_id}`" + f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`" ) return - # if provider_id is not specified, we'll pick an arbitrary one from existing entries + # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: obj.provider_id = list(self.impls_by_provider_id.keys())[0] @@ -166,12 +171,7 @@ class CommonRoutingTableImpl(RoutingTable): p = self.impls_by_provider_id[obj.provider_id] await register_object_with_provider(obj, p) - - if obj.identifier not in self.registry: - self.registry[obj.identifier] = [] - self.registry[obj.identifier].append(obj) - - # TODO: persist this to a store + await self.dist_registry.register(obj) class ModelsRoutingTable(CommonRoutingTableImpl, Models): diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index b8fe4734e..2560f4070 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -31,6 +31,8 @@ from llama_stack.distribution.distribution import ( get_provider_registry, ) +from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR + from llama_stack.providers.utils.telemetry.tracing import ( end_trace, setup_logger, @@ -38,9 +40,10 @@ from llama_stack.providers.utils.telemetry.tracing import ( start_trace, ) from llama_stack.distribution.datatypes import * # noqa: F403 - from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_impls +from llama_stack.distribution.store import CachedDiskDistributionRegistry +from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from .endpoints import get_all_api_endpoints @@ -278,8 +281,23 @@ def main( config = StackRunConfig(**yaml.safe_load(fp)) app = FastAPI() + # instantiate kvstore for storing and retrieving distribution metadata + if config.metadata_store: + dist_kvstore = asyncio.run(kvstore_impl(config.metadata_store)) + else: + dist_kvstore = asyncio.run( + kvstore_impl( + SqliteKVStoreConfig( + db_path=( + DISTRIBS_BASE_DIR / config.image_name / "kvstore.db" + ).as_posix() + ) + ) + ) - impls = asyncio.run(resolve_impls(config, get_provider_registry())) + dist_registry = CachedDiskDistributionRegistry(dist_kvstore) + + impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry)) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) diff --git a/llama_stack/distribution/store/__init__.py b/llama_stack/distribution/store/__init__.py new file mode 100644 index 000000000..cd1080f3a --- /dev/null +++ b/llama_stack/distribution/store/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .registry import * # noqa: F401 F403 diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py new file mode 100644 index 000000000..994fb475c --- /dev/null +++ b/llama_stack/distribution/store/registry.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import Dict, List, Protocol + +import pydantic + +from llama_stack.distribution.datatypes import RoutableObjectWithProvider + +from llama_stack.providers.utils.kvstore import KVStore + + +class DistributionRegistry(Protocol): + async def get_all(self) -> List[RoutableObjectWithProvider]: ... + + async def initialize(self) -> None: ... + + async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: ... + + def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: ... + + # The current data structure allows multiple objects with the same identifier but different providers. + # This is not ideal - we should have a single object that can be served by multiple providers, + # suggesting a data structure like (obj: Obj, providers: List[str]) rather than List[RoutableObjectWithProvider]. + # The current approach could lead to inconsistencies if the same logical object has different data across providers. + async def register(self, obj: RoutableObjectWithProvider) -> bool: ... + + +KEY_FORMAT = "distributions:registry:{}" + + +class DiskDistributionRegistry(DistributionRegistry): + def __init__(self, kvstore: KVStore): + self.kvstore = kvstore + + async def initialize(self) -> None: + pass + + def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: + # Disk registry does not have a cache + return [] + + async def get_all(self) -> List[RoutableObjectWithProvider]: + start_key = KEY_FORMAT.format("") + end_key = KEY_FORMAT.format("\xff") + keys = await self.kvstore.range(start_key, end_key) + return [await self.get(key.split(":")[-1]) for key in keys] + + async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: + json_str = await self.kvstore.get(KEY_FORMAT.format(identifier)) + if not json_str: + return [] + + objects_data = json.loads(json_str) + return [ + pydantic.parse_obj_as( + RoutableObjectWithProvider, + json.loads(obj_str), + ) + for obj_str in objects_data + ] + + async def register(self, obj: RoutableObjectWithProvider) -> bool: + existing_objects = await self.get(obj.identifier) + # dont register if the object's providerid already exists + for eobj in existing_objects: + if eobj.provider_id == obj.provider_id: + return False + + existing_objects.append(obj) + + objects_json = [ + obj.model_dump_json() for obj in existing_objects + ] # Fixed variable name + await self.kvstore.set( + KEY_FORMAT.format(obj.identifier), json.dumps(objects_json) + ) + return True + + +class CachedDiskDistributionRegistry(DiskDistributionRegistry): + def __init__(self, kvstore: KVStore): + super().__init__(kvstore) + self.cache: Dict[str, List[RoutableObjectWithProvider]] = {} + + async def initialize(self) -> None: + start_key = KEY_FORMAT.format("") + end_key = KEY_FORMAT.format("\xff") + + keys = await self.kvstore.range(start_key, end_key) + + for key in keys: + identifier = key.split(":")[-1] + objects = await super().get(identifier) + if objects: + self.cache[identifier] = objects + + def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: + return self.cache.get(identifier, []) + + async def get_all(self) -> List[RoutableObjectWithProvider]: + return [item for sublist in self.cache.values() for item in sublist] + + async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: + if identifier in self.cache: + return self.cache[identifier] + + objects = await super().get(identifier) + if objects: + self.cache[identifier] = objects + + return objects + + async def register(self, obj: RoutableObjectWithProvider) -> bool: + # First update disk + success = await super().register(obj) + + if success: + # Then update cache + if obj.identifier not in self.cache: + self.cache[obj.identifier] = [] + + # Check if provider already exists in cache + for cached_obj in self.cache[obj.identifier]: + if cached_obj.provider_id == obj.provider_id: + return success + + # If not, update cache + self.cache[obj.identifier].append(obj) + + return success diff --git a/llama_stack/distribution/store/tests/test_registry.py b/llama_stack/distribution/store/tests/test_registry.py new file mode 100644 index 000000000..a9df4bed6 --- /dev/null +++ b/llama_stack/distribution/store/tests/test_registry.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +import pytest +import pytest_asyncio +from llama_stack.distribution.store import * # noqa F403 +from llama_stack.apis.inference import ModelDefWithProvider +from llama_stack.apis.memory_banks import VectorMemoryBankDef +from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig +from llama_stack.distribution.datatypes import * # noqa F403 + + +@pytest.fixture +def config(): + config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db") + if os.path.exists(config.db_path): + os.remove(config.db_path) + return config + + +@pytest_asyncio.fixture +async def registry(config): + registry = DiskDistributionRegistry(await kvstore_impl(config)) + await registry.initialize() + return registry + + +@pytest_asyncio.fixture +async def cached_registry(config): + registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) + await registry.initialize() + return registry + + +@pytest.fixture +def sample_bank(): + return VectorMemoryBankDef( + identifier="test_bank", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + provider_id="test-provider", + ) + + +@pytest.fixture +def sample_model(): + return ModelDefWithProvider( + identifier="test_model", + llama_model="Llama3.2-3B-Instruct", + provider_id="test-provider", + ) + + +@pytest.mark.asyncio +async def test_registry_initialization(registry): + # Test empty registry + results = await registry.get("nonexistent") + assert len(results) == 0 + + +@pytest.mark.asyncio +async def test_basic_registration(registry, sample_bank, sample_model): + print(f"Registering {sample_bank}") + await registry.register(sample_bank) + print(f"Registering {sample_model}") + await registry.register(sample_model) + print("Getting bank") + results = await registry.get("test_bank") + assert len(results) == 1 + result_bank = results[0] + assert result_bank.identifier == sample_bank.identifier + assert result_bank.embedding_model == sample_bank.embedding_model + assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens + assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens + assert result_bank.provider_id == sample_bank.provider_id + + results = await registry.get("test_model") + assert len(results) == 1 + result_model = results[0] + assert result_model.identifier == sample_model.identifier + assert result_model.llama_model == sample_model.llama_model + assert result_model.provider_id == sample_model.provider_id + + +@pytest.mark.asyncio +async def test_cached_registry_initialization(config, sample_bank, sample_model): + # First populate the disk registry + disk_registry = DiskDistributionRegistry(await kvstore_impl(config)) + await disk_registry.initialize() + await disk_registry.register(sample_bank) + await disk_registry.register(sample_model) + + # Test cached version loads from disk + cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) + await cached_registry.initialize() + + results = await cached_registry.get("test_bank") + assert len(results) == 1 + result_bank = results[0] + assert result_bank.identifier == sample_bank.identifier + assert result_bank.embedding_model == sample_bank.embedding_model + assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens + assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens + assert result_bank.provider_id == sample_bank.provider_id + + +@pytest.mark.asyncio +async def test_cached_registry_updates(config): + cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) + await cached_registry.initialize() + + new_bank = VectorMemoryBankDef( + identifier="test_bank_2", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=256, + overlap_size_in_tokens=32, + provider_id="baz", + ) + await cached_registry.register(new_bank) + + # Verify in cache + results = await cached_registry.get("test_bank_2") + assert len(results) == 1 + result_bank = results[0] + assert result_bank.identifier == new_bank.identifier + assert result_bank.provider_id == new_bank.provider_id + + # Verify persisted to disk + new_registry = DiskDistributionRegistry(await kvstore_impl(config)) + await new_registry.initialize() + results = await new_registry.get("test_bank_2") + assert len(results) == 1 + result_bank = results[0] + assert result_bank.identifier == new_bank.identifier + assert result_bank.provider_id == new_bank.provider_id + + +@pytest.mark.asyncio +async def test_duplicate_provider_registration(config): + cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) + await cached_registry.initialize() + + original_bank = VectorMemoryBankDef( + identifier="test_bank_2", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=256, + overlap_size_in_tokens=32, + provider_id="baz", + ) + await cached_registry.register(original_bank) + + duplicate_bank = VectorMemoryBankDef( + identifier="test_bank_2", + embedding_model="different-model", + chunk_size_in_tokens=128, + overlap_size_in_tokens=16, + provider_id="baz", # Same provider_id + ) + await cached_registry.register(duplicate_bank) + + results = await cached_registry.get("test_bank_2") + assert len(results) == 1 # Still only one result + assert ( + results[0].embedding_model == original_bank.embedding_model + ) # Original values preserved