From 19d730917ae700cef8953dd42c504421ef714ed6 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 4 Nov 2024 15:37:40 -0800 Subject: [PATCH] use registry to hydrate --- .../distribution/routers/routing_tables.py | 57 ++++--- llama_stack/distribution/server/server.py | 4 +- llama_stack/distribution/store/registry.py | 90 ++++++++-- .../distribution/store/tests/test_registry.py | 157 +++++++++++++++--- 4 files changed, 241 insertions(+), 67 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index d314614dd..c103c7010 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -47,8 +47,6 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: Registry = Dict[str, List[RoutableObjectWithProvider]] -# TODO: this routing table maintains state in memory purely. We need to -# add persistence to it when we add dynamic registration of objects. class CommonRoutingTableImpl(RoutingTable): def __init__( self, @@ -59,15 +57,13 @@ class CommonRoutingTableImpl(RoutingTable): self.dist_registry = dist_registry async def initialize(self) -> None: - self.registry: Registry = {} + # Initialize the registry if not already done + await self.dist_registry.initialize() - def add_objects( + async def add_objects( objs: List[RoutableObjectWithProvider], provider_id: str, cls ) -> None: for obj in objs: - if obj.identifier not in self.registry: - self.registry[obj.identifier] = [] - if cls is None: obj.provider_id = provider_id else: @@ -77,34 +73,36 @@ class CommonRoutingTableImpl(RoutingTable): obj.provider_id = provider_id else: obj = cls(**obj.model_dump(), provider_id=provider_id) - self.registry[obj.identifier].append(obj) + await self.dist_registry.register(obj) + # Register all objects from providers + print("impls_by_provider_id", self.impls_by_provider_id) for pid, p in self.impls_by_provider_id.items(): api = get_impl_api(p) if api == Api.inference: p.model_store = self models = await p.list_models() - add_objects(models, pid, ModelDefWithProvider) + await add_objects(models, pid, ModelDefWithProvider) elif api == Api.safety: p.shield_store = self shields = await p.list_shields() - add_objects(shields, pid, ShieldDefWithProvider) + await add_objects(shields, pid, ShieldDefWithProvider) elif api == Api.memory: p.memory_bank_store = self memory_banks = await p.list_memory_banks() - add_objects(memory_banks, pid, None) + await add_objects(memory_banks, pid, None) elif api == Api.datasetio: p.dataset_store = self datasets = await p.list_datasets() - add_objects(datasets, pid, DatasetDefWithProvider) + await add_objects(datasets, pid, DatasetDefWithProvider) elif api == Api.scoring: p.scoring_function_store = self scoring_functions = await p.list_scoring_functions() - add_objects(scoring_functions, pid, ScoringFnDefWithProvider) + await add_objects(scoring_functions, pid, ScoringFnDefWithProvider) async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): @@ -127,39 +125,44 @@ class CommonRoutingTableImpl(RoutingTable): else: raise ValueError("Unknown routing table type") - if routing_key not in self.registry: + # Get objects from disk registry + objects = self.dist_registry.get_cached(routing_key) + if not objects: apiname, objname = apiname_object() raise ValueError( f"`{routing_key}` not registered. Make sure there is an {apiname} provider serving this {objname}." ) - objs = self.registry[routing_key] - for obj in objs: + for obj in objects: if not provider_id or provider_id == obj.provider_id: return self.impls_by_provider_id[obj.provider_id] raise ValueError(f"Provider not found for `{routing_key}`") - def get_object_by_identifier( + async def get_object_by_identifier( self, identifier: str ) -> Optional[RoutableObjectWithProvider]: - objs = self.registry.get(identifier, []) - if not objs: + # Get from disk registry + objects = await self.dist_registry.get(identifier) + if not objects: return None # kind of ill-defined behavior here, but we'll just return the first one - return objs[0] + return objects[0] async def register_object(self, obj: RoutableObjectWithProvider): - entries = self.registry.get(obj.identifier, []) - for entry in entries: - if entry.provider_id == obj.provider_id or not obj.provider_id: + # Get existing objects from registry + existing_objects = await self.dist_registry.get(obj.identifier) + + # Check for existing registration + for existing_obj in existing_objects: + if existing_obj.provider_id == obj.provider_id or not obj.provider_id: print( - f"`{obj.identifier}` already registered with `{entry.provider_id}`" + f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`" ) return - # if provider_id is not specified, we'll pick an arbitrary one from existing entries + # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: obj.provider_id = list(self.impls_by_provider_id.keys())[0] @@ -169,10 +172,6 @@ class CommonRoutingTableImpl(RoutingTable): p = self.impls_by_provider_id[obj.provider_id] await register_object_with_provider(obj, p) - - if obj.identifier not in self.registry: - self.registry[obj.identifier] = [] - self.registry[obj.identifier].append(obj) await self.dist_registry.register(obj) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 35d9fe484..636da3916 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -42,7 +42,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_impls -from llama_stack.distribution.store import DiskDistributionRegistry +from llama_stack.distribution.store import CachedDiskDistributionRegistry from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from .endpoints import get_all_api_endpoints @@ -295,7 +295,7 @@ def main( ) ) - dist_registry = DiskDistributionRegistry(dist_kvstore) + dist_registry = CachedDiskDistributionRegistry(dist_kvstore) impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry)) if Api.telemetry in impls: diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index d14c66c2a..6f8f91889 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -5,8 +5,8 @@ # the root directory of this source tree. import json - -from typing import Protocol +import asyncio +from typing import Protocol, Dict, List import pydantic @@ -16,14 +16,19 @@ from llama_stack.providers.utils.kvstore import KVStore class DistributionRegistry(Protocol): + async def get_all(self) -> List[RoutableObjectWithProvider]: ... - async def get(self, identifier: str) -> [RoutableObjectWithProvider]: ... + async def initialize(self) -> None: ... + + async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: ... + + def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: ... # The current data structure allows multiple objects with the same identifier but different providers. # This is not ideal - we should have a single object that can be served by multiple providers, # suggesting a data structure like (obj: Obj, providers: List[str]) rather than List[RoutableObjectWithProvider]. # The current approach could lead to inconsistencies if the same logical object has different data across providers. - async def register(self, obj: RoutableObjectWithProvider) -> None: ... + async def register(self, obj: RoutableObjectWithProvider) -> bool: ... KEY_FORMAT = "distributions:registry:{}" @@ -33,14 +38,25 @@ class DiskDistributionRegistry(DistributionRegistry): def __init__(self, kvstore: KVStore): self.kvstore = kvstore - async def get(self, identifier: str) -> [RoutableObjectWithProvider]: - # Get JSON string from kvstore + async def initialize(self) -> None: + pass + + def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: + # Disk registry does not have a cache + return [] + + async def get_all(self) -> List[RoutableObjectWithProvider]: + start_key = KEY_FORMAT.format("") + end_key = KEY_FORMAT.format("\xff") + keys = await self.kvstore.range(start_key, end_key) + return [await self.get(key.split(":")[-1]) for key in keys] + + async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: json_str = await self.kvstore.get(KEY_FORMAT.format(identifier)) if not json_str: return [] objects_data = json.loads(json_str) - return [ pydantic.parse_obj_as( RoutableObjectWithProvider, @@ -49,17 +65,69 @@ class DiskDistributionRegistry(DistributionRegistry): for obj_str in objects_data ] - # TODO: make it thread safe using CAS - async def register(self, obj: RoutableObjectWithProvider) -> None: + async def register(self, obj: RoutableObjectWithProvider) -> bool: existing_objects = await self.get(obj.identifier) # dont register if the object's providerid already exists for eobj in existing_objects: if eobj.provider_id == obj.provider_id: - return + return False existing_objects.append(obj) - objects_json = [obj.model_dump_json() for existing_object in existing_objects] + objects_json = [obj.model_dump_json() for obj in existing_objects] # Fixed variable name await self.kvstore.set( KEY_FORMAT.format(obj.identifier), json.dumps(objects_json) ) + return True + +class CachedDiskDistributionRegistry(DiskDistributionRegistry): + def __init__(self, kvstore: KVStore): + super().__init__(kvstore) + self.cache: Dict[str, List[RoutableObjectWithProvider]] = {} + + async def initialize(self) -> None: + start_key = KEY_FORMAT.format("") + end_key = KEY_FORMAT.format("\xff") + + keys = await self.kvstore.range(start_key, end_key) + + for key in keys: + identifier = key.split(":")[-1] + objects = await super().get(identifier) + if objects: + self.cache[identifier] = objects + + def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: + return self.cache.get(identifier, []) + + async def get_all(self) -> List[RoutableObjectWithProvider]: + return [item for sublist in self.cache.values() for item in sublist] + + async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: + if identifier in self.cache: + return self.cache[identifier] + + objects = await super().get(identifier) + if objects: + self.cache[identifier] = objects + + return objects + + async def register(self, obj: RoutableObjectWithProvider) -> bool: + # First update disk + success = await super().register(obj) + + if success: + # Then update cache + if obj.identifier not in self.cache: + self.cache[obj.identifier] = [] + + # Check if provider already exists in cache + for cached_obj in self.cache[obj.identifier]: + if cached_obj.provider_id == obj.provider_id: + return success + + # If not, update cache + self.cache[obj.identifier].append(obj) + + return success \ No newline at end of file diff --git a/llama_stack/distribution/store/tests/test_registry.py b/llama_stack/distribution/store/tests/test_registry.py index ab9457707..1210c4bf7 100644 --- a/llama_stack/distribution/store/tests/test_registry.py +++ b/llama_stack/distribution/store/tests/test_registry.py @@ -5,48 +5,155 @@ # the root directory of this source tree. import os - +import asyncio import pytest -from llama_stack.distribution.store import * # noqa: F403 +import pytest_asyncio +from llama_stack.distribution.store import * from llama_stack.apis.memory_banks import VectorMemoryBankDef +from llama_stack.apis.inference import ModelDefWithProvider from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.distribution.datatypes import * - -@pytest.mark.asyncio -async def test_registry(): +@pytest.fixture +def config(): config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db") - # delete the file if it exists if os.path.exists(config.db_path): os.remove(config.db_path) + return config + +@pytest_asyncio.fixture +async def registry(config): registry = DiskDistributionRegistry(await kvstore_impl(config)) - bank = VectorMemoryBankDef( + await registry.initialize() + return registry + +@pytest_asyncio.fixture +async def cached_registry(config): + registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) + await registry.initialize() + return registry + +@pytest.fixture +def sample_bank(): + return VectorMemoryBankDef( identifier="test_bank", - embedding_model="all-MiniLM-L6-v2", + embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, - provider_id="bar", - ) - model = ModelDefWithProvider( - identifier="test_model", - llama_model="Llama3.2-3B-Instruct", - provider_id="foo", + provider_id="test-provider" ) - await registry.register(bank) - await registry.register(model) +@pytest.fixture +def sample_model(): + return ModelDefWithProvider( + identifier="test_model", + llama_model="Llama3.2-3B-Instruct", + provider_id="test-provider" + ) + +@pytest.mark.asyncio +async def test_registry_initialization(registry): + # Test empty registry + results = await registry.get("nonexistent") + assert len(results) == 0 + +@pytest.mark.asyncio +async def test_basic_registration(registry, sample_bank, sample_model): + print(f"Registering {sample_bank}") + await registry.register(sample_bank) + print(f"Registering {sample_model}") + await registry.register(sample_model) + print("Getting bank") results = await registry.get("test_bank") assert len(results) == 1 result_bank = results[0] - assert result_bank.identifier == bank.identifier - assert result_bank.embedding_model == bank.embedding_model - assert result_bank.chunk_size_in_tokens == bank.chunk_size_in_tokens - assert result_bank.overlap_size_in_tokens == bank.overlap_size_in_tokens - assert result_bank.provider_id == bank.provider_id + assert result_bank.identifier == sample_bank.identifier + assert result_bank.embedding_model == sample_bank.embedding_model + assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens + assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens + assert result_bank.provider_id == sample_bank.provider_id results = await registry.get("test_model") assert len(results) == 1 result_model = results[0] - assert result_model.identifier == model.identifier - assert result_model.llama_model == model.llama_model - assert result_model.provider_id == model.provider_id + assert result_model.identifier == sample_model.identifier + assert result_model.llama_model == sample_model.llama_model + assert result_model.provider_id == sample_model.provider_id + +@pytest.mark.asyncio +async def test_cached_registry_initialization(config, sample_bank, sample_model): + # First populate the disk registry + disk_registry = DiskDistributionRegistry(await kvstore_impl(config)) + await disk_registry.initialize() + await disk_registry.register(sample_bank) + await disk_registry.register(sample_model) + + # Test cached version loads from disk + cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) + await cached_registry.initialize() + + results = await cached_registry.get("test_bank") + assert len(results) == 1 + result_bank = results[0] + assert result_bank.identifier == sample_bank.identifier + assert result_bank.embedding_model == sample_bank.embedding_model + assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens + assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens + assert result_bank.provider_id == sample_bank.provider_id + +@pytest.mark.asyncio +async def test_cached_registry_updates(config): + cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) + await cached_registry.initialize() + + new_bank = VectorMemoryBankDef( + identifier="test_bank_2", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=256, + overlap_size_in_tokens=32, + provider_id="baz", + ) + await cached_registry.register(new_bank) + + # Verify in cache + results = await cached_registry.get("test_bank_2") + assert len(results) == 1 + result_bank = results[0] + assert result_bank.identifier == new_bank.identifier + assert result_bank.provider_id == new_bank.provider_id + + # Verify persisted to disk + new_registry = DiskDistributionRegistry(await kvstore_impl(config)) + await new_registry.initialize() + results = await new_registry.get("test_bank_2") + assert len(results) == 1 + result_bank = results[0] + assert result_bank.identifier == new_bank.identifier + assert result_bank.provider_id == new_bank.provider_id + +@pytest.mark.asyncio +async def test_duplicate_provider_registration(config): + cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) + await cached_registry.initialize() + + original_bank = VectorMemoryBankDef( + identifier="test_bank_2", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=256, + overlap_size_in_tokens=32, + provider_id="baz", + ) + await cached_registry.register(original_bank) + + duplicate_bank = VectorMemoryBankDef( + identifier="test_bank_2", + embedding_model="different-model", + chunk_size_in_tokens=128, + overlap_size_in_tokens=16, + provider_id="baz", # Same provider_id + ) + await cached_registry.register(duplicate_bank) + + results = await cached_registry.get("test_bank_2") + assert len(results) == 1 # Still only one result + assert results[0].embedding_model == original_bank.embedding_model # Original values preserved