diff --git a/llama_stack/core/providers.py b/llama_stack/core/providers.py index ee432e793..9a0a478c2 100644 --- a/llama_stack/core/providers.py +++ b/llama_stack/core/providers.py @@ -34,6 +34,8 @@ from .utils.config import redact_sensitive_fields logger = get_logger(name=__name__, category="core") # Storage constants for dynamic provider connections +# Use composite key format: provider_connections:v1::{api}::{provider_id} +# This allows the same provider_id to be used for different APIs PROVIDER_CONNECTIONS_PREFIX = "provider_connections:v1::" @@ -55,6 +57,8 @@ class ProviderImpl(Providers): self.config = config self.deps = deps self.kvstore = None # KVStore for dynamic provider persistence + # Runtime cache uses composite key: "{api}::{provider_id}" + # This allows the same provider_id to be used for different APIs self.dynamic_providers: dict[str, ProviderConnectionInfo] = {} # Runtime cache self.dynamic_provider_impls: dict[str, Any] = {} # Initialized provider instances @@ -65,36 +69,39 @@ class ProviderImpl(Providers): async def initialize(self) -> None: # Initialize kvstore for dynamic providers - # Reuse the same kvstore as the distribution registry if available - if hasattr(self.config.run_config, "metadata_store") and self.config.run_config.metadata_store: - from llama_stack.providers.utils.kvstore import kvstore_impl + # Use the metadata store from the new storage config structure + if not (self.config.run_config.storage and self.config.run_config.storage.stores.metadata): + raise RuntimeError( + "No metadata store configured in storage.stores.metadata. " + "Provider management requires a configured metadata store (kv_memory, kv_sqlite, etc)." + ) - self.kvstore = await kvstore_impl(self.config.run_config.metadata_store) - logger.info("Initialized kvstore for dynamic provider management") + from llama_stack.providers.utils.kvstore import kvstore_impl - # Load existing dynamic providers from kvstore - await self._load_dynamic_providers() - logger.info(f"Loaded {len(self.dynamic_providers)} dynamic providers from kvstore") + self.kvstore = await kvstore_impl(self.config.run_config.storage.stores.metadata) + logger.info("✅ Initialized kvstore for dynamic provider management") - # Auto-instantiate connected providers on startup - if self.provider_registry: - for provider_id, conn_info in self.dynamic_providers.items(): - if conn_info.status == ProviderConnectionStatus.connected: - try: - impl = await self._instantiate_provider(conn_info) - self.dynamic_provider_impls[provider_id] = impl - logger.info(f"Auto-instantiated provider {provider_id} from kvstore") - except Exception as e: - logger.error(f"Failed to auto-instantiate provider {provider_id}: {e}") - # Update status to failed - conn_info.status = ProviderConnectionStatus.failed - conn_info.error_message = str(e) - conn_info.updated_at = datetime.now(UTC) - await self._store_connection(conn_info) - else: - logger.warning("Provider registry not available, skipping auto-instantiation") + # Load existing dynamic providers from kvstore + await self._load_dynamic_providers() + logger.info(f"📦 Loaded {len(self.dynamic_providers)} existing dynamic providers from kvstore") + + # Auto-instantiate connected providers on startup + if self.provider_registry: + for provider_id, conn_info in self.dynamic_providers.items(): + if conn_info.status == ProviderConnectionStatus.connected: + try: + impl = await self._instantiate_provider(conn_info) + self.dynamic_provider_impls[provider_id] = impl + logger.info(f"♻️ Auto-instantiated provider {provider_id} from kvstore") + except Exception as e: + logger.error(f"Failed to auto-instantiate provider {provider_id}: {e}") + # Update status to failed + conn_info.status = ProviderConnectionStatus.failed + conn_info.error_message = str(e) + conn_info.updated_at = datetime.now(UTC) + await self._store_connection(conn_info) else: - logger.warning("No metadata_store configured, dynamic provider management disabled") + logger.warning("Provider registry not available, skipping auto-instantiation") async def shutdown(self) -> None: logger.debug("ProviderImpl.shutdown") @@ -245,9 +252,10 @@ class ProviderImpl(Providers): if not self.kvstore: raise RuntimeError("KVStore not initialized") - key = f"{PROVIDER_CONNECTIONS_PREFIX}{info.provider_id}" + # Use composite key: provider_connections:v1::{api}::{provider_id} + key = f"{PROVIDER_CONNECTIONS_PREFIX}{info.api}::{info.provider_id}" await self.kvstore.set(key, info.model_dump_json()) - logger.debug(f"Stored provider connection: {info.provider_id}") + logger.debug(f"Stored provider connection: {info.api}::{info.provider_id}") async def _load_connection(self, provider_id: str) -> ProviderConnectionInfo | None: """Load provider connection info from kvstore. @@ -293,8 +301,10 @@ class ProviderImpl(Providers): """Load dynamic providers from kvstore into runtime cache.""" connections = await self._list_connections() for conn in connections: - self.dynamic_providers[conn.provider_id] = conn - logger.debug(f"Loaded dynamic provider: {conn.provider_id} (status: {conn.status})") + # Use composite key for runtime cache + cache_key = f"{conn.api}::{conn.provider_id}" + self.dynamic_providers[cache_key] = conn + logger.debug(f"Loaded dynamic provider: {cache_key} (status: {conn.status})") # Helper methods for dynamic provider management @@ -384,12 +394,17 @@ class ProviderImpl(Providers): All providers are stored in kvstore and treated equally. """ + logger.info(f"📝 REGISTER_PROVIDER called: provider_id={provider_id}, api={api}, type={provider_type}") + if not self.kvstore: raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)") - # Check if provider_id already exists - if provider_id in self.dynamic_providers: - raise ValueError(f"Provider {provider_id} already exists") + # Use composite key to allow same provider_id for different APIs + cache_key = f"{api}::{provider_id}" + + # Check if provider already exists for this API + if cache_key in self.dynamic_providers: + raise ValueError(f"Provider {provider_id} already exists for API {api}") # Get authenticated user as owner user = get_authenticated_user() @@ -415,7 +430,8 @@ class ProviderImpl(Providers): # Instantiate provider if we have a provider registry if self.provider_registry: impl = await self._instantiate_provider(conn_info) - self.dynamic_provider_impls[provider_id] = impl + # Use composite key for impl cache too + self.dynamic_provider_impls[cache_key] = impl # Update status to connected after successful instantiation conn_info.status = ProviderConnectionStatus.connected @@ -434,8 +450,8 @@ class ProviderImpl(Providers): # Store updated status await self._store_connection(conn_info) - # Add to runtime cache - self.dynamic_providers[provider_id] = conn_info + # Add to runtime cache using composite key + self.dynamic_providers[cache_key] = conn_info return RegisterProviderResponse(provider=conn_info) @@ -445,7 +461,7 @@ class ProviderImpl(Providers): conn_info.error_message = str(e) conn_info.updated_at = datetime.now(UTC) await self._store_connection(conn_info) - self.dynamic_providers[provider_id] = conn_info + self.dynamic_providers[cache_key] = conn_info logger.error(f"Failed to register provider {provider_id}: {e}") raise RuntimeError(f"Failed to register provider: {e}") from e @@ -461,6 +477,8 @@ class ProviderImpl(Providers): Updates persist to kvstore and survive server restarts. This works for all providers (whether originally from run.yaml or API). """ + logger.info(f"🔄 UPDATE_PROVIDER called: provider_id={provider_id}, has_config={config is not None}, has_attributes={attributes is not None}") + if not self.kvstore: raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)") @@ -531,6 +549,8 @@ class ProviderImpl(Providers): Removes the provider from kvstore and shuts down its instance. This works for all providers (whether originally from run.yaml or API). """ + logger.info(f"🗑️ UNREGISTER_PROVIDER called: provider_id={provider_id}") + if not self.kvstore: raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)") @@ -560,6 +580,8 @@ class ProviderImpl(Providers): async def test_provider_connection(self, provider_id: str) -> TestProviderConnectionResponse: """Test a provider connection.""" + logger.info(f"🔍 TEST_PROVIDER_CONNECTION called: provider_id={provider_id}") + # Check if provider exists (static or dynamic) provider_impl = None diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index fb0089432..6fadefeb3 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -42,6 +42,8 @@ from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceI from llama_stack.core.providers import ProviderImpl, ProviderImplConfig from llama_stack.core.resolver import ProviderRegistry, resolve_impls from llama_stack.core.routing_tables.common import CommonRoutingTableImpl +from llama_stack.core.access_control.datatypes import AccessRule +from llama_stack.core.store.registry import DistributionRegistry from llama_stack.core.storage.datatypes import ( InferenceStoreReference, KVStoreReference, @@ -406,6 +408,187 @@ def _initialize_storage(run_config: StackRunConfig): register_sqlstore_backends(sql_backends) +async def resolve_impls_via_provider_registration( + run_config: StackRunConfig, + provider_registry: ProviderRegistry, + dist_registry: DistributionRegistry, + policy: list[AccessRule], + internal_impls: dict[Api, Any], +) -> dict[Api, Any]: + """ + Resolves provider implementations by registering them through ProviderImpl. + This ensures all providers (startup and runtime) go through the same registration code path. + + Args: + run_config: Stack run configuration with providers from run.yaml + provider_registry: Registry of available provider types + dist_registry: Distribution registry + policy: Access control policy + internal_impls: Internal implementations (inspect, providers) already initialized + + Returns: + Dictionary mapping API to implementation instances + """ + from llama_stack.core.distribution import builtin_automatically_routed_apis + from llama_stack.core.resolver import sort_providers_by_deps, specs_for_autorouted_apis, validate_and_prepare_providers + + routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()} + router_apis = {x.router_api for x in builtin_automatically_routed_apis()} + + # Validate and prepare providers from run.yaml + providers_with_specs = validate_and_prepare_providers( + run_config, provider_registry, routing_table_apis, router_apis + ) + + apis_to_serve = run_config.apis or set( + list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis] + ) + + providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve)) + + # Sort providers in dependency order + sorted_providers = sort_providers_by_deps(providers_with_specs, run_config) + + # Get the ProviderImpl instance + providers_impl = internal_impls[Api.providers] + + # Register each provider through ProviderImpl + impls = internal_impls.copy() + + logger.info(f"🚀 Starting provider registration for {len(sorted_providers)} providers from run.yaml") + + for api_str, provider in sorted_providers: + # Skip providers that are not enabled + if provider.provider_id is None: + continue + + # Skip internal APIs that need special handling + # - providers: already initialized as internal_impls + # - inspect: already initialized as internal_impls + # - telemetry: internal observability, directly instantiated below + if api_str in ["providers", "inspect"]: + continue + + # Telemetry is an internal API that should be directly instantiated + if api_str == "telemetry": + logger.info(f"Instantiating {provider.provider_id} for {api_str}") + + from llama_stack.core.resolver import instantiate_provider + + deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls} + for a in provider.spec.optional_api_dependencies: + if a in impls: + deps[a] = impls[a] + + impl = await instantiate_provider(provider, deps, {}, dist_registry, run_config, policy) + api = Api(api_str) + impls[api] = impl + providers_impl.deps[api] = impl + continue + + # Handle different provider types + try: + # Check if this is a routing table or router (system infrastructure) + is_routing_table = api_str.startswith("inner-") or provider.spec.provider_type in ["routing_table", "router"] + is_router = not api_str.startswith("inner-") and (Api(api_str) in router_apis or provider.spec.provider_type == "router") + + if api_str.startswith("inner-") or provider.spec.provider_type == "routing_table": + # Inner providers or routing tables cannot be registered through the API + # They need to be instantiated directly + logger.info(f"Instantiating {provider.provider_id} for {api_str}") + + from llama_stack.core.resolver import instantiate_provider + + deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls} + for a in provider.spec.optional_api_dependencies: + if a in impls: + deps[a] = impls[a] + + # Get inner impls if available + inner_impls = {} + + # For routing tables of autorouted APIs, get inner impls from the router API + # E.g., tool_groups routing table needs inner-tool_runtime providers + if provider.spec.provider_type == "routing_table": + from llama_stack.core.distribution import builtin_automatically_routed_apis + autorouted_map = {info.routing_table_api: info.router_api for info in builtin_automatically_routed_apis()} + if Api(api_str) in autorouted_map: + router_api_str = autorouted_map[Api(api_str)].value + inner_key = f"inner-{router_api_str}" + if inner_key in impls: + inner_impls = impls[inner_key] + else: + # For regular inner providers, use their own inner key + inner_key = f"inner-{api_str}" + if inner_key in impls: + inner_impls = impls[inner_key] + + impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy) + + # Store appropriately + if api_str.startswith("inner-"): + if api_str not in impls: + impls[api_str] = {} + impls[api_str][provider.provider_id] = impl + else: + api = Api(api_str) + impls[api] = impl + # Update providers_impl.deps so subsequent providers can depend on this + providers_impl.deps[api] = impl + + elif is_router: + # Router providers also need special handling + logger.info(f"Instantiating router {provider.provider_id} for {api_str}") + + from llama_stack.core.resolver import instantiate_provider + + deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls} + for a in provider.spec.optional_api_dependencies: + if a in impls: + deps[a] = impls[a] + + # Get inner impls if this is a router + inner_impls = {} + inner_key = f"inner-{api_str}" + if inner_key in impls: + inner_impls = impls[inner_key] + + impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy) + api = Api(api_str) + impls[api] = impl + # Update providers_impl.deps so subsequent providers can depend on this + providers_impl.deps[api] = impl + + else: + # Regular providers - register through ProviderImpl + api = Api(api_str) + logger.info(f"Registering {provider.provider_id} for {api.value}") + + response = await providers_impl.register_provider( + provider_id=provider.provider_id, + api=api.value, + provider_type=provider.spec.provider_type, + config=provider.config, + attributes=getattr(provider, "attributes", None), + ) + + # Get the instantiated impl from dynamic_provider_impls using composite key + cache_key = f"{api.value}::{provider.provider_id}" + impl = providers_impl.dynamic_provider_impls[cache_key] + impls[api] = impl + + # IMPORTANT: Update providers_impl.deps so subsequent providers can depend on this one + providers_impl.deps[api] = impl + + logger.info(f"✅ Successfully registered startup provider: {provider.provider_id}") + + except Exception as e: + logger.error(f"❌ Failed to handle provider {provider.provider_id}: {e}") + raise + + return impls + + class Stack: def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None): self.run_config = run_config @@ -441,7 +624,13 @@ class Stack: policy=policy, ) - impls = await resolve_impls( + # Initialize the ProviderImpl so it has access to kvstore + print("DEBUG: About to initialize ProviderImpl") + await internal_impls[Api.providers].initialize() + print("DEBUG: ProviderImpl initialized, about to call resolve_impls_via_provider_registration") + + # Register all providers from run.yaml through ProviderImpl + impls = await resolve_impls_via_provider_registration( self.run_config, provider_registry, dist_registry, diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index ed880d4a0..6938dbb92 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -231,7 +231,7 @@ storage: backends: kv_default: type: kv_sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/kvstore.db + db_path: ":memory:" sql_default: type: sql_sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sql_store.db