Factor out create_dist_registry (#398)

This commit is contained in:
Dalton Flanagan 2024-11-07 16:13:19 -05:00 committed by GitHub
parent 694c142b89
commit 345ae07317
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 30 additions and 19 deletions

View file

@ -31,7 +31,7 @@ from llama_stack.distribution.distribution import (
get_provider_registry,
)
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -42,8 +42,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from .endpoints import get_all_api_endpoints
@ -281,21 +279,8 @@ def main(
config = StackRunConfig(**yaml.safe_load(fp))
app = FastAPI()
# instantiate kvstore for storing and retrieving distribution metadata
if config.metadata_store:
dist_kvstore = asyncio.run(kvstore_impl(config.metadata_store))
else:
dist_kvstore = asyncio.run(
kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
)
)
)
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
if Api.telemetry in impls:

View file

@ -9,9 +9,17 @@ from typing import Dict, List, Protocol
import pydantic
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
from llama_stack.distribution.datatypes import (
RoutableObjectWithProvider,
StackRunConfig,
)
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.kvstore import (
KVStore,
kvstore_impl,
SqliteKVStoreConfig,
)
class DistributionRegistry(Protocol):
@ -133,3 +141,21 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
self.cache[obj.identifier].append(obj)
return success
async def create_dist_registry(
config: StackRunConfig,
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata
if config.metadata_store:
dist_kvstore = await kvstore_impl(config.metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
)
)
return CachedDiskDistributionRegistry(dist_kvstore), dist_kvstore