persist registered objects with distribution (#354)

* persist registered objects with distribution

* linter fixes

* comment

* use annotate and field discriminator

* workign tests

* donot use global state

* precommit failures fixed

* add back Any

* fix imports

* remove unnecessary changes in ollama

* precommit failures fixed

* make kvstore configurable for dist and rename registry

* add comment about registry list return

* fix linter errors

* use registry to hydrate

* remove debug print

* linter fixes

* remove kvstore.db

* rename distribution_registry_store

---------

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-04 17:25:06 -08:00 committed by GitHub
parent c9bf1d7d0b
commit 663883cc29
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 401 additions and 46 deletions

View file

@ -31,6 +31,8 @@ from llama_stack.distribution.distribution import (
get_provider_registry,
)
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
@ -38,9 +40,10 @@ from llama_stack.providers.utils.telemetry.tracing import (
start_trace,
)
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from .endpoints import get_all_api_endpoints
@ -278,8 +281,23 @@ def main(
config = StackRunConfig(**yaml.safe_load(fp))
app = FastAPI()
# instantiate kvstore for storing and retrieving distribution metadata
if config.metadata_store:
dist_kvstore = asyncio.run(kvstore_impl(config.metadata_store))
else:
dist_kvstore = asyncio.run(
kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
)
)
)
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])