make distribution registry thread safe and other fixes (#449)

This PR makes the following changes:
1) Fixes the get_all and initialize impl to actually read the values
returned from the range call to kvstore and not keys.
2) The start_key and end_key are fixed to correct perform the range
query after the key format changes
3) Made the cache registry thread safe since there are multiple
initializes called for each routing table.

Tests:
* Start stack
* Register dataset
* Kill stack
* Bring stack up
* dataset list
```
 llama-stack-client datasets list
+--------------+---------------+---------------------------------------------------------------------------------+---------+
| identifier   | provider_id   | metadata                                                                        | type    |
+==============+===============+=================================================================================+=========+
| alpaca       | huggingface-0 | {}                                                                              | dataset |
+--------------+---------------+---------------------------------------------------------------------------------+---------+
| mmlu         | huggingface-0 | {'path': 'llama-stack/evals', 'name': 'evals__mmlu__details', 'split': 'train'} | dataset |
+--------------+---------------+---------------------------------------------------------------------------------+---------+
```

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-13 15:12:34 -08:00 committed by GitHub
parent 15dee2b8b8
commit e90ea1ab1e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 148 additions and 48 deletions

View file

@ -4,7 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from contextlib import asynccontextmanager
from typing import Dict, List, Optional, Protocol, Tuple
import pydantic
@ -35,8 +37,35 @@ class DistributionRegistry(Protocol):
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v1"
KEY_FORMAT = f"distributions:registry:{KEY_VERSION}::" + "{type}:{identifier}"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
def _get_registry_key_range() -> Tuple[str, str]:
"""Returns the start and end keys for the registry range query."""
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
return start_key, f"{start_key}\xff"
def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider]:
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = []
for value in values:
try:
objects_data = json.loads(value)
objects = [
pydantic.parse_obj_as(
RoutableObjectWithProvider,
json.loads(obj_str),
)
for obj_str in objects_data
]
all_objects.extend(objects)
except Exception as e:
print(f"Error parsing value: {e}")
traceback.print_exc()
return all_objects
class DiskDistributionRegistry(DistributionRegistry):
@ -53,12 +82,9 @@ class DiskDistributionRegistry(DistributionRegistry):
return []
async def get_all(self) -> List[RoutableObjectWithProvider]:
start_key = KEY_FORMAT.format(type="", identifier="")
end_key = KEY_FORMAT.format(type="", identifier="\xff")
keys = await self.kvstore.range(start_key, end_key)
tuples = [(key.split(":")[-2], key.split(":")[-1]) for key in keys]
return [await self.get(type, identifier) for type, identifier in tuples]
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key)
return _parse_registry_values(values)
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
json_str = await self.kvstore.get(
@ -99,55 +125,84 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
def __init__(self, kvstore: KVStore):
super().__init__(kvstore)
self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {}
self._initialized = False
self._initialize_lock = asyncio.Lock()
self._cache_lock = asyncio.Lock()
@asynccontextmanager
async def _locked_cache(self):
"""Context manager for safely accessing the cache with a lock."""
async with self._cache_lock:
yield self.cache
async def _ensure_initialized(self):
"""Ensures the registry is initialized before operations."""
if self._initialized:
return
async with self._initialize_lock:
if self._initialized:
return
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key)
objects = _parse_registry_values(values)
async with self._locked_cache() as cache:
for obj in objects:
cache_key = (obj.type, obj.identifier)
if cache_key not in cache:
cache[cache_key] = []
if not any(
cached_obj.provider_id == obj.provider_id
for cached_obj in cache[cache_key]
):
cache[cache_key].append(obj)
self._initialized = True
async def initialize(self) -> None:
start_key = KEY_FORMAT.format(type="", identifier="")
end_key = KEY_FORMAT.format(type="", identifier="\xff")
keys = await self.kvstore.range(start_key, end_key)
for key in keys:
type, identifier = key.split(":")[-2:]
objects = await super().get(type, identifier)
if objects:
self.cache[type, identifier] = objects
await self._ensure_initialized()
def get_cached(
self, type: str, identifier: str
) -> List[RoutableObjectWithProvider]:
return self.cache.get((type, identifier), [])
return self.cache.get((type, identifier), [])[:] # Return a copy
async def get_all(self) -> List[RoutableObjectWithProvider]:
return [item for sublist in self.cache.values() for item in sublist]
await self._ensure_initialized()
async with self._locked_cache() as cache:
return [item for sublist in cache.values() for item in sublist]
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
cachekey = (type, identifier)
if cachekey in self.cache:
return self.cache[cachekey]
await self._ensure_initialized()
cache_key = (type, identifier)
async with self._locked_cache() as cache:
if cache_key in cache:
return cache[cache_key][:]
objects = await super().get(type, identifier)
if objects:
self.cache[cachekey] = objects
async with self._locked_cache() as cache:
cache[cache_key] = objects
return objects
async def register(self, obj: RoutableObjectWithProvider) -> bool:
# First update disk
await self._ensure_initialized()
success = await super().register(obj)
if success:
# Then update cache
cachekey = (obj.type, obj.identifier)
if cachekey not in self.cache:
self.cache[cachekey] = []
# Check if provider already exists in cache
for cached_obj in self.cache[cachekey]:
if cached_obj.provider_id == obj.provider_id:
return success
# If not, update cache
self.cache[cachekey].append(obj)
cache_key = (obj.type, obj.identifier)
async with self._locked_cache() as cache:
if cache_key not in cache:
cache[cache_key] = []
if not any(
cached_obj.provider_id == obj.provider_id
for cached_obj in cache[cache_key]
):
cache[cache_key].append(obj)
return success