diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 3cc69c4a4..6181465e1 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -11,7 +11,6 @@ from typing import Any, Dict, List, Set from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 -import llama_stack.distribution.store as distribution_store from llama_stack.apis.agents import Agents from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -27,6 +26,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry from llama_stack.distribution.distribution import builtin_automatically_routed_apis +from llama_stack.distribution.store import Registry as DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -68,7 +68,7 @@ class ProviderWithSpec(Provider): async def resolve_impls( run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]], - dist_registry: distribution_store.Registry, + dist_registry: DistributionRegistry, ) -> Dict[Api, Any]: """ Does two things: @@ -241,7 +241,7 @@ async def instantiate_provider( provider: ProviderWithSpec, deps: Dict[str, Any], inner_impls: Dict[str, Any], - dist_registry: distribution_store.Registry, + dist_registry: DistributionRegistry, ): protocols = api_protocol_map() additional_protocols = additional_protocols_map() diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 87b746c51..4d65d1056 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -7,7 +7,8 @@ from typing import Any from llama_stack.distribution.datatypes import * # noqa: F403 -import llama_stack.distribution.store as distribution_store + +from llama_stack.distribution.store import Registry as DistributionRegistry from .routing_tables import ( DatasetsRoutingTable, @@ -22,7 +23,7 @@ async def get_routing_table_impl( api: Api, impls_by_provider_id: Dict[str, RoutedProtocol], _deps, - dist_registry: distribution_store.Registry, + dist_registry: DistributionRegistry, ) -> Any: api_to_tables = { "memory_banks": MemoryBanksRoutingTable, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 9508512c9..332bdf259 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -13,8 +13,8 @@ from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.distribution.store import Registry as DistributionRegistry from llama_stack.distribution.datatypes import * # noqa: F403 -import llama_stack.distribution.store as distribution_store def get_impl_api(p: Any) -> Api: @@ -53,7 +53,7 @@ class CommonRoutingTableImpl(RoutingTable): def __init__( self, impls_by_provider_id: Dict[str, RoutedProtocol], - dist_registry: distribution_store.Registry, + dist_registry: DistributionRegistry, ) -> None: self.impls_by_provider_id = impls_by_provider_id self.dist_registry = dist_registry diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index aafa6741d..44e88e2ea 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -22,9 +22,6 @@ import yaml from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse -from pydantic import BaseModel, ValidationError -from termcolor import cprint -from typing_extensions import Annotated from llama_stack.distribution.distribution import ( builtin_automatically_routed_apis, @@ -39,10 +36,13 @@ from llama_stack.providers.utils.telemetry.tracing import ( SpanStatus, start_trace, ) +from pydantic import BaseModel, ValidationError +from termcolor import cprint +from typing_extensions import Annotated from llama_stack.distribution.datatypes import * # noqa: F403 -import llama_stack.distribution.store as distribution_store from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_impls +from llama_stack.distribution.store import DiskRegistry from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from .endpoints import get_all_api_endpoints @@ -292,7 +292,7 @@ def main( ) ) - dist_registry = distribution_store.DiskRegistry(dist_kvstore) + dist_registry = DiskRegistry(dist_kvstore) impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry)) if Api.telemetry in impls: