save changes

This commit is contained in:
Raghotham Murthy 2025-10-24 14:28:49 -07:00
parent 66158d1999
commit e21db79d6c
3 changed files with 250 additions and 39 deletions

View file

@ -34,6 +34,8 @@ from .utils.config import redact_sensitive_fields
logger = get_logger(name=__name__, category="core")
# Storage constants for dynamic provider connections
# Use composite key format: provider_connections:v1::{api}::{provider_id}
# This allows the same provider_id to be used for different APIs
PROVIDER_CONNECTIONS_PREFIX = "provider_connections:v1::"
@ -55,6 +57,8 @@ class ProviderImpl(Providers):
self.config = config
self.deps = deps
self.kvstore = None # KVStore for dynamic provider persistence
# Runtime cache uses composite key: "{api}::{provider_id}"
# This allows the same provider_id to be used for different APIs
self.dynamic_providers: dict[str, ProviderConnectionInfo] = {} # Runtime cache
self.dynamic_provider_impls: dict[str, Any] = {} # Initialized provider instances
@ -65,36 +69,39 @@ class ProviderImpl(Providers):
async def initialize(self) -> None:
# Initialize kvstore for dynamic providers
# Reuse the same kvstore as the distribution registry if available
if hasattr(self.config.run_config, "metadata_store") and self.config.run_config.metadata_store:
from llama_stack.providers.utils.kvstore import kvstore_impl
# Use the metadata store from the new storage config structure
if not (self.config.run_config.storage and self.config.run_config.storage.stores.metadata):
raise RuntimeError(
"No metadata store configured in storage.stores.metadata. "
"Provider management requires a configured metadata store (kv_memory, kv_sqlite, etc)."
)
self.kvstore = await kvstore_impl(self.config.run_config.metadata_store)
logger.info("Initialized kvstore for dynamic provider management")
from llama_stack.providers.utils.kvstore import kvstore_impl
# Load existing dynamic providers from kvstore
await self._load_dynamic_providers()
logger.info(f"Loaded {len(self.dynamic_providers)} dynamic providers from kvstore")
self.kvstore = await kvstore_impl(self.config.run_config.storage.stores.metadata)
logger.info("✅ Initialized kvstore for dynamic provider management")
# Auto-instantiate connected providers on startup
if self.provider_registry:
for provider_id, conn_info in self.dynamic_providers.items():
if conn_info.status == ProviderConnectionStatus.connected:
try:
impl = await self._instantiate_provider(conn_info)
self.dynamic_provider_impls[provider_id] = impl
logger.info(f"Auto-instantiated provider {provider_id} from kvstore")
except Exception as e:
logger.error(f"Failed to auto-instantiate provider {provider_id}: {e}")
# Update status to failed
conn_info.status = ProviderConnectionStatus.failed
conn_info.error_message = str(e)
conn_info.updated_at = datetime.now(UTC)
await self._store_connection(conn_info)
else:
logger.warning("Provider registry not available, skipping auto-instantiation")
# Load existing dynamic providers from kvstore
await self._load_dynamic_providers()
logger.info(f"📦 Loaded {len(self.dynamic_providers)} existing dynamic providers from kvstore")
# Auto-instantiate connected providers on startup
if self.provider_registry:
for provider_id, conn_info in self.dynamic_providers.items():
if conn_info.status == ProviderConnectionStatus.connected:
try:
impl = await self._instantiate_provider(conn_info)
self.dynamic_provider_impls[provider_id] = impl
logger.info(f"♻️ Auto-instantiated provider {provider_id} from kvstore")
except Exception as e:
logger.error(f"Failed to auto-instantiate provider {provider_id}: {e}")
# Update status to failed
conn_info.status = ProviderConnectionStatus.failed
conn_info.error_message = str(e)
conn_info.updated_at = datetime.now(UTC)
await self._store_connection(conn_info)
else:
logger.warning("No metadata_store configured, dynamic provider management disabled")
logger.warning("Provider registry not available, skipping auto-instantiation")
async def shutdown(self) -> None:
logger.debug("ProviderImpl.shutdown")
@ -245,9 +252,10 @@ class ProviderImpl(Providers):
if not self.kvstore:
raise RuntimeError("KVStore not initialized")
key = f"{PROVIDER_CONNECTIONS_PREFIX}{info.provider_id}"
# Use composite key: provider_connections:v1::{api}::{provider_id}
key = f"{PROVIDER_CONNECTIONS_PREFIX}{info.api}::{info.provider_id}"
await self.kvstore.set(key, info.model_dump_json())
logger.debug(f"Stored provider connection: {info.provider_id}")
logger.debug(f"Stored provider connection: {info.api}::{info.provider_id}")
async def _load_connection(self, provider_id: str) -> ProviderConnectionInfo | None:
"""Load provider connection info from kvstore.
@ -293,8 +301,10 @@ class ProviderImpl(Providers):
"""Load dynamic providers from kvstore into runtime cache."""
connections = await self._list_connections()
for conn in connections:
self.dynamic_providers[conn.provider_id] = conn
logger.debug(f"Loaded dynamic provider: {conn.provider_id} (status: {conn.status})")
# Use composite key for runtime cache
cache_key = f"{conn.api}::{conn.provider_id}"
self.dynamic_providers[cache_key] = conn
logger.debug(f"Loaded dynamic provider: {cache_key} (status: {conn.status})")
# Helper methods for dynamic provider management
@ -384,12 +394,17 @@ class ProviderImpl(Providers):
All providers are stored in kvstore and treated equally.
"""
logger.info(f"📝 REGISTER_PROVIDER called: provider_id={provider_id}, api={api}, type={provider_type}")
if not self.kvstore:
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
# Check if provider_id already exists
if provider_id in self.dynamic_providers:
raise ValueError(f"Provider {provider_id} already exists")
# Use composite key to allow same provider_id for different APIs
cache_key = f"{api}::{provider_id}"
# Check if provider already exists for this API
if cache_key in self.dynamic_providers:
raise ValueError(f"Provider {provider_id} already exists for API {api}")
# Get authenticated user as owner
user = get_authenticated_user()
@ -415,7 +430,8 @@ class ProviderImpl(Providers):
# Instantiate provider if we have a provider registry
if self.provider_registry:
impl = await self._instantiate_provider(conn_info)
self.dynamic_provider_impls[provider_id] = impl
# Use composite key for impl cache too
self.dynamic_provider_impls[cache_key] = impl
# Update status to connected after successful instantiation
conn_info.status = ProviderConnectionStatus.connected
@ -434,8 +450,8 @@ class ProviderImpl(Providers):
# Store updated status
await self._store_connection(conn_info)
# Add to runtime cache
self.dynamic_providers[provider_id] = conn_info
# Add to runtime cache using composite key
self.dynamic_providers[cache_key] = conn_info
return RegisterProviderResponse(provider=conn_info)
@ -445,7 +461,7 @@ class ProviderImpl(Providers):
conn_info.error_message = str(e)
conn_info.updated_at = datetime.now(UTC)
await self._store_connection(conn_info)
self.dynamic_providers[provider_id] = conn_info
self.dynamic_providers[cache_key] = conn_info
logger.error(f"Failed to register provider {provider_id}: {e}")
raise RuntimeError(f"Failed to register provider: {e}") from e
@ -461,6 +477,8 @@ class ProviderImpl(Providers):
Updates persist to kvstore and survive server restarts.
This works for all providers (whether originally from run.yaml or API).
"""
logger.info(f"🔄 UPDATE_PROVIDER called: provider_id={provider_id}, has_config={config is not None}, has_attributes={attributes is not None}")
if not self.kvstore:
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
@ -531,6 +549,8 @@ class ProviderImpl(Providers):
Removes the provider from kvstore and shuts down its instance.
This works for all providers (whether originally from run.yaml or API).
"""
logger.info(f"🗑️ UNREGISTER_PROVIDER called: provider_id={provider_id}")
if not self.kvstore:
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
@ -560,6 +580,8 @@ class ProviderImpl(Providers):
async def test_provider_connection(self, provider_id: str) -> TestProviderConnectionResponse:
"""Test a provider connection."""
logger.info(f"🔍 TEST_PROVIDER_CONNECTION called: provider_id={provider_id}")
# Check if provider exists (static or dynamic)
provider_impl = None

View file

@ -42,6 +42,8 @@ from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceI
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
from llama_stack.core.resolver import ProviderRegistry, resolve_impls
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.store.registry import DistributionRegistry
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
@ -406,6 +408,187 @@ def _initialize_storage(run_config: StackRunConfig):
register_sqlstore_backends(sql_backends)
async def resolve_impls_via_provider_registration(
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
policy: list[AccessRule],
internal_impls: dict[Api, Any],
) -> dict[Api, Any]:
"""
Resolves provider implementations by registering them through ProviderImpl.
This ensures all providers (startup and runtime) go through the same registration code path.
Args:
run_config: Stack run configuration with providers from run.yaml
provider_registry: Registry of available provider types
dist_registry: Distribution registry
policy: Access control policy
internal_impls: Internal implementations (inspect, providers) already initialized
Returns:
Dictionary mapping API to implementation instances
"""
from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.resolver import sort_providers_by_deps, specs_for_autorouted_apis, validate_and_prepare_providers
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
# Validate and prepare providers from run.yaml
providers_with_specs = validate_and_prepare_providers(
run_config, provider_registry, routing_table_apis, router_apis
)
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
)
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
# Sort providers in dependency order
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
# Get the ProviderImpl instance
providers_impl = internal_impls[Api.providers]
# Register each provider through ProviderImpl
impls = internal_impls.copy()
logger.info(f"🚀 Starting provider registration for {len(sorted_providers)} providers from run.yaml")
for api_str, provider in sorted_providers:
# Skip providers that are not enabled
if provider.provider_id is None:
continue
# Skip internal APIs that need special handling
# - providers: already initialized as internal_impls
# - inspect: already initialized as internal_impls
# - telemetry: internal observability, directly instantiated below
if api_str in ["providers", "inspect"]:
continue
# Telemetry is an internal API that should be directly instantiated
if api_str == "telemetry":
logger.info(f"Instantiating {provider.provider_id} for {api_str}")
from llama_stack.core.resolver import instantiate_provider
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]
impl = await instantiate_provider(provider, deps, {}, dist_registry, run_config, policy)
api = Api(api_str)
impls[api] = impl
providers_impl.deps[api] = impl
continue
# Handle different provider types
try:
# Check if this is a routing table or router (system infrastructure)
is_routing_table = api_str.startswith("inner-") or provider.spec.provider_type in ["routing_table", "router"]
is_router = not api_str.startswith("inner-") and (Api(api_str) in router_apis or provider.spec.provider_type == "router")
if api_str.startswith("inner-") or provider.spec.provider_type == "routing_table":
# Inner providers or routing tables cannot be registered through the API
# They need to be instantiated directly
logger.info(f"Instantiating {provider.provider_id} for {api_str}")
from llama_stack.core.resolver import instantiate_provider
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]
# Get inner impls if available
inner_impls = {}
# For routing tables of autorouted APIs, get inner impls from the router API
# E.g., tool_groups routing table needs inner-tool_runtime providers
if provider.spec.provider_type == "routing_table":
from llama_stack.core.distribution import builtin_automatically_routed_apis
autorouted_map = {info.routing_table_api: info.router_api for info in builtin_automatically_routed_apis()}
if Api(api_str) in autorouted_map:
router_api_str = autorouted_map[Api(api_str)].value
inner_key = f"inner-{router_api_str}"
if inner_key in impls:
inner_impls = impls[inner_key]
else:
# For regular inner providers, use their own inner key
inner_key = f"inner-{api_str}"
if inner_key in impls:
inner_impls = impls[inner_key]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
# Store appropriately
if api_str.startswith("inner-"):
if api_str not in impls:
impls[api_str] = {}
impls[api_str][provider.provider_id] = impl
else:
api = Api(api_str)
impls[api] = impl
# Update providers_impl.deps so subsequent providers can depend on this
providers_impl.deps[api] = impl
elif is_router:
# Router providers also need special handling
logger.info(f"Instantiating router {provider.provider_id} for {api_str}")
from llama_stack.core.resolver import instantiate_provider
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]
# Get inner impls if this is a router
inner_impls = {}
inner_key = f"inner-{api_str}"
if inner_key in impls:
inner_impls = impls[inner_key]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
api = Api(api_str)
impls[api] = impl
# Update providers_impl.deps so subsequent providers can depend on this
providers_impl.deps[api] = impl
else:
# Regular providers - register through ProviderImpl
api = Api(api_str)
logger.info(f"Registering {provider.provider_id} for {api.value}")
response = await providers_impl.register_provider(
provider_id=provider.provider_id,
api=api.value,
provider_type=provider.spec.provider_type,
config=provider.config,
attributes=getattr(provider, "attributes", None),
)
# Get the instantiated impl from dynamic_provider_impls using composite key
cache_key = f"{api.value}::{provider.provider_id}"
impl = providers_impl.dynamic_provider_impls[cache_key]
impls[api] = impl
# IMPORTANT: Update providers_impl.deps so subsequent providers can depend on this one
providers_impl.deps[api] = impl
logger.info(f"✅ Successfully registered startup provider: {provider.provider_id}")
except Exception as e:
logger.error(f"❌ Failed to handle provider {provider.provider_id}: {e}")
raise
return impls
class Stack:
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
self.run_config = run_config
@ -441,7 +624,13 @@ class Stack:
policy=policy,
)
impls = await resolve_impls(
# Initialize the ProviderImpl so it has access to kvstore
print("DEBUG: About to initialize ProviderImpl")
await internal_impls[Api.providers].initialize()
print("DEBUG: ProviderImpl initialized, about to call resolve_impls_via_provider_registration")
# Register all providers from run.yaml through ProviderImpl
impls = await resolve_impls_via_provider_registration(
self.run_config,
provider_registry,
dist_registry,