# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import json from typing import Dict, List, Optional, Protocol, Tuple import pydantic from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.providers.utils.kvstore import ( KVStore, kvstore_impl, SqliteKVStoreConfig, ) class DistributionRegistry(Protocol): async def get_all(self) -> List[RoutableObjectWithProvider]: ... async def initialize(self) -> None: ... async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: ... def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: ... # The current data structure allows multiple objects with the same identifier but different providers. # This is not ideal - we should have a single object that can be served by multiple providers, # suggesting a data structure like (obj: Obj, providers: List[str]) rather than List[RoutableObjectWithProvider]. # The current approach could lead to inconsistencies if the same logical object has different data across providers. async def register(self, obj: RoutableObjectWithProvider) -> bool: ... KEY_VERSION = "v1" KEY_FORMAT = f"distributions:registry:{KEY_VERSION}::" + "{type}:{identifier}" class DiskDistributionRegistry(DistributionRegistry): def __init__(self, kvstore: KVStore): self.kvstore = kvstore async def initialize(self) -> None: pass def get_cached( self, type: str, identifier: str ) -> List[RoutableObjectWithProvider]: # Disk registry does not have a cache return [] async def get_all(self) -> List[RoutableObjectWithProvider]: start_key = KEY_FORMAT.format(type="", identifier="") end_key = KEY_FORMAT.format(type="", identifier="\xff") keys = await self.kvstore.range(start_key, end_key) tuples = [(key.split(":")[-2], key.split(":")[-1]) for key in keys] return [await self.get(type, identifier) for type, identifier in tuples] async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: json_str = await self.kvstore.get( KEY_FORMAT.format(type=type, identifier=identifier) ) if not json_str: return [] objects_data = json.loads(json_str) return [ pydantic.parse_obj_as( RoutableObjectWithProvider, json.loads(obj_str), ) for obj_str in objects_data ] async def register(self, obj: RoutableObjectWithProvider) -> bool: existing_objects = await self.get(obj.type, obj.identifier) # dont register if the object's providerid already exists for eobj in existing_objects: if eobj.provider_id == obj.provider_id: return False existing_objects.append(obj) objects_json = [ obj.model_dump_json() for obj in existing_objects ] # Fixed variable name await self.kvstore.set( KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), json.dumps(objects_json), ) return True class CachedDiskDistributionRegistry(DiskDistributionRegistry): def __init__(self, kvstore: KVStore): super().__init__(kvstore) self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {} async def initialize(self) -> None: start_key = KEY_FORMAT.format(type="", identifier="") end_key = KEY_FORMAT.format(type="", identifier="\xff") keys = await self.kvstore.range(start_key, end_key) for key in keys: type, identifier = key.split(":")[-2:] objects = await super().get(type, identifier) if objects: self.cache[type, identifier] = objects def get_cached( self, type: str, identifier: str ) -> List[RoutableObjectWithProvider]: return self.cache.get((type, identifier), []) async def get_all(self) -> List[RoutableObjectWithProvider]: return [item for sublist in self.cache.values() for item in sublist] async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: cachekey = (type, identifier) if cachekey in self.cache: return self.cache[cachekey] objects = await super().get(type, identifier) if objects: self.cache[cachekey] = objects return objects async def register(self, obj: RoutableObjectWithProvider) -> bool: # First update disk success = await super().register(obj) if success: # Then update cache cachekey = (obj.type, obj.identifier) if cachekey not in self.cache: self.cache[cachekey] = [] # Check if provider already exists in cache for cached_obj in self.cache[cachekey]: if cached_obj.provider_id == obj.provider_id: return success # If not, update cache self.cache[cachekey].append(obj) return success async def create_dist_registry( metadata_store: Optional[KVStoreConfig], image_name: str, ) -> tuple[CachedDiskDistributionRegistry, KVStore]: # instantiate kvstore for storing and retrieving distribution metadata if metadata_store: dist_kvstore = await kvstore_impl(metadata_store) else: dist_kvstore = await kvstore_impl( SqliteKVStoreConfig( db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix() ) ) return CachedDiskDistributionRegistry(dist_kvstore), dist_kvstore