diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 1f5461733..fc7eda012 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -152,7 +152,7 @@ class CommonRoutingTableImpl(RoutingTable): async def register_object(self, obj: RoutableObjectWithProvider): # Get existing objects from registry existing_objects = await self.dist_registry.get(obj.identifier) - + # Check for existing registration for existing_obj in existing_objects: if existing_obj.provider_id == obj.provider_id or not obj.provider_id: diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 6f8f91889..994fb475c 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -5,8 +5,7 @@ # the root directory of this source tree. import json -import asyncio -from typing import Protocol, Dict, List +from typing import Dict, List, Protocol import pydantic @@ -74,12 +73,15 @@ class DiskDistributionRegistry(DistributionRegistry): existing_objects.append(obj) - objects_json = [obj.model_dump_json() for obj in existing_objects] # Fixed variable name + objects_json = [ + obj.model_dump_json() for obj in existing_objects + ] # Fixed variable name await self.kvstore.set( KEY_FORMAT.format(obj.identifier), json.dumps(objects_json) ) return True + class CachedDiskDistributionRegistry(DiskDistributionRegistry): def __init__(self, kvstore: KVStore): super().__init__(kvstore) @@ -88,9 +90,9 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async def initialize(self) -> None: start_key = KEY_FORMAT.format("") end_key = KEY_FORMAT.format("\xff") - + keys = await self.kvstore.range(start_key, end_key) - + for key in keys: identifier = key.split(":")[-1] objects = await super().get(identifier) @@ -106,28 +108,28 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: if identifier in self.cache: return self.cache[identifier] - + objects = await super().get(identifier) if objects: self.cache[identifier] = objects - + return objects async def register(self, obj: RoutableObjectWithProvider) -> bool: # First update disk success = await super().register(obj) - + if success: # Then update cache if obj.identifier not in self.cache: self.cache[obj.identifier] = [] - + # Check if provider already exists in cache for cached_obj in self.cache[obj.identifier]: if cached_obj.provider_id == obj.provider_id: return success - + # If not, update cache self.cache[obj.identifier].append(obj) - - return success \ No newline at end of file + + return success diff --git a/llama_stack/distribution/store/tests/test_registry.py b/llama_stack/distribution/store/tests/test_registry.py index 1210c4bf7..a9df4bed6 100644 --- a/llama_stack/distribution/store/tests/test_registry.py +++ b/llama_stack/distribution/store/tests/test_registry.py @@ -5,14 +5,15 @@ # the root directory of this source tree. import os -import asyncio + import pytest import pytest_asyncio -from llama_stack.distribution.store import * -from llama_stack.apis.memory_banks import VectorMemoryBankDef +from llama_stack.distribution.store import * # noqa F403 from llama_stack.apis.inference import ModelDefWithProvider +from llama_stack.apis.memory_banks import VectorMemoryBankDef from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig -from llama_stack.distribution.datatypes import * +from llama_stack.distribution.datatypes import * # noqa F403 + @pytest.fixture def config(): @@ -21,45 +22,51 @@ def config(): os.remove(config.db_path) return config + @pytest_asyncio.fixture async def registry(config): registry = DiskDistributionRegistry(await kvstore_impl(config)) await registry.initialize() return registry + @pytest_asyncio.fixture async def cached_registry(config): registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) await registry.initialize() return registry + @pytest.fixture def sample_bank(): return VectorMemoryBankDef( identifier="test_bank", - embedding_model="all-MiniLM-L6-v2", + embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, - provider_id="test-provider" + provider_id="test-provider", ) + @pytest.fixture def sample_model(): return ModelDefWithProvider( identifier="test_model", llama_model="Llama3.2-3B-Instruct", - provider_id="test-provider" + provider_id="test-provider", ) + @pytest.mark.asyncio async def test_registry_initialization(registry): # Test empty registry results = await registry.get("nonexistent") assert len(results) == 0 + @pytest.mark.asyncio async def test_basic_registration(registry, sample_bank, sample_model): - print(f"Registering {sample_bank}") + print(f"Registering {sample_bank}") await registry.register(sample_bank) print(f"Registering {sample_model}") await registry.register(sample_model) @@ -80,6 +87,7 @@ async def test_basic_registration(registry, sample_bank, sample_model): assert result_model.llama_model == sample_model.llama_model assert result_model.provider_id == sample_model.provider_id + @pytest.mark.asyncio async def test_cached_registry_initialization(config, sample_bank, sample_model): # First populate the disk registry @@ -101,6 +109,7 @@ async def test_cached_registry_initialization(config, sample_bank, sample_model) assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens assert result_bank.provider_id == sample_bank.provider_id + @pytest.mark.asyncio async def test_cached_registry_updates(config): cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) @@ -131,6 +140,7 @@ async def test_cached_registry_updates(config): assert result_bank.identifier == new_bank.identifier assert result_bank.provider_id == new_bank.provider_id + @pytest.mark.asyncio async def test_duplicate_provider_registration(config): cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) @@ -153,7 +163,9 @@ async def test_duplicate_provider_registration(config): provider_id="baz", # Same provider_id ) await cached_registry.register(duplicate_bank) - + results = await cached_registry.get("test_bank_2") assert len(results) == 1 # Still only one result - assert results[0].embedding_model == original_bank.embedding_model # Original values preserved + assert ( + results[0].embedding_model == original_bank.embedding_model + ) # Original values preserved