make distribution registry thread safe and other fixes (#449)

This PR makes the following changes:
1) Fixes the get_all and initialize impl to actually read the values
returned from the range call to kvstore and not keys.
2) The start_key and end_key are fixed to correct perform the range
query after the key format changes
3) Made the cache registry thread safe since there are multiple
initializes called for each routing table.

Tests:
* Start stack
* Register dataset
* Kill stack
* Bring stack up
* dataset list
```
 llama-stack-client datasets list
+--------------+---------------+---------------------------------------------------------------------------------+---------+
| identifier   | provider_id   | metadata                                                                        | type    |
+==============+===============+=================================================================================+=========+
| alpaca       | huggingface-0 | {}                                                                              | dataset |
+--------------+---------------+---------------------------------------------------------------------------------+---------+
| mmlu         | huggingface-0 | {'path': 'llama-stack/evals', 'name': 'evals__mmlu__details', 'split': 'train'} | dataset |
+--------------+---------------+---------------------------------------------------------------------------------+---------+
```

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-13 15:12:34 -08:00 committed by GitHub
parent 15dee2b8b8
commit e90ea1ab1e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 148 additions and 48 deletions

View file

@ -44,6 +44,7 @@ def sample_bank():
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
provider_resource_id="test_bank",
provider_id="test-provider",
)
@ -52,6 +53,7 @@ def sample_bank():
def sample_model():
return Model(
identifier="test_model",
provider_resource_id="test_model",
provider_id="test-provider",
)
@ -59,7 +61,7 @@ def sample_model():
@pytest.mark.asyncio
async def test_registry_initialization(registry):
# Test empty registry
results = await registry.get("nonexistent")
results = await registry.get("nonexistent", "nonexistent")
assert len(results) == 0
@ -70,7 +72,7 @@ async def test_basic_registration(registry, sample_bank, sample_model):
print(f"Registering {sample_model}")
await registry.register(sample_model)
print("Getting bank")
results = await registry.get("test_bank")
results = await registry.get("memory_bank", "test_bank")
assert len(results) == 1
result_bank = results[0]
assert result_bank.identifier == sample_bank.identifier
@ -79,7 +81,7 @@ async def test_basic_registration(registry, sample_bank, sample_model):
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
assert result_bank.provider_id == sample_bank.provider_id
results = await registry.get("test_model")
results = await registry.get("model", "test_model")
assert len(results) == 1
result_model = results[0]
assert result_model.identifier == sample_model.identifier
@ -98,7 +100,7 @@ async def test_cached_registry_initialization(config, sample_bank, sample_model)
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
results = await cached_registry.get("test_bank")
results = await cached_registry.get("memory_bank", "test_bank")
assert len(results) == 1
result_bank = results[0]
assert result_bank.identifier == sample_bank.identifier
@ -118,12 +120,13 @@ async def test_cached_registry_updates(config):
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=256,
overlap_size_in_tokens=32,
provider_resource_id="test_bank_2",
provider_id="baz",
)
await cached_registry.register(new_bank)
# Verify in cache
results = await cached_registry.get("test_bank_2")
results = await cached_registry.get("memory_bank", "test_bank_2")
assert len(results) == 1
result_bank = results[0]
assert result_bank.identifier == new_bank.identifier
@ -132,7 +135,7 @@ async def test_cached_registry_updates(config):
# Verify persisted to disk
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
await new_registry.initialize()
results = await new_registry.get("test_bank_2")
results = await new_registry.get("memory_bank", "test_bank_2")
assert len(results) == 1
result_bank = results[0]
assert result_bank.identifier == new_bank.identifier
@ -149,6 +152,7 @@ async def test_duplicate_provider_registration(config):
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=256,
overlap_size_in_tokens=32,
provider_resource_id="test_bank_2",
provider_id="baz",
)
await cached_registry.register(original_bank)
@ -158,12 +162,54 @@ async def test_duplicate_provider_registration(config):
embedding_model="different-model",
chunk_size_in_tokens=128,
overlap_size_in_tokens=16,
provider_resource_id="test_bank_2",
provider_id="baz", # Same provider_id
)
await cached_registry.register(duplicate_bank)
results = await cached_registry.get("test_bank_2")
results = await cached_registry.get("memory_bank", "test_bank_2")
assert len(results) == 1 # Still only one result
assert (
results[0].embedding_model == original_bank.embedding_model
) # Original values preserved
@pytest.mark.asyncio
async def test_get_all_objects(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
# Create multiple test banks
test_banks = [
VectorMemoryBank(
identifier=f"test_bank_{i}",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=256,
overlap_size_in_tokens=32,
provider_resource_id=f"test_bank_{i}",
provider_id=f"provider_{i}",
)
for i in range(3)
]
# Register all banks
for bank in test_banks:
await cached_registry.register(bank)
# Test get_all retrieval
all_results = await cached_registry.get_all()
assert len(all_results) == 3
# Verify each bank was stored correctly
for original_bank in test_banks:
matching_banks = [
b for b in all_results if b.identifier == original_bank.identifier
]
assert len(matching_banks) == 1
stored_bank = matching_banks[0]
assert stored_bank.embedding_model == original_bank.embedding_model
assert stored_bank.provider_id == original_bank.provider_id
assert stored_bank.chunk_size_in_tokens == original_bank.chunk_size_in_tokens
assert (
stored_bank.overlap_size_in_tokens == original_bank.overlap_size_in_tokens
)