diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index a93cc1183..3cc69c4a4 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -11,6 +11,7 @@ from typing import Any, Dict, List, Set from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 +import llama_stack.distribution.store as distribution_store from llama_stack.apis.agents import Agents from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -65,7 +66,9 @@ class ProviderWithSpec(Provider): # TODO: this code is not very straightforward to follow and needs one more round of refactoring async def resolve_impls( - run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]] + run_config: StackRunConfig, + provider_registry: Dict[Api, Dict[str, ProviderSpec]], + dist_registry: distribution_store.Registry, ) -> Dict[Api, Any]: """ Does two things: @@ -189,6 +192,7 @@ async def resolve_impls( provider, deps, inner_impls, + dist_registry, ) # TODO: ugh slightly redesign this shady looking code if "inner-" in api_str: @@ -237,6 +241,7 @@ async def instantiate_provider( provider: ProviderWithSpec, deps: Dict[str, Any], inner_impls: Dict[str, Any], + dist_registry: distribution_store.Registry, ): protocols = api_protocol_map() additional_protocols = additional_protocols_map() @@ -270,7 +275,7 @@ async def instantiate_provider( method = "get_routing_table_impl" config = None - args = [provider_spec.api, inner_impls, deps] + args = [provider_spec.api, inner_impls, deps, dist_registry] else: method = "get_provider_impl" diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 2cc89848e..87b746c51 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -7,6 +7,8 @@ from typing import Any from llama_stack.distribution.datatypes import * # noqa: F403 +import llama_stack.distribution.store as distribution_store + from .routing_tables import ( DatasetsRoutingTable, MemoryBanksRoutingTable, @@ -20,6 +22,7 @@ async def get_routing_table_impl( api: Api, impls_by_provider_id: Dict[str, RoutedProtocol], _deps, + dist_registry: distribution_store.Registry, ) -> Any: api_to_tables = { "memory_banks": MemoryBanksRoutingTable, @@ -32,7 +35,7 @@ async def get_routing_table_impl( if api.value not in api_to_tables: raise ValueError(f"API {api.value} not found in router map") - impl = api_to_tables[api.value](impls_by_provider_id) + impl = api_to_tables[api.value](impls_by_provider_id, dist_registry) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 06fb49092..9508512c9 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -53,8 +53,10 @@ class CommonRoutingTableImpl(RoutingTable): def __init__( self, impls_by_provider_id: Dict[str, RoutedProtocol], + dist_registry: distribution_store.Registry, ) -> None: self.impls_by_provider_id = impls_by_provider_id + self.dist_registry = dist_registry async def initialize(self) -> None: self.registry: Registry = {} @@ -171,7 +173,7 @@ class CommonRoutingTableImpl(RoutingTable): if obj.identifier not in self.registry: self.registry[obj.identifier] = [] self.registry[obj.identifier].append(obj) - await distribution_store.REGISTRY.register(obj) + await self.dist_registry.register(obj) class ModelsRoutingTable(CommonRoutingTableImpl, Models): diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index cee517fa8..2d9777736 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -22,9 +22,6 @@ import yaml from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse -from pydantic import BaseModel, ValidationError -from termcolor import cprint -from typing_extensions import Annotated from llama_stack.distribution.distribution import ( builtin_automatically_routed_apis, @@ -39,6 +36,9 @@ from llama_stack.providers.utils.telemetry.tracing import ( SpanStatus, start_trace, ) +from pydantic import BaseModel, ValidationError +from termcolor import cprint +from typing_extensions import Annotated from llama_stack.distribution.datatypes import * # noqa: F403 import llama_stack.distribution.store as distribution_store from llama_stack.distribution.request_headers import set_request_provider_data @@ -292,9 +292,9 @@ def main( ) ) - distribution_store.REGISTRY = distribution_store.DiskRegistry(dist_kvstore) + dist_registry = distribution_store.DiskRegistry(dist_kvstore) - impls = asyncio.run(resolve_impls(config, get_provider_registry())) + impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry)) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) diff --git a/llama_stack/distribution/store/__init__.py b/llama_stack/distribution/store/__init__.py index e03b931f1..b294efc73 100644 --- a/llama_stack/distribution/store/__init__.py +++ b/llama_stack/distribution/store/__init__.py @@ -5,5 +5,3 @@ # the root directory of this source tree. from .registry import DiskRegistry, Registry - -REGISTRY: Registry = None