make distribution registry thread safe and other fixes

This commit is contained in:
Dinesh Yeduguru 2024-11-13 14:47:30 -08:00
parent 787e2034b7
commit 40b55ed0d0
3 changed files with 148 additions and 48 deletions

View file

@ -302,7 +302,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> List[Dataset]:
return await self.get_all_with_type("dataset")
return await self.get_all_with_type(ResourceType.dataset.value)
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id)
@ -341,7 +341,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFn]:
return await self.get_all_with_type("scoring_function")
return await self.get_all_with_type(ResourceType.scoring_function.value)
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
@ -355,8 +355,6 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
provider_id: Optional[str] = None,
params: Optional[ScoringFnParams] = None,
) -> None:
if params is None:
params = {}
if provider_scoring_fn_id is None:
provider_scoring_fn_id = scoring_fn_id
if provider_id is None:
@ -371,6 +369,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
description=description,
return_type=return_type,
provider_resource_id=provider_scoring_fn_id,
provider_id=provider_id,
params=params,
)
scoring_fn.provider_id = provider_id
@ -379,7 +378,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
async def list_eval_tasks(self) -> List[EvalTask]:
return await self.get_all_with_type("eval_task")
return await self.get_all_with_type(ResourceType.eval_task.value)
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
return await self.get_object_by_identifier("eval_task", name)

View file

@ -4,7 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from contextlib import asynccontextmanager
from typing import Dict, List, Optional, Protocol, Tuple
import pydantic
@ -35,8 +37,35 @@ class DistributionRegistry(Protocol):
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v1"
KEY_FORMAT = f"distributions:registry:{KEY_VERSION}::" + "{type}:{identifier}"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
def _get_registry_key_range() -> Tuple[str, str]:
"""Returns the start and end keys for the registry range query."""
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
return start_key, f"{start_key}\xff"
def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider]:
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = []
for value in values:
try:
objects_data = json.loads(value)
objects = [
pydantic.parse_obj_as(
RoutableObjectWithProvider,
json.loads(obj_str),
)
for obj_str in objects_data
]
all_objects.extend(objects)
except Exception as e:
print(f"Error parsing value: {e}")
traceback.print_exc()
return all_objects
class DiskDistributionRegistry(DistributionRegistry):
@ -53,12 +82,9 @@ class DiskDistributionRegistry(DistributionRegistry):
return []
async def get_all(self) -> List[RoutableObjectWithProvider]:
start_key = KEY_FORMAT.format(type="", identifier="")
end_key = KEY_FORMAT.format(type="", identifier="\xff")
keys = await self.kvstore.range(start_key, end_key)
tuples = [(key.split(":")[-2], key.split(":")[-1]) for key in keys]
return [await self.get(type, identifier) for type, identifier in tuples]
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key)
return _parse_registry_values(values)
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
json_str = await self.kvstore.get(
@ -99,55 +125,84 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
def __init__(self, kvstore: KVStore):
super().__init__(kvstore)
self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {}
self._initialized = False
self._initialize_lock = asyncio.Lock()
self._cache_lock = asyncio.Lock()
@asynccontextmanager
async def _locked_cache(self):
"""Context manager for safely accessing the cache with a lock."""
async with self._cache_lock:
yield self.cache
async def _ensure_initialized(self):
"""Ensures the registry is initialized before operations."""
if self._initialized:
return
async with self._initialize_lock:
if self._initialized:
return
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key)
objects = _parse_registry_values(values)
async with self._locked_cache() as cache:
for obj in objects:
cache_key = (obj.type, obj.identifier)
if cache_key not in cache:
cache[cache_key] = []
if not any(
cached_obj.provider_id == obj.provider_id
for cached_obj in cache[cache_key]
):
cache[cache_key].append(obj)
self._initialized = True
async def initialize(self) -> None:
start_key = KEY_FORMAT.format(type="", identifier="")
end_key = KEY_FORMAT.format(type="", identifier="\xff")
keys = await self.kvstore.range(start_key, end_key)
for key in keys:
type, identifier = key.split(":")[-2:]
objects = await super().get(type, identifier)
if objects:
self.cache[type, identifier] = objects
await self._ensure_initialized()
def get_cached(
self, type: str, identifier: str
) -> List[RoutableObjectWithProvider]:
return self.cache.get((type, identifier), [])
return self.cache.get((type, identifier), [])[:] # Return a copy
async def get_all(self) -> List[RoutableObjectWithProvider]:
return [item for sublist in self.cache.values() for item in sublist]
await self._ensure_initialized()
async with self._locked_cache() as cache:
return [item for sublist in cache.values() for item in sublist]
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
cachekey = (type, identifier)
if cachekey in self.cache:
return self.cache[cachekey]
await self._ensure_initialized()
cache_key = (type, identifier)
async with self._locked_cache() as cache:
if cache_key in cache:
return cache[cache_key][:]
objects = await super().get(type, identifier)
if objects:
self.cache[cachekey] = objects
async with self._locked_cache() as cache:
cache[cache_key] = objects
return objects
async def register(self, obj: RoutableObjectWithProvider) -> bool:
# First update disk
await self._ensure_initialized()
success = await super().register(obj)
if success:
# Then update cache
cachekey = (obj.type, obj.identifier)
if cachekey not in self.cache:
self.cache[cachekey] = []
# Check if provider already exists in cache
for cached_obj in self.cache[cachekey]:
if cached_obj.provider_id == obj.provider_id:
return success
# If not, update cache
self.cache[cachekey].append(obj)
cache_key = (obj.type, obj.identifier)
async with self._locked_cache() as cache:
if cache_key not in cache:
cache[cache_key] = []
if not any(
cached_obj.provider_id == obj.provider_id
for cached_obj in cache[cache_key]
):
cache[cache_key].append(obj)
return success

View file

@ -44,6 +44,7 @@ def sample_bank():
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
provider_resource_id="test_bank",
provider_id="test-provider",
)
@ -52,6 +53,7 @@ def sample_bank():
def sample_model():
return Model(
identifier="test_model",
provider_resource_id="test_model",
provider_id="test-provider",
)
@ -59,7 +61,7 @@ def sample_model():
@pytest.mark.asyncio
async def test_registry_initialization(registry):
# Test empty registry
results = await registry.get("nonexistent")
results = await registry.get("nonexistent", "nonexistent")
assert len(results) == 0
@ -70,7 +72,7 @@ async def test_basic_registration(registry, sample_bank, sample_model):
print(f"Registering {sample_model}")
await registry.register(sample_model)
print("Getting bank")
results = await registry.get("test_bank")
results = await registry.get("memory_bank", "test_bank")
assert len(results) == 1
result_bank = results[0]
assert result_bank.identifier == sample_bank.identifier
@ -79,7 +81,7 @@ async def test_basic_registration(registry, sample_bank, sample_model):
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
assert result_bank.provider_id == sample_bank.provider_id
results = await registry.get("test_model")
results = await registry.get("model", "test_model")
assert len(results) == 1
result_model = results[0]
assert result_model.identifier == sample_model.identifier
@ -98,7 +100,7 @@ async def test_cached_registry_initialization(config, sample_bank, sample_model)
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
results = await cached_registry.get("test_bank")
results = await cached_registry.get("memory_bank", "test_bank")
assert len(results) == 1
result_bank = results[0]
assert result_bank.identifier == sample_bank.identifier
@ -118,12 +120,13 @@ async def test_cached_registry_updates(config):
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=256,
overlap_size_in_tokens=32,
provider_resource_id="test_bank_2",
provider_id="baz",
)
await cached_registry.register(new_bank)
# Verify in cache
results = await cached_registry.get("test_bank_2")
results = await cached_registry.get("memory_bank", "test_bank_2")
assert len(results) == 1
result_bank = results[0]
assert result_bank.identifier == new_bank.identifier
@ -132,7 +135,7 @@ async def test_cached_registry_updates(config):
# Verify persisted to disk
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
await new_registry.initialize()
results = await new_registry.get("test_bank_2")
results = await new_registry.get("memory_bank", "test_bank_2")
assert len(results) == 1
result_bank = results[0]
assert result_bank.identifier == new_bank.identifier
@ -149,6 +152,7 @@ async def test_duplicate_provider_registration(config):
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=256,
overlap_size_in_tokens=32,
provider_resource_id="test_bank_2",
provider_id="baz",
)
await cached_registry.register(original_bank)
@ -158,12 +162,54 @@ async def test_duplicate_provider_registration(config):
embedding_model="different-model",
chunk_size_in_tokens=128,
overlap_size_in_tokens=16,
provider_resource_id="test_bank_2",
provider_id="baz", # Same provider_id
)
await cached_registry.register(duplicate_bank)
results = await cached_registry.get("test_bank_2")
results = await cached_registry.get("memory_bank", "test_bank_2")
assert len(results) == 1 # Still only one result
assert (
results[0].embedding_model == original_bank.embedding_model
) # Original values preserved
@pytest.mark.asyncio
async def test_get_all_objects(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
# Create multiple test banks
test_banks = [
VectorMemoryBank(
identifier=f"test_bank_{i}",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=256,
overlap_size_in_tokens=32,
provider_resource_id=f"test_bank_{i}",
provider_id=f"provider_{i}",
)
for i in range(3)
]
# Register all banks
for bank in test_banks:
await cached_registry.register(bank)
# Test get_all retrieval
all_results = await cached_registry.get_all()
assert len(all_results) == 3
# Verify each bank was stored correctly
for original_bank in test_banks:
matching_banks = [
b for b in all_results if b.identifier == original_bank.identifier
]
assert len(matching_banks) == 1
stored_bank = matching_banks[0]
assert stored_bank.embedding_model == original_bank.embedding_model
assert stored_bank.provider_id == original_bank.provider_id
assert stored_bank.chunk_size_in_tokens == original_bank.chunk_size_in_tokens
assert (
stored_bank.overlap_size_in_tokens == original_bank.overlap_size_in_tokens
)