diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 9d2a10adb..e9ec1b9e6 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -21,6 +21,7 @@ from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring +from llama_stack.providers.utils.kvstore.config import KVStoreConfig LLAMA_STACK_BUILD_CONFIG_VERSION = "2" LLAMA_STACK_RUN_CONFIG_VERSION = "2" @@ -138,6 +139,13 @@ One or more providers to use for each API. The same provider_type (e.g., meta-re can be instantiated multiple times (with different configs) if necessary. """, ) + distribution_registry_store: Optional[KVStoreConfig] = Field( + default=None, + description=""" +Configuration for the persistence store used by the distribution registry. If not specified, +a default SQLite store will be used.""", + ) + class BuildConfig(BaseModel): diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 6181465e1..96b4b81e6 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -26,7 +26,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry from llama_stack.distribution.distribution import builtin_automatically_routed_apis -from llama_stack.distribution.store import Registry as DistributionRegistry +from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 4d65d1056..8060f1450 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -8,7 +8,7 @@ from typing import Any from llama_stack.distribution.datatypes import * # noqa: F403 -from llama_stack.distribution.store import Registry as DistributionRegistry +from llama_stack.distribution.store import DistributionRegistry from .routing_tables import ( DatasetsRoutingTable, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 332bdf259..d314614dd 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -13,7 +13,7 @@ from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 -from llama_stack.distribution.store import Registry as DistributionRegistry +from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.datatypes import * # noqa: F403 diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 2a20597f8..bb23a6a4d 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -42,7 +42,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_impls -from llama_stack.distribution.store import DiskRegistry +from llama_stack.distribution.store import DiskDistributionRegistry from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from .endpoints import get_all_api_endpoints @@ -282,8 +282,13 @@ def main( app = FastAPI() # instantiate kvstore for storing and retrieving distribution metadata - dist_kvstore = asyncio.run( - kvstore_impl( + if config.distribution_registry_store: + dist_kvstore = asyncio.run( + kvstore_impl(config.distribution_registry_store) + ) + else: + dist_kvstore = asyncio.run( + kvstore_impl( SqliteKVStoreConfig( db_path=( DISTRIBS_BASE_DIR / config.image_name / "kvstore.db" @@ -292,7 +297,7 @@ def main( ) ) - dist_registry = DiskRegistry(dist_kvstore) + dist_registry = DiskDistributionRegistry(dist_kvstore) impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry)) if Api.telemetry in impls: diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 72ce20245..e4881638f 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -15,7 +15,7 @@ from llama_stack.distribution.datatypes import RoutableObjectWithProvider from llama_stack.providers.utils.kvstore import KVStore -class Registry(Protocol): +class DistributionRegistry(Protocol): async def get(self, identifier: str) -> [RoutableObjectWithProvider]: ... async def register(self, obj: RoutableObjectWithProvider) -> None: ... @@ -23,7 +23,7 @@ class Registry(Protocol): KEY_FORMAT = "distributions:registry:{}" -class DiskRegistry(Registry): +class DiskDistributionRegistry(DistributionRegistry): def __init__(self, kvstore: KVStore): self.kvstore = kvstore @@ -33,7 +33,6 @@ class DiskRegistry(Registry): if not json_str: return [] - # Parse JSON string into list of objects objects_data = json.loads(json_str) return [ diff --git a/llama_stack/distribution/store/tests/test_registry.py b/llama_stack/distribution/store/tests/test_registry.py index 69c690673..ab9457707 100644 --- a/llama_stack/distribution/store/tests/test_registry.py +++ b/llama_stack/distribution/store/tests/test_registry.py @@ -19,7 +19,7 @@ async def test_registry(): # delete the file if it exists if os.path.exists(config.db_path): os.remove(config.db_path) - registry = DiskRegistry(await kvstore_impl(config)) + registry = DiskDistributionRegistry(await kvstore_impl(config)) bank = VectorMemoryBankDef( identifier="test_bank", embedding_model="all-MiniLM-L6-v2", diff --git a/~/kvstore.db b/~/kvstore.db new file mode 100644 index 000000000..dfd84a63f Binary files /dev/null and b/~/kvstore.db differ