mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
make distribution registry thread safe and other fixes
This commit is contained in:
parent
787e2034b7
commit
40b55ed0d0
3 changed files with 148 additions and 48 deletions
|
@ -302,7 +302,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
async def list_datasets(self) -> List[Dataset]:
|
async def list_datasets(self) -> List[Dataset]:
|
||||||
return await self.get_all_with_type("dataset")
|
return await self.get_all_with_type(ResourceType.dataset.value)
|
||||||
|
|
||||||
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
||||||
return await self.get_object_by_identifier("dataset", dataset_id)
|
return await self.get_object_by_identifier("dataset", dataset_id)
|
||||||
|
@ -341,7 +341,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||||
return await self.get_all_with_type("scoring_function")
|
return await self.get_all_with_type(ResourceType.scoring_function.value)
|
||||||
|
|
||||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||||
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||||
|
@ -355,8 +355,6 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
params: Optional[ScoringFnParams] = None,
|
params: Optional[ScoringFnParams] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if params is None:
|
|
||||||
params = {}
|
|
||||||
if provider_scoring_fn_id is None:
|
if provider_scoring_fn_id is None:
|
||||||
provider_scoring_fn_id = scoring_fn_id
|
provider_scoring_fn_id = scoring_fn_id
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
|
@ -371,6 +369,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
description=description,
|
description=description,
|
||||||
return_type=return_type,
|
return_type=return_type,
|
||||||
provider_resource_id=provider_scoring_fn_id,
|
provider_resource_id=provider_scoring_fn_id,
|
||||||
|
provider_id=provider_id,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
scoring_fn.provider_id = provider_id
|
scoring_fn.provider_id = provider_id
|
||||||
|
@ -379,7 +378,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
|
|
||||||
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
||||||
async def list_eval_tasks(self) -> List[EvalTask]:
|
async def list_eval_tasks(self) -> List[EvalTask]:
|
||||||
return await self.get_all_with_type("eval_task")
|
return await self.get_all_with_type(ResourceType.eval_task.value)
|
||||||
|
|
||||||
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
|
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
|
||||||
return await self.get_object_by_identifier("eval_task", name)
|
return await self.get_object_by_identifier("eval_task", name)
|
||||||
|
|
|
@ -4,7 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from typing import Dict, List, Optional, Protocol, Tuple
|
from typing import Dict, List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
|
@ -35,8 +37,35 @@ class DistributionRegistry(Protocol):
|
||||||
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
|
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v1"
|
KEY_VERSION = "v1"
|
||||||
KEY_FORMAT = f"distributions:registry:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_registry_key_range() -> Tuple[str, str]:
|
||||||
|
"""Returns the start and end keys for the registry range query."""
|
||||||
|
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
|
||||||
|
return start_key, f"{start_key}\xff"
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider]:
|
||||||
|
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
||||||
|
all_objects = []
|
||||||
|
for value in values:
|
||||||
|
try:
|
||||||
|
objects_data = json.loads(value)
|
||||||
|
objects = [
|
||||||
|
pydantic.parse_obj_as(
|
||||||
|
RoutableObjectWithProvider,
|
||||||
|
json.loads(obj_str),
|
||||||
|
)
|
||||||
|
for obj_str in objects_data
|
||||||
|
]
|
||||||
|
all_objects.extend(objects)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing value: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return all_objects
|
||||||
|
|
||||||
|
|
||||||
class DiskDistributionRegistry(DistributionRegistry):
|
class DiskDistributionRegistry(DistributionRegistry):
|
||||||
|
@ -53,12 +82,9 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
||||||
start_key = KEY_FORMAT.format(type="", identifier="")
|
start_key, end_key = _get_registry_key_range()
|
||||||
end_key = KEY_FORMAT.format(type="", identifier="\xff")
|
values = await self.kvstore.range(start_key, end_key)
|
||||||
keys = await self.kvstore.range(start_key, end_key)
|
return _parse_registry_values(values)
|
||||||
|
|
||||||
tuples = [(key.split(":")[-2], key.split(":")[-1]) for key in keys]
|
|
||||||
return [await self.get(type, identifier) for type, identifier in tuples]
|
|
||||||
|
|
||||||
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
|
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
|
||||||
json_str = await self.kvstore.get(
|
json_str = await self.kvstore.get(
|
||||||
|
@ -99,55 +125,84 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||||
def __init__(self, kvstore: KVStore):
|
def __init__(self, kvstore: KVStore):
|
||||||
super().__init__(kvstore)
|
super().__init__(kvstore)
|
||||||
self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {}
|
self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {}
|
||||||
|
self._initialized = False
|
||||||
|
self._initialize_lock = asyncio.Lock()
|
||||||
|
self._cache_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _locked_cache(self):
|
||||||
|
"""Context manager for safely accessing the cache with a lock."""
|
||||||
|
async with self._cache_lock:
|
||||||
|
yield self.cache
|
||||||
|
|
||||||
|
async def _ensure_initialized(self):
|
||||||
|
"""Ensures the registry is initialized before operations."""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self._initialize_lock:
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
start_key, end_key = _get_registry_key_range()
|
||||||
|
values = await self.kvstore.range(start_key, end_key)
|
||||||
|
objects = _parse_registry_values(values)
|
||||||
|
|
||||||
|
async with self._locked_cache() as cache:
|
||||||
|
for obj in objects:
|
||||||
|
cache_key = (obj.type, obj.identifier)
|
||||||
|
if cache_key not in cache:
|
||||||
|
cache[cache_key] = []
|
||||||
|
if not any(
|
||||||
|
cached_obj.provider_id == obj.provider_id
|
||||||
|
for cached_obj in cache[cache_key]
|
||||||
|
):
|
||||||
|
cache[cache_key].append(obj)
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
start_key = KEY_FORMAT.format(type="", identifier="")
|
await self._ensure_initialized()
|
||||||
end_key = KEY_FORMAT.format(type="", identifier="\xff")
|
|
||||||
|
|
||||||
keys = await self.kvstore.range(start_key, end_key)
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
type, identifier = key.split(":")[-2:]
|
|
||||||
objects = await super().get(type, identifier)
|
|
||||||
if objects:
|
|
||||||
self.cache[type, identifier] = objects
|
|
||||||
|
|
||||||
def get_cached(
|
def get_cached(
|
||||||
self, type: str, identifier: str
|
self, type: str, identifier: str
|
||||||
) -> List[RoutableObjectWithProvider]:
|
) -> List[RoutableObjectWithProvider]:
|
||||||
return self.cache.get((type, identifier), [])
|
return self.cache.get((type, identifier), [])[:] # Return a copy
|
||||||
|
|
||||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
||||||
return [item for sublist in self.cache.values() for item in sublist]
|
await self._ensure_initialized()
|
||||||
|
async with self._locked_cache() as cache:
|
||||||
|
return [item for sublist in cache.values() for item in sublist]
|
||||||
|
|
||||||
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
|
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
|
||||||
cachekey = (type, identifier)
|
await self._ensure_initialized()
|
||||||
if cachekey in self.cache:
|
cache_key = (type, identifier)
|
||||||
return self.cache[cachekey]
|
|
||||||
|
async with self._locked_cache() as cache:
|
||||||
|
if cache_key in cache:
|
||||||
|
return cache[cache_key][:]
|
||||||
|
|
||||||
objects = await super().get(type, identifier)
|
objects = await super().get(type, identifier)
|
||||||
if objects:
|
if objects:
|
||||||
self.cache[cachekey] = objects
|
async with self._locked_cache() as cache:
|
||||||
|
cache[cache_key] = objects
|
||||||
|
|
||||||
return objects
|
return objects
|
||||||
|
|
||||||
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
||||||
# First update disk
|
await self._ensure_initialized()
|
||||||
success = await super().register(obj)
|
success = await super().register(obj)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
# Then update cache
|
cache_key = (obj.type, obj.identifier)
|
||||||
cachekey = (obj.type, obj.identifier)
|
async with self._locked_cache() as cache:
|
||||||
if cachekey not in self.cache:
|
if cache_key not in cache:
|
||||||
self.cache[cachekey] = []
|
cache[cache_key] = []
|
||||||
|
if not any(
|
||||||
# Check if provider already exists in cache
|
cached_obj.provider_id == obj.provider_id
|
||||||
for cached_obj in self.cache[cachekey]:
|
for cached_obj in cache[cache_key]
|
||||||
if cached_obj.provider_id == obj.provider_id:
|
):
|
||||||
return success
|
cache[cache_key].append(obj)
|
||||||
|
|
||||||
# If not, update cache
|
|
||||||
self.cache[cachekey].append(obj)
|
|
||||||
|
|
||||||
return success
|
return success
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,7 @@ def sample_bank():
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
|
provider_resource_id="test_bank",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -52,6 +53,7 @@ def sample_bank():
|
||||||
def sample_model():
|
def sample_model():
|
||||||
return Model(
|
return Model(
|
||||||
identifier="test_model",
|
identifier="test_model",
|
||||||
|
provider_resource_id="test_model",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -59,7 +61,7 @@ def sample_model():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_registry_initialization(registry):
|
async def test_registry_initialization(registry):
|
||||||
# Test empty registry
|
# Test empty registry
|
||||||
results = await registry.get("nonexistent")
|
results = await registry.get("nonexistent", "nonexistent")
|
||||||
assert len(results) == 0
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,7 +72,7 @@ async def test_basic_registration(registry, sample_bank, sample_model):
|
||||||
print(f"Registering {sample_model}")
|
print(f"Registering {sample_model}")
|
||||||
await registry.register(sample_model)
|
await registry.register(sample_model)
|
||||||
print("Getting bank")
|
print("Getting bank")
|
||||||
results = await registry.get("test_bank")
|
results = await registry.get("memory_bank", "test_bank")
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
result_bank = results[0]
|
result_bank = results[0]
|
||||||
assert result_bank.identifier == sample_bank.identifier
|
assert result_bank.identifier == sample_bank.identifier
|
||||||
|
@ -79,7 +81,7 @@ async def test_basic_registration(registry, sample_bank, sample_model):
|
||||||
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
|
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
|
||||||
assert result_bank.provider_id == sample_bank.provider_id
|
assert result_bank.provider_id == sample_bank.provider_id
|
||||||
|
|
||||||
results = await registry.get("test_model")
|
results = await registry.get("model", "test_model")
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
result_model = results[0]
|
result_model = results[0]
|
||||||
assert result_model.identifier == sample_model.identifier
|
assert result_model.identifier == sample_model.identifier
|
||||||
|
@ -98,7 +100,7 @@ async def test_cached_registry_initialization(config, sample_bank, sample_model)
|
||||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||||
await cached_registry.initialize()
|
await cached_registry.initialize()
|
||||||
|
|
||||||
results = await cached_registry.get("test_bank")
|
results = await cached_registry.get("memory_bank", "test_bank")
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
result_bank = results[0]
|
result_bank = results[0]
|
||||||
assert result_bank.identifier == sample_bank.identifier
|
assert result_bank.identifier == sample_bank.identifier
|
||||||
|
@ -118,12 +120,13 @@ async def test_cached_registry_updates(config):
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=256,
|
chunk_size_in_tokens=256,
|
||||||
overlap_size_in_tokens=32,
|
overlap_size_in_tokens=32,
|
||||||
|
provider_resource_id="test_bank_2",
|
||||||
provider_id="baz",
|
provider_id="baz",
|
||||||
)
|
)
|
||||||
await cached_registry.register(new_bank)
|
await cached_registry.register(new_bank)
|
||||||
|
|
||||||
# Verify in cache
|
# Verify in cache
|
||||||
results = await cached_registry.get("test_bank_2")
|
results = await cached_registry.get("memory_bank", "test_bank_2")
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
result_bank = results[0]
|
result_bank = results[0]
|
||||||
assert result_bank.identifier == new_bank.identifier
|
assert result_bank.identifier == new_bank.identifier
|
||||||
|
@ -132,7 +135,7 @@ async def test_cached_registry_updates(config):
|
||||||
# Verify persisted to disk
|
# Verify persisted to disk
|
||||||
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||||
await new_registry.initialize()
|
await new_registry.initialize()
|
||||||
results = await new_registry.get("test_bank_2")
|
results = await new_registry.get("memory_bank", "test_bank_2")
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
result_bank = results[0]
|
result_bank = results[0]
|
||||||
assert result_bank.identifier == new_bank.identifier
|
assert result_bank.identifier == new_bank.identifier
|
||||||
|
@ -149,6 +152,7 @@ async def test_duplicate_provider_registration(config):
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=256,
|
chunk_size_in_tokens=256,
|
||||||
overlap_size_in_tokens=32,
|
overlap_size_in_tokens=32,
|
||||||
|
provider_resource_id="test_bank_2",
|
||||||
provider_id="baz",
|
provider_id="baz",
|
||||||
)
|
)
|
||||||
await cached_registry.register(original_bank)
|
await cached_registry.register(original_bank)
|
||||||
|
@ -158,12 +162,54 @@ async def test_duplicate_provider_registration(config):
|
||||||
embedding_model="different-model",
|
embedding_model="different-model",
|
||||||
chunk_size_in_tokens=128,
|
chunk_size_in_tokens=128,
|
||||||
overlap_size_in_tokens=16,
|
overlap_size_in_tokens=16,
|
||||||
|
provider_resource_id="test_bank_2",
|
||||||
provider_id="baz", # Same provider_id
|
provider_id="baz", # Same provider_id
|
||||||
)
|
)
|
||||||
await cached_registry.register(duplicate_bank)
|
await cached_registry.register(duplicate_bank)
|
||||||
|
|
||||||
results = await cached_registry.get("test_bank_2")
|
results = await cached_registry.get("memory_bank", "test_bank_2")
|
||||||
assert len(results) == 1 # Still only one result
|
assert len(results) == 1 # Still only one result
|
||||||
assert (
|
assert (
|
||||||
results[0].embedding_model == original_bank.embedding_model
|
results[0].embedding_model == original_bank.embedding_model
|
||||||
) # Original values preserved
|
) # Original values preserved
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_all_objects(config):
|
||||||
|
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||||
|
await cached_registry.initialize()
|
||||||
|
|
||||||
|
# Create multiple test banks
|
||||||
|
test_banks = [
|
||||||
|
VectorMemoryBank(
|
||||||
|
identifier=f"test_bank_{i}",
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=256,
|
||||||
|
overlap_size_in_tokens=32,
|
||||||
|
provider_resource_id=f"test_bank_{i}",
|
||||||
|
provider_id=f"provider_{i}",
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Register all banks
|
||||||
|
for bank in test_banks:
|
||||||
|
await cached_registry.register(bank)
|
||||||
|
|
||||||
|
# Test get_all retrieval
|
||||||
|
all_results = await cached_registry.get_all()
|
||||||
|
assert len(all_results) == 3
|
||||||
|
|
||||||
|
# Verify each bank was stored correctly
|
||||||
|
for original_bank in test_banks:
|
||||||
|
matching_banks = [
|
||||||
|
b for b in all_results if b.identifier == original_bank.identifier
|
||||||
|
]
|
||||||
|
assert len(matching_banks) == 1
|
||||||
|
stored_bank = matching_banks[0]
|
||||||
|
assert stored_bank.embedding_model == original_bank.embedding_model
|
||||||
|
assert stored_bank.provider_id == original_bank.provider_id
|
||||||
|
assert stored_bank.chunk_size_in_tokens == original_bank.chunk_size_in_tokens
|
||||||
|
assert (
|
||||||
|
stored_bank.overlap_size_in_tokens == original_bank.overlap_size_in_tokens
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue