mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-06 14:00:42 +00:00
This PR makes the following changes: 1) Fixes the get_all and initialize impl to actually read the values returned from the range call to kvstore and not keys. 2) The start_key and end_key are fixed to correct perform the range query after the key format changes 3) Made the cache registry thread safe since there are multiple initializes called for each routing table. Tests: * Start stack * Register dataset * Kill stack * Bring stack up * dataset list ``` llama-stack-client datasets list +--------------+---------------+---------------------------------------------------------------------------------+---------+ | identifier | provider_id | metadata | type | +==============+===============+=================================================================================+=========+ | alpaca | huggingface-0 | {} | dataset | +--------------+---------------+---------------------------------------------------------------------------------+---------+ | mmlu | huggingface-0 | {'path': 'llama-stack/evals', 'name': 'evals__mmlu__details', 'split': 'train'} | dataset | +--------------+---------------+---------------------------------------------------------------------------------+---------+ ``` Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
215 lines
7.2 KiB
Python
215 lines
7.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import os
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from llama_stack.distribution.store import * # noqa F403
|
|
from llama_stack.apis.inference import Model
|
|
from llama_stack.apis.memory_banks import VectorMemoryBank
|
|
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
|
from llama_stack.distribution.datatypes import * # noqa F403
|
|
|
|
|
|
@pytest.fixture
|
|
def config():
|
|
config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
|
|
if os.path.exists(config.db_path):
|
|
os.remove(config.db_path)
|
|
return config
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def registry(config):
|
|
registry = DiskDistributionRegistry(await kvstore_impl(config))
|
|
await registry.initialize()
|
|
return registry
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def cached_registry(config):
|
|
registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
|
await registry.initialize()
|
|
return registry
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_bank():
|
|
return VectorMemoryBank(
|
|
identifier="test_bank",
|
|
embedding_model="all-MiniLM-L6-v2",
|
|
chunk_size_in_tokens=512,
|
|
overlap_size_in_tokens=64,
|
|
provider_resource_id="test_bank",
|
|
provider_id="test-provider",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_model():
|
|
return Model(
|
|
identifier="test_model",
|
|
provider_resource_id="test_model",
|
|
provider_id="test-provider",
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_registry_initialization(registry):
|
|
# Test empty registry
|
|
results = await registry.get("nonexistent", "nonexistent")
|
|
assert len(results) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_basic_registration(registry, sample_bank, sample_model):
|
|
print(f"Registering {sample_bank}")
|
|
await registry.register(sample_bank)
|
|
print(f"Registering {sample_model}")
|
|
await registry.register(sample_model)
|
|
print("Getting bank")
|
|
results = await registry.get("memory_bank", "test_bank")
|
|
assert len(results) == 1
|
|
result_bank = results[0]
|
|
assert result_bank.identifier == sample_bank.identifier
|
|
assert result_bank.embedding_model == sample_bank.embedding_model
|
|
assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens
|
|
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
|
|
assert result_bank.provider_id == sample_bank.provider_id
|
|
|
|
results = await registry.get("model", "test_model")
|
|
assert len(results) == 1
|
|
result_model = results[0]
|
|
assert result_model.identifier == sample_model.identifier
|
|
assert result_model.provider_id == sample_model.provider_id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cached_registry_initialization(config, sample_bank, sample_model):
|
|
# First populate the disk registry
|
|
disk_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
|
await disk_registry.initialize()
|
|
await disk_registry.register(sample_bank)
|
|
await disk_registry.register(sample_model)
|
|
|
|
# Test cached version loads from disk
|
|
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
|
await cached_registry.initialize()
|
|
|
|
results = await cached_registry.get("memory_bank", "test_bank")
|
|
assert len(results) == 1
|
|
result_bank = results[0]
|
|
assert result_bank.identifier == sample_bank.identifier
|
|
assert result_bank.embedding_model == sample_bank.embedding_model
|
|
assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens
|
|
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
|
|
assert result_bank.provider_id == sample_bank.provider_id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cached_registry_updates(config):
|
|
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
|
await cached_registry.initialize()
|
|
|
|
new_bank = VectorMemoryBank(
|
|
identifier="test_bank_2",
|
|
embedding_model="all-MiniLM-L6-v2",
|
|
chunk_size_in_tokens=256,
|
|
overlap_size_in_tokens=32,
|
|
provider_resource_id="test_bank_2",
|
|
provider_id="baz",
|
|
)
|
|
await cached_registry.register(new_bank)
|
|
|
|
# Verify in cache
|
|
results = await cached_registry.get("memory_bank", "test_bank_2")
|
|
assert len(results) == 1
|
|
result_bank = results[0]
|
|
assert result_bank.identifier == new_bank.identifier
|
|
assert result_bank.provider_id == new_bank.provider_id
|
|
|
|
# Verify persisted to disk
|
|
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
|
await new_registry.initialize()
|
|
results = await new_registry.get("memory_bank", "test_bank_2")
|
|
assert len(results) == 1
|
|
result_bank = results[0]
|
|
assert result_bank.identifier == new_bank.identifier
|
|
assert result_bank.provider_id == new_bank.provider_id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_duplicate_provider_registration(config):
|
|
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
|
await cached_registry.initialize()
|
|
|
|
original_bank = VectorMemoryBank(
|
|
identifier="test_bank_2",
|
|
embedding_model="all-MiniLM-L6-v2",
|
|
chunk_size_in_tokens=256,
|
|
overlap_size_in_tokens=32,
|
|
provider_resource_id="test_bank_2",
|
|
provider_id="baz",
|
|
)
|
|
await cached_registry.register(original_bank)
|
|
|
|
duplicate_bank = VectorMemoryBank(
|
|
identifier="test_bank_2",
|
|
embedding_model="different-model",
|
|
chunk_size_in_tokens=128,
|
|
overlap_size_in_tokens=16,
|
|
provider_resource_id="test_bank_2",
|
|
provider_id="baz", # Same provider_id
|
|
)
|
|
await cached_registry.register(duplicate_bank)
|
|
|
|
results = await cached_registry.get("memory_bank", "test_bank_2")
|
|
assert len(results) == 1 # Still only one result
|
|
assert (
|
|
results[0].embedding_model == original_bank.embedding_model
|
|
) # Original values preserved
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_all_objects(config):
|
|
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
|
await cached_registry.initialize()
|
|
|
|
# Create multiple test banks
|
|
test_banks = [
|
|
VectorMemoryBank(
|
|
identifier=f"test_bank_{i}",
|
|
embedding_model="all-MiniLM-L6-v2",
|
|
chunk_size_in_tokens=256,
|
|
overlap_size_in_tokens=32,
|
|
provider_resource_id=f"test_bank_{i}",
|
|
provider_id=f"provider_{i}",
|
|
)
|
|
for i in range(3)
|
|
]
|
|
|
|
# Register all banks
|
|
for bank in test_banks:
|
|
await cached_registry.register(bank)
|
|
|
|
# Test get_all retrieval
|
|
all_results = await cached_registry.get_all()
|
|
assert len(all_results) == 3
|
|
|
|
# Verify each bank was stored correctly
|
|
for original_bank in test_banks:
|
|
matching_banks = [
|
|
b for b in all_results if b.identifier == original_bank.identifier
|
|
]
|
|
assert len(matching_banks) == 1
|
|
stored_bank = matching_banks[0]
|
|
assert stored_bank.embedding_model == original_bank.embedding_model
|
|
assert stored_bank.provider_id == original_bank.provider_id
|
|
assert stored_bank.chunk_size_in_tokens == original_bank.chunk_size_in_tokens
|
|
assert (
|
|
stored_bank.overlap_size_in_tokens == original_bank.overlap_size_in_tokens
|
|
)
|