mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
make distribution registry thread safe and other fixes (#449)
This PR makes the following changes: 1) Fixes the get_all and initialize impl to actually read the values returned from the range call to kvstore and not keys. 2) The start_key and end_key are fixed to correct perform the range query after the key format changes 3) Made the cache registry thread safe since there are multiple initializes called for each routing table. Tests: * Start stack * Register dataset * Kill stack * Bring stack up * dataset list ``` llama-stack-client datasets list +--------------+---------------+---------------------------------------------------------------------------------+---------+ | identifier | provider_id | metadata | type | +==============+===============+=================================================================================+=========+ | alpaca | huggingface-0 | {} | dataset | +--------------+---------------+---------------------------------------------------------------------------------+---------+ | mmlu | huggingface-0 | {'path': 'llama-stack/evals', 'name': 'evals__mmlu__details', 'split': 'train'} | dataset | +--------------+---------------+---------------------------------------------------------------------------------+---------+ ``` Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
15dee2b8b8
commit
e90ea1ab1e
3 changed files with 148 additions and 48 deletions
|
@ -302,7 +302,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
async def list_datasets(self) -> List[Dataset]:
|
async def list_datasets(self) -> List[Dataset]:
|
||||||
return await self.get_all_with_type("dataset")
|
return await self.get_all_with_type(ResourceType.dataset.value)
|
||||||
|
|
||||||
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
||||||
return await self.get_object_by_identifier("dataset", dataset_id)
|
return await self.get_object_by_identifier("dataset", dataset_id)
|
||||||
|
@ -341,7 +341,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||||
return await self.get_all_with_type("scoring_function")
|
return await self.get_all_with_type(ResourceType.scoring_function.value)
|
||||||
|
|
||||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||||
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||||
|
@ -355,8 +355,6 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
params: Optional[ScoringFnParams] = None,
|
params: Optional[ScoringFnParams] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if params is None:
|
|
||||||
params = {}
|
|
||||||
if provider_scoring_fn_id is None:
|
if provider_scoring_fn_id is None:
|
||||||
provider_scoring_fn_id = scoring_fn_id
|
provider_scoring_fn_id = scoring_fn_id
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
|
@ -371,6 +369,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
description=description,
|
description=description,
|
||||||
return_type=return_type,
|
return_type=return_type,
|
||||||
provider_resource_id=provider_scoring_fn_id,
|
provider_resource_id=provider_scoring_fn_id,
|
||||||
|
provider_id=provider_id,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
scoring_fn.provider_id = provider_id
|
scoring_fn.provider_id = provider_id
|
||||||
|
@ -379,7 +378,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
|
|
||||||
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
||||||
async def list_eval_tasks(self) -> List[EvalTask]:
|
async def list_eval_tasks(self) -> List[EvalTask]:
|
||||||
return await self.get_all_with_type("eval_task")
|
return await self.get_all_with_type(ResourceType.eval_task.value)
|
||||||
|
|
||||||
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
|
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
|
||||||
return await self.get_object_by_identifier("eval_task", name)
|
return await self.get_object_by_identifier("eval_task", name)
|
||||||
|
|
|
@ -4,7 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from typing import Dict, List, Optional, Protocol, Tuple
|
from typing import Dict, List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
|
@ -35,8 +37,35 @@ class DistributionRegistry(Protocol):
|
||||||
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
|
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v1"
|
KEY_VERSION = "v1"
|
||||||
KEY_FORMAT = f"distributions:registry:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_registry_key_range() -> Tuple[str, str]:
|
||||||
|
"""Returns the start and end keys for the registry range query."""
|
||||||
|
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
|
||||||
|
return start_key, f"{start_key}\xff"
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider]:
|
||||||
|
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
||||||
|
all_objects = []
|
||||||
|
for value in values:
|
||||||
|
try:
|
||||||
|
objects_data = json.loads(value)
|
||||||
|
objects = [
|
||||||
|
pydantic.parse_obj_as(
|
||||||
|
RoutableObjectWithProvider,
|
||||||
|
json.loads(obj_str),
|
||||||
|
)
|
||||||
|
for obj_str in objects_data
|
||||||
|
]
|
||||||
|
all_objects.extend(objects)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing value: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return all_objects
|
||||||
|
|
||||||
|
|
||||||
class DiskDistributionRegistry(DistributionRegistry):
|
class DiskDistributionRegistry(DistributionRegistry):
|
||||||
|
@ -53,12 +82,9 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
||||||
start_key = KEY_FORMAT.format(type="", identifier="")
|
start_key, end_key = _get_registry_key_range()
|
||||||
end_key = KEY_FORMAT.format(type="", identifier="\xff")
|
values = await self.kvstore.range(start_key, end_key)
|
||||||
keys = await self.kvstore.range(start_key, end_key)
|
return _parse_registry_values(values)
|
||||||
|
|
||||||
tuples = [(key.split(":")[-2], key.split(":")[-1]) for key in keys]
|
|
||||||
return [await self.get(type, identifier) for type, identifier in tuples]
|
|
||||||
|
|
||||||
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
|
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
|
||||||
json_str = await self.kvstore.get(
|
json_str = await self.kvstore.get(
|
||||||
|
@ -99,55 +125,84 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||||
def __init__(self, kvstore: KVStore):
|
def __init__(self, kvstore: KVStore):
|
||||||
super().__init__(kvstore)
|
super().__init__(kvstore)
|
||||||
self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {}
|
self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {}
|
||||||
|
self._initialized = False
|
||||||
|
self._initialize_lock = asyncio.Lock()
|
||||||
|
self._cache_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _locked_cache(self):
|
||||||
|
"""Context manager for safely accessing the cache with a lock."""
|
||||||
|
async with self._cache_lock:
|
||||||
|
yield self.cache
|
||||||
|
|
||||||
|
async def _ensure_initialized(self):
|
||||||
|
"""Ensures the registry is initialized before operations."""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self._initialize_lock:
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
start_key, end_key = _get_registry_key_range()
|
||||||
|
values = await self.kvstore.range(start_key, end_key)
|
||||||
|
objects = _parse_registry_values(values)
|
||||||
|
|
||||||
|
async with self._locked_cache() as cache:
|
||||||
|
for obj in objects:
|
||||||
|
cache_key = (obj.type, obj.identifier)
|
||||||
|
if cache_key not in cache:
|
||||||
|
cache[cache_key] = []
|
||||||
|
if not any(
|
||||||
|
cached_obj.provider_id == obj.provider_id
|
||||||
|
for cached_obj in cache[cache_key]
|
||||||
|
):
|
||||||
|
cache[cache_key].append(obj)
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
start_key = KEY_FORMAT.format(type="", identifier="")
|
await self._ensure_initialized()
|
||||||
end_key = KEY_FORMAT.format(type="", identifier="\xff")
|
|
||||||
|
|
||||||
keys = await self.kvstore.range(start_key, end_key)
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
type, identifier = key.split(":")[-2:]
|
|
||||||
objects = await super().get(type, identifier)
|
|
||||||
if objects:
|
|
||||||
self.cache[type, identifier] = objects
|
|
||||||
|
|
||||||
def get_cached(
|
def get_cached(
|
||||||
self, type: str, identifier: str
|
self, type: str, identifier: str
|
||||||
) -> List[RoutableObjectWithProvider]:
|
) -> List[RoutableObjectWithProvider]:
|
||||||
return self.cache.get((type, identifier), [])
|
return self.cache.get((type, identifier), [])[:] # Return a copy
|
||||||
|
|
||||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
||||||
return [item for sublist in self.cache.values() for item in sublist]
|
await self._ensure_initialized()
|
||||||
|
async with self._locked_cache() as cache:
|
||||||
|
return [item for sublist in cache.values() for item in sublist]
|
||||||
|
|
||||||
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
|
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
|
||||||
cachekey = (type, identifier)
|
await self._ensure_initialized()
|
||||||
if cachekey in self.cache:
|
cache_key = (type, identifier)
|
||||||
return self.cache[cachekey]
|
|
||||||
|
async with self._locked_cache() as cache:
|
||||||
|
if cache_key in cache:
|
||||||
|
return cache[cache_key][:]
|
||||||
|
|
||||||
objects = await super().get(type, identifier)
|
objects = await super().get(type, identifier)
|
||||||
if objects:
|
if objects:
|
||||||
self.cache[cachekey] = objects
|
async with self._locked_cache() as cache:
|
||||||
|
cache[cache_key] = objects
|
||||||
|
|
||||||
return objects
|
return objects
|
||||||
|
|
||||||
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
||||||
# First update disk
|
await self._ensure_initialized()
|
||||||
success = await super().register(obj)
|
success = await super().register(obj)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
# Then update cache
|
cache_key = (obj.type, obj.identifier)
|
||||||
cachekey = (obj.type, obj.identifier)
|
async with self._locked_cache() as cache:
|
||||||
if cachekey not in self.cache:
|
if cache_key not in cache:
|
||||||
self.cache[cachekey] = []
|
cache[cache_key] = []
|
||||||
|
if not any(
|
||||||
# Check if provider already exists in cache
|
cached_obj.provider_id == obj.provider_id
|
||||||
for cached_obj in self.cache[cachekey]:
|
for cached_obj in cache[cache_key]
|
||||||
if cached_obj.provider_id == obj.provider_id:
|
):
|
||||||
return success
|
cache[cache_key].append(obj)
|
||||||
|
|
||||||
# If not, update cache
|
|
||||||
self.cache[cachekey].append(obj)
|
|
||||||
|
|
||||||
return success
|
return success
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,7 @@ def sample_bank():
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
|
provider_resource_id="test_bank",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -52,6 +53,7 @@ def sample_bank():
|
||||||
def sample_model():
|
def sample_model():
|
||||||
return Model(
|
return Model(
|
||||||
identifier="test_model",
|
identifier="test_model",
|
||||||
|
provider_resource_id="test_model",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -59,7 +61,7 @@ def sample_model():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_registry_initialization(registry):
|
async def test_registry_initialization(registry):
|
||||||
# Test empty registry
|
# Test empty registry
|
||||||
results = await registry.get("nonexistent")
|
results = await registry.get("nonexistent", "nonexistent")
|
||||||
assert len(results) == 0
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,7 +72,7 @@ async def test_basic_registration(registry, sample_bank, sample_model):
|
||||||
print(f"Registering {sample_model}")
|
print(f"Registering {sample_model}")
|
||||||
await registry.register(sample_model)
|
await registry.register(sample_model)
|
||||||
print("Getting bank")
|
print("Getting bank")
|
||||||
results = await registry.get("test_bank")
|
results = await registry.get("memory_bank", "test_bank")
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
result_bank = results[0]
|
result_bank = results[0]
|
||||||
assert result_bank.identifier == sample_bank.identifier
|
assert result_bank.identifier == sample_bank.identifier
|
||||||
|
@ -79,7 +81,7 @@ async def test_basic_registration(registry, sample_bank, sample_model):
|
||||||
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
|
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
|
||||||
assert result_bank.provider_id == sample_bank.provider_id
|
assert result_bank.provider_id == sample_bank.provider_id
|
||||||
|
|
||||||
results = await registry.get("test_model")
|
results = await registry.get("model", "test_model")
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
result_model = results[0]
|
result_model = results[0]
|
||||||
assert result_model.identifier == sample_model.identifier
|
assert result_model.identifier == sample_model.identifier
|
||||||
|
@ -98,7 +100,7 @@ async def test_cached_registry_initialization(config, sample_bank, sample_model)
|
||||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||||
await cached_registry.initialize()
|
await cached_registry.initialize()
|
||||||
|
|
||||||
results = await cached_registry.get("test_bank")
|
results = await cached_registry.get("memory_bank", "test_bank")
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
result_bank = results[0]
|
result_bank = results[0]
|
||||||
assert result_bank.identifier == sample_bank.identifier
|
assert result_bank.identifier == sample_bank.identifier
|
||||||
|
@ -118,12 +120,13 @@ async def test_cached_registry_updates(config):
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=256,
|
chunk_size_in_tokens=256,
|
||||||
overlap_size_in_tokens=32,
|
overlap_size_in_tokens=32,
|
||||||
|
provider_resource_id="test_bank_2",
|
||||||
provider_id="baz",
|
provider_id="baz",
|
||||||
)
|
)
|
||||||
await cached_registry.register(new_bank)
|
await cached_registry.register(new_bank)
|
||||||
|
|
||||||
# Verify in cache
|
# Verify in cache
|
||||||
results = await cached_registry.get("test_bank_2")
|
results = await cached_registry.get("memory_bank", "test_bank_2")
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
result_bank = results[0]
|
result_bank = results[0]
|
||||||
assert result_bank.identifier == new_bank.identifier
|
assert result_bank.identifier == new_bank.identifier
|
||||||
|
@ -132,7 +135,7 @@ async def test_cached_registry_updates(config):
|
||||||
# Verify persisted to disk
|
# Verify persisted to disk
|
||||||
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||||
await new_registry.initialize()
|
await new_registry.initialize()
|
||||||
results = await new_registry.get("test_bank_2")
|
results = await new_registry.get("memory_bank", "test_bank_2")
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
result_bank = results[0]
|
result_bank = results[0]
|
||||||
assert result_bank.identifier == new_bank.identifier
|
assert result_bank.identifier == new_bank.identifier
|
||||||
|
@ -149,6 +152,7 @@ async def test_duplicate_provider_registration(config):
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=256,
|
chunk_size_in_tokens=256,
|
||||||
overlap_size_in_tokens=32,
|
overlap_size_in_tokens=32,
|
||||||
|
provider_resource_id="test_bank_2",
|
||||||
provider_id="baz",
|
provider_id="baz",
|
||||||
)
|
)
|
||||||
await cached_registry.register(original_bank)
|
await cached_registry.register(original_bank)
|
||||||
|
@ -158,12 +162,54 @@ async def test_duplicate_provider_registration(config):
|
||||||
embedding_model="different-model",
|
embedding_model="different-model",
|
||||||
chunk_size_in_tokens=128,
|
chunk_size_in_tokens=128,
|
||||||
overlap_size_in_tokens=16,
|
overlap_size_in_tokens=16,
|
||||||
|
provider_resource_id="test_bank_2",
|
||||||
provider_id="baz", # Same provider_id
|
provider_id="baz", # Same provider_id
|
||||||
)
|
)
|
||||||
await cached_registry.register(duplicate_bank)
|
await cached_registry.register(duplicate_bank)
|
||||||
|
|
||||||
results = await cached_registry.get("test_bank_2")
|
results = await cached_registry.get("memory_bank", "test_bank_2")
|
||||||
assert len(results) == 1 # Still only one result
|
assert len(results) == 1 # Still only one result
|
||||||
assert (
|
assert (
|
||||||
results[0].embedding_model == original_bank.embedding_model
|
results[0].embedding_model == original_bank.embedding_model
|
||||||
) # Original values preserved
|
) # Original values preserved
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_all_objects(config):
|
||||||
|
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||||
|
await cached_registry.initialize()
|
||||||
|
|
||||||
|
# Create multiple test banks
|
||||||
|
test_banks = [
|
||||||
|
VectorMemoryBank(
|
||||||
|
identifier=f"test_bank_{i}",
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=256,
|
||||||
|
overlap_size_in_tokens=32,
|
||||||
|
provider_resource_id=f"test_bank_{i}",
|
||||||
|
provider_id=f"provider_{i}",
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Register all banks
|
||||||
|
for bank in test_banks:
|
||||||
|
await cached_registry.register(bank)
|
||||||
|
|
||||||
|
# Test get_all retrieval
|
||||||
|
all_results = await cached_registry.get_all()
|
||||||
|
assert len(all_results) == 3
|
||||||
|
|
||||||
|
# Verify each bank was stored correctly
|
||||||
|
for original_bank in test_banks:
|
||||||
|
matching_banks = [
|
||||||
|
b for b in all_results if b.identifier == original_bank.identifier
|
||||||
|
]
|
||||||
|
assert len(matching_banks) == 1
|
||||||
|
stored_bank = matching_banks[0]
|
||||||
|
assert stored_bank.embedding_model == original_bank.embedding_model
|
||||||
|
assert stored_bank.provider_id == original_bank.provider_id
|
||||||
|
assert stored_bank.chunk_size_in_tokens == original_bank.chunk_size_in_tokens
|
||||||
|
assert (
|
||||||
|
stored_bank.overlap_size_in_tokens == original_bank.overlap_size_in_tokens
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue