diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 3345f4c26..8c1b0c1e7 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -302,7 +302,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> List[Dataset]: - return await self.get_all_with_type("dataset") + return await self.get_all_with_type(ResourceType.dataset.value) async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: return await self.get_object_by_identifier("dataset", dataset_id) @@ -341,7 +341,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> List[ScoringFn]: - return await self.get_all_with_type("scoring_function") + return await self.get_all_with_type(ResourceType.scoring_function.value) async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: return await self.get_object_by_identifier("scoring_function", scoring_fn_id) @@ -355,8 +355,6 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): provider_id: Optional[str] = None, params: Optional[ScoringFnParams] = None, ) -> None: - if params is None: - params = {} if provider_scoring_fn_id is None: provider_scoring_fn_id = scoring_fn_id if provider_id is None: @@ -371,6 +369,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): description=description, return_type=return_type, provider_resource_id=provider_scoring_fn_id, + provider_id=provider_id, params=params, ) scoring_fn.provider_id = provider_id @@ -379,7 +378,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks): async def list_eval_tasks(self) -> List[EvalTask]: - return await self.get_all_with_type("eval_task") + return await self.get_all_with_type(ResourceType.eval_task.value) async def get_eval_task(self, name: str) -> Optional[EvalTask]: return await self.get_object_by_identifier("eval_task", name) diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index d837c4375..bb87c81fa 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -4,7 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import json +from contextlib import asynccontextmanager from typing import Dict, List, Optional, Protocol, Tuple import pydantic @@ -35,8 +37,35 @@ class DistributionRegistry(Protocol): async def register(self, obj: RoutableObjectWithProvider) -> bool: ... +REGISTER_PREFIX = "distributions:registry" KEY_VERSION = "v1" -KEY_FORMAT = f"distributions:registry:{KEY_VERSION}::" + "{type}:{identifier}" +KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" + + +def _get_registry_key_range() -> Tuple[str, str]: + """Returns the start and end keys for the registry range query.""" + start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}" + return start_key, f"{start_key}\xff" + + +def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider]: + """Utility function to parse registry values into RoutableObjectWithProvider objects.""" + all_objects = [] + for value in values: + try: + objects_data = json.loads(value) + objects = [ + pydantic.parse_obj_as( + RoutableObjectWithProvider, + json.loads(obj_str), + ) + for obj_str in objects_data + ] + all_objects.extend(objects) + except Exception as e: + print(f"Error parsing value: {e}") + traceback.print_exc() + return all_objects class DiskDistributionRegistry(DistributionRegistry): @@ -53,12 +82,9 @@ class DiskDistributionRegistry(DistributionRegistry): return [] async def get_all(self) -> List[RoutableObjectWithProvider]: - start_key = KEY_FORMAT.format(type="", identifier="") - end_key = KEY_FORMAT.format(type="", identifier="\xff") - keys = await self.kvstore.range(start_key, end_key) - - tuples = [(key.split(":")[-2], key.split(":")[-1]) for key in keys] - return [await self.get(type, identifier) for type, identifier in tuples] + start_key, end_key = _get_registry_key_range() + values = await self.kvstore.range(start_key, end_key) + return _parse_registry_values(values) async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: json_str = await self.kvstore.get( @@ -99,55 +125,84 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): def __init__(self, kvstore: KVStore): super().__init__(kvstore) self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {} + self._initialized = False + self._initialize_lock = asyncio.Lock() + self._cache_lock = asyncio.Lock() + + @asynccontextmanager + async def _locked_cache(self): + """Context manager for safely accessing the cache with a lock.""" + async with self._cache_lock: + yield self.cache + + async def _ensure_initialized(self): + """Ensures the registry is initialized before operations.""" + if self._initialized: + return + + async with self._initialize_lock: + if self._initialized: + return + + start_key, end_key = _get_registry_key_range() + values = await self.kvstore.range(start_key, end_key) + objects = _parse_registry_values(values) + + async with self._locked_cache() as cache: + for obj in objects: + cache_key = (obj.type, obj.identifier) + if cache_key not in cache: + cache[cache_key] = [] + if not any( + cached_obj.provider_id == obj.provider_id + for cached_obj in cache[cache_key] + ): + cache[cache_key].append(obj) + + self._initialized = True async def initialize(self) -> None: - start_key = KEY_FORMAT.format(type="", identifier="") - end_key = KEY_FORMAT.format(type="", identifier="\xff") - - keys = await self.kvstore.range(start_key, end_key) - - for key in keys: - type, identifier = key.split(":")[-2:] - objects = await super().get(type, identifier) - if objects: - self.cache[type, identifier] = objects + await self._ensure_initialized() def get_cached( self, type: str, identifier: str ) -> List[RoutableObjectWithProvider]: - return self.cache.get((type, identifier), []) + return self.cache.get((type, identifier), [])[:] # Return a copy async def get_all(self) -> List[RoutableObjectWithProvider]: - return [item for sublist in self.cache.values() for item in sublist] + await self._ensure_initialized() + async with self._locked_cache() as cache: + return [item for sublist in cache.values() for item in sublist] async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: - cachekey = (type, identifier) - if cachekey in self.cache: - return self.cache[cachekey] + await self._ensure_initialized() + cache_key = (type, identifier) + + async with self._locked_cache() as cache: + if cache_key in cache: + return cache[cache_key][:] objects = await super().get(type, identifier) if objects: - self.cache[cachekey] = objects + async with self._locked_cache() as cache: + cache[cache_key] = objects return objects async def register(self, obj: RoutableObjectWithProvider) -> bool: - # First update disk + await self._ensure_initialized() success = await super().register(obj) if success: - # Then update cache - cachekey = (obj.type, obj.identifier) - if cachekey not in self.cache: - self.cache[cachekey] = [] - - # Check if provider already exists in cache - for cached_obj in self.cache[cachekey]: - if cached_obj.provider_id == obj.provider_id: - return success - - # If not, update cache - self.cache[cachekey].append(obj) + cache_key = (obj.type, obj.identifier) + async with self._locked_cache() as cache: + if cache_key not in cache: + cache[cache_key] = [] + if not any( + cached_obj.provider_id == obj.provider_id + for cached_obj in cache[cache_key] + ): + cache[cache_key].append(obj) return success diff --git a/llama_stack/distribution/store/tests/test_registry.py b/llama_stack/distribution/store/tests/test_registry.py index e5b64bdc6..7e389cccd 100644 --- a/llama_stack/distribution/store/tests/test_registry.py +++ b/llama_stack/distribution/store/tests/test_registry.py @@ -44,6 +44,7 @@ def sample_bank(): embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, + provider_resource_id="test_bank", provider_id="test-provider", ) @@ -52,6 +53,7 @@ def sample_bank(): def sample_model(): return Model( identifier="test_model", + provider_resource_id="test_model", provider_id="test-provider", ) @@ -59,7 +61,7 @@ def sample_model(): @pytest.mark.asyncio async def test_registry_initialization(registry): # Test empty registry - results = await registry.get("nonexistent") + results = await registry.get("nonexistent", "nonexistent") assert len(results) == 0 @@ -70,7 +72,7 @@ async def test_basic_registration(registry, sample_bank, sample_model): print(f"Registering {sample_model}") await registry.register(sample_model) print("Getting bank") - results = await registry.get("test_bank") + results = await registry.get("memory_bank", "test_bank") assert len(results) == 1 result_bank = results[0] assert result_bank.identifier == sample_bank.identifier @@ -79,7 +81,7 @@ async def test_basic_registration(registry, sample_bank, sample_model): assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens assert result_bank.provider_id == sample_bank.provider_id - results = await registry.get("test_model") + results = await registry.get("model", "test_model") assert len(results) == 1 result_model = results[0] assert result_model.identifier == sample_model.identifier @@ -98,7 +100,7 @@ async def test_cached_registry_initialization(config, sample_bank, sample_model) cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) await cached_registry.initialize() - results = await cached_registry.get("test_bank") + results = await cached_registry.get("memory_bank", "test_bank") assert len(results) == 1 result_bank = results[0] assert result_bank.identifier == sample_bank.identifier @@ -118,12 +120,13 @@ async def test_cached_registry_updates(config): embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=256, overlap_size_in_tokens=32, + provider_resource_id="test_bank_2", provider_id="baz", ) await cached_registry.register(new_bank) # Verify in cache - results = await cached_registry.get("test_bank_2") + results = await cached_registry.get("memory_bank", "test_bank_2") assert len(results) == 1 result_bank = results[0] assert result_bank.identifier == new_bank.identifier @@ -132,7 +135,7 @@ async def test_cached_registry_updates(config): # Verify persisted to disk new_registry = DiskDistributionRegistry(await kvstore_impl(config)) await new_registry.initialize() - results = await new_registry.get("test_bank_2") + results = await new_registry.get("memory_bank", "test_bank_2") assert len(results) == 1 result_bank = results[0] assert result_bank.identifier == new_bank.identifier @@ -149,6 +152,7 @@ async def test_duplicate_provider_registration(config): embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=256, overlap_size_in_tokens=32, + provider_resource_id="test_bank_2", provider_id="baz", ) await cached_registry.register(original_bank) @@ -158,12 +162,54 @@ async def test_duplicate_provider_registration(config): embedding_model="different-model", chunk_size_in_tokens=128, overlap_size_in_tokens=16, + provider_resource_id="test_bank_2", provider_id="baz", # Same provider_id ) await cached_registry.register(duplicate_bank) - results = await cached_registry.get("test_bank_2") + results = await cached_registry.get("memory_bank", "test_bank_2") assert len(results) == 1 # Still only one result assert ( results[0].embedding_model == original_bank.embedding_model ) # Original values preserved + + +@pytest.mark.asyncio +async def test_get_all_objects(config): + cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) + await cached_registry.initialize() + + # Create multiple test banks + test_banks = [ + VectorMemoryBank( + identifier=f"test_bank_{i}", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=256, + overlap_size_in_tokens=32, + provider_resource_id=f"test_bank_{i}", + provider_id=f"provider_{i}", + ) + for i in range(3) + ] + + # Register all banks + for bank in test_banks: + await cached_registry.register(bank) + + # Test get_all retrieval + all_results = await cached_registry.get_all() + assert len(all_results) == 3 + + # Verify each bank was stored correctly + for original_bank in test_banks: + matching_banks = [ + b for b in all_results if b.identifier == original_bank.identifier + ] + assert len(matching_banks) == 1 + stored_bank = matching_banks[0] + assert stored_bank.embedding_model == original_bank.embedding_model + assert stored_bank.provider_id == original_bank.provider_id + assert stored_bank.chunk_size_in_tokens == original_bank.chunk_size_in_tokens + assert ( + stored_bank.overlap_size_in_tokens == original_bank.overlap_size_in_tokens + )