save changes

This commit is contained in:
Raghotham Murthy 2025-10-24 14:28:49 -07:00
parent 66158d1999
commit e21db79d6c
3 changed files with 250 additions and 39 deletions

View file

@ -34,6 +34,8 @@ from .utils.config import redact_sensitive_fields
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
# Storage constants for dynamic provider connections # Storage constants for dynamic provider connections
# Use composite key format: provider_connections:v1::{api}::{provider_id}
# This allows the same provider_id to be used for different APIs
PROVIDER_CONNECTIONS_PREFIX = "provider_connections:v1::" PROVIDER_CONNECTIONS_PREFIX = "provider_connections:v1::"
@ -55,6 +57,8 @@ class ProviderImpl(Providers):
self.config = config self.config = config
self.deps = deps self.deps = deps
self.kvstore = None # KVStore for dynamic provider persistence self.kvstore = None # KVStore for dynamic provider persistence
# Runtime cache uses composite key: "{api}::{provider_id}"
# This allows the same provider_id to be used for different APIs
self.dynamic_providers: dict[str, ProviderConnectionInfo] = {} # Runtime cache self.dynamic_providers: dict[str, ProviderConnectionInfo] = {} # Runtime cache
self.dynamic_provider_impls: dict[str, Any] = {} # Initialized provider instances self.dynamic_provider_impls: dict[str, Any] = {} # Initialized provider instances
@ -65,36 +69,39 @@ class ProviderImpl(Providers):
async def initialize(self) -> None: async def initialize(self) -> None:
# Initialize kvstore for dynamic providers # Initialize kvstore for dynamic providers
# Reuse the same kvstore as the distribution registry if available # Use the metadata store from the new storage config structure
if hasattr(self.config.run_config, "metadata_store") and self.config.run_config.metadata_store: if not (self.config.run_config.storage and self.config.run_config.storage.stores.metadata):
from llama_stack.providers.utils.kvstore import kvstore_impl raise RuntimeError(
"No metadata store configured in storage.stores.metadata. "
"Provider management requires a configured metadata store (kv_memory, kv_sqlite, etc)."
)
self.kvstore = await kvstore_impl(self.config.run_config.metadata_store) from llama_stack.providers.utils.kvstore import kvstore_impl
logger.info("Initialized kvstore for dynamic provider management")
# Load existing dynamic providers from kvstore self.kvstore = await kvstore_impl(self.config.run_config.storage.stores.metadata)
await self._load_dynamic_providers() logger.info("✅ Initialized kvstore for dynamic provider management")
logger.info(f"Loaded {len(self.dynamic_providers)} dynamic providers from kvstore")
# Auto-instantiate connected providers on startup # Load existing dynamic providers from kvstore
if self.provider_registry: await self._load_dynamic_providers()
for provider_id, conn_info in self.dynamic_providers.items(): logger.info(f"📦 Loaded {len(self.dynamic_providers)} existing dynamic providers from kvstore")
if conn_info.status == ProviderConnectionStatus.connected:
try: # Auto-instantiate connected providers on startup
impl = await self._instantiate_provider(conn_info) if self.provider_registry:
self.dynamic_provider_impls[provider_id] = impl for provider_id, conn_info in self.dynamic_providers.items():
logger.info(f"Auto-instantiated provider {provider_id} from kvstore") if conn_info.status == ProviderConnectionStatus.connected:
except Exception as e: try:
logger.error(f"Failed to auto-instantiate provider {provider_id}: {e}") impl = await self._instantiate_provider(conn_info)
# Update status to failed self.dynamic_provider_impls[provider_id] = impl
conn_info.status = ProviderConnectionStatus.failed logger.info(f"♻️ Auto-instantiated provider {provider_id} from kvstore")
conn_info.error_message = str(e) except Exception as e:
conn_info.updated_at = datetime.now(UTC) logger.error(f"Failed to auto-instantiate provider {provider_id}: {e}")
await self._store_connection(conn_info) # Update status to failed
else: conn_info.status = ProviderConnectionStatus.failed
logger.warning("Provider registry not available, skipping auto-instantiation") conn_info.error_message = str(e)
conn_info.updated_at = datetime.now(UTC)
await self._store_connection(conn_info)
else: else:
logger.warning("No metadata_store configured, dynamic provider management disabled") logger.warning("Provider registry not available, skipping auto-instantiation")
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("ProviderImpl.shutdown") logger.debug("ProviderImpl.shutdown")
@ -245,9 +252,10 @@ class ProviderImpl(Providers):
if not self.kvstore: if not self.kvstore:
raise RuntimeError("KVStore not initialized") raise RuntimeError("KVStore not initialized")
key = f"{PROVIDER_CONNECTIONS_PREFIX}{info.provider_id}" # Use composite key: provider_connections:v1::{api}::{provider_id}
key = f"{PROVIDER_CONNECTIONS_PREFIX}{info.api}::{info.provider_id}"
await self.kvstore.set(key, info.model_dump_json()) await self.kvstore.set(key, info.model_dump_json())
logger.debug(f"Stored provider connection: {info.provider_id}") logger.debug(f"Stored provider connection: {info.api}::{info.provider_id}")
async def _load_connection(self, provider_id: str) -> ProviderConnectionInfo | None: async def _load_connection(self, provider_id: str) -> ProviderConnectionInfo | None:
"""Load provider connection info from kvstore. """Load provider connection info from kvstore.
@ -293,8 +301,10 @@ class ProviderImpl(Providers):
"""Load dynamic providers from kvstore into runtime cache.""" """Load dynamic providers from kvstore into runtime cache."""
connections = await self._list_connections() connections = await self._list_connections()
for conn in connections: for conn in connections:
self.dynamic_providers[conn.provider_id] = conn # Use composite key for runtime cache
logger.debug(f"Loaded dynamic provider: {conn.provider_id} (status: {conn.status})") cache_key = f"{conn.api}::{conn.provider_id}"
self.dynamic_providers[cache_key] = conn
logger.debug(f"Loaded dynamic provider: {cache_key} (status: {conn.status})")
# Helper methods for dynamic provider management # Helper methods for dynamic provider management
@ -384,12 +394,17 @@ class ProviderImpl(Providers):
All providers are stored in kvstore and treated equally. All providers are stored in kvstore and treated equally.
""" """
logger.info(f"📝 REGISTER_PROVIDER called: provider_id={provider_id}, api={api}, type={provider_type}")
if not self.kvstore: if not self.kvstore:
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)") raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
# Check if provider_id already exists # Use composite key to allow same provider_id for different APIs
if provider_id in self.dynamic_providers: cache_key = f"{api}::{provider_id}"
raise ValueError(f"Provider {provider_id} already exists")
# Check if provider already exists for this API
if cache_key in self.dynamic_providers:
raise ValueError(f"Provider {provider_id} already exists for API {api}")
# Get authenticated user as owner # Get authenticated user as owner
user = get_authenticated_user() user = get_authenticated_user()
@ -415,7 +430,8 @@ class ProviderImpl(Providers):
# Instantiate provider if we have a provider registry # Instantiate provider if we have a provider registry
if self.provider_registry: if self.provider_registry:
impl = await self._instantiate_provider(conn_info) impl = await self._instantiate_provider(conn_info)
self.dynamic_provider_impls[provider_id] = impl # Use composite key for impl cache too
self.dynamic_provider_impls[cache_key] = impl
# Update status to connected after successful instantiation # Update status to connected after successful instantiation
conn_info.status = ProviderConnectionStatus.connected conn_info.status = ProviderConnectionStatus.connected
@ -434,8 +450,8 @@ class ProviderImpl(Providers):
# Store updated status # Store updated status
await self._store_connection(conn_info) await self._store_connection(conn_info)
# Add to runtime cache # Add to runtime cache using composite key
self.dynamic_providers[provider_id] = conn_info self.dynamic_providers[cache_key] = conn_info
return RegisterProviderResponse(provider=conn_info) return RegisterProviderResponse(provider=conn_info)
@ -445,7 +461,7 @@ class ProviderImpl(Providers):
conn_info.error_message = str(e) conn_info.error_message = str(e)
conn_info.updated_at = datetime.now(UTC) conn_info.updated_at = datetime.now(UTC)
await self._store_connection(conn_info) await self._store_connection(conn_info)
self.dynamic_providers[provider_id] = conn_info self.dynamic_providers[cache_key] = conn_info
logger.error(f"Failed to register provider {provider_id}: {e}") logger.error(f"Failed to register provider {provider_id}: {e}")
raise RuntimeError(f"Failed to register provider: {e}") from e raise RuntimeError(f"Failed to register provider: {e}") from e
@ -461,6 +477,8 @@ class ProviderImpl(Providers):
Updates persist to kvstore and survive server restarts. Updates persist to kvstore and survive server restarts.
This works for all providers (whether originally from run.yaml or API). This works for all providers (whether originally from run.yaml or API).
""" """
logger.info(f"🔄 UPDATE_PROVIDER called: provider_id={provider_id}, has_config={config is not None}, has_attributes={attributes is not None}")
if not self.kvstore: if not self.kvstore:
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)") raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
@ -531,6 +549,8 @@ class ProviderImpl(Providers):
Removes the provider from kvstore and shuts down its instance. Removes the provider from kvstore and shuts down its instance.
This works for all providers (whether originally from run.yaml or API). This works for all providers (whether originally from run.yaml or API).
""" """
logger.info(f"🗑️ UNREGISTER_PROVIDER called: provider_id={provider_id}")
if not self.kvstore: if not self.kvstore:
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)") raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
@ -560,6 +580,8 @@ class ProviderImpl(Providers):
async def test_provider_connection(self, provider_id: str) -> TestProviderConnectionResponse: async def test_provider_connection(self, provider_id: str) -> TestProviderConnectionResponse:
"""Test a provider connection.""" """Test a provider connection."""
logger.info(f"🔍 TEST_PROVIDER_CONNECTION called: provider_id={provider_id}")
# Check if provider exists (static or dynamic) # Check if provider exists (static or dynamic)
provider_impl = None provider_impl = None

View file

@ -42,6 +42,8 @@ from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceI
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
from llama_stack.core.resolver import ProviderRegistry, resolve_impls from llama_stack.core.resolver import ProviderRegistry, resolve_impls
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.store.registry import DistributionRegistry
from llama_stack.core.storage.datatypes import ( from llama_stack.core.storage.datatypes import (
InferenceStoreReference, InferenceStoreReference,
KVStoreReference, KVStoreReference,
@ -406,6 +408,187 @@ def _initialize_storage(run_config: StackRunConfig):
register_sqlstore_backends(sql_backends) register_sqlstore_backends(sql_backends)
async def resolve_impls_via_provider_registration(
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
policy: list[AccessRule],
internal_impls: dict[Api, Any],
) -> dict[Api, Any]:
"""
Resolves provider implementations by registering them through ProviderImpl.
This ensures all providers (startup and runtime) go through the same registration code path.
Args:
run_config: Stack run configuration with providers from run.yaml
provider_registry: Registry of available provider types
dist_registry: Distribution registry
policy: Access control policy
internal_impls: Internal implementations (inspect, providers) already initialized
Returns:
Dictionary mapping API to implementation instances
"""
from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.resolver import sort_providers_by_deps, specs_for_autorouted_apis, validate_and_prepare_providers
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
# Validate and prepare providers from run.yaml
providers_with_specs = validate_and_prepare_providers(
run_config, provider_registry, routing_table_apis, router_apis
)
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
)
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
# Sort providers in dependency order
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
# Get the ProviderImpl instance
providers_impl = internal_impls[Api.providers]
# Register each provider through ProviderImpl
impls = internal_impls.copy()
logger.info(f"🚀 Starting provider registration for {len(sorted_providers)} providers from run.yaml")
for api_str, provider in sorted_providers:
# Skip providers that are not enabled
if provider.provider_id is None:
continue
# Skip internal APIs that need special handling
# - providers: already initialized as internal_impls
# - inspect: already initialized as internal_impls
# - telemetry: internal observability, directly instantiated below
if api_str in ["providers", "inspect"]:
continue
# Telemetry is an internal API that should be directly instantiated
if api_str == "telemetry":
logger.info(f"Instantiating {provider.provider_id} for {api_str}")
from llama_stack.core.resolver import instantiate_provider
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]
impl = await instantiate_provider(provider, deps, {}, dist_registry, run_config, policy)
api = Api(api_str)
impls[api] = impl
providers_impl.deps[api] = impl
continue
# Handle different provider types
try:
# Check if this is a routing table or router (system infrastructure)
is_routing_table = api_str.startswith("inner-") or provider.spec.provider_type in ["routing_table", "router"]
is_router = not api_str.startswith("inner-") and (Api(api_str) in router_apis or provider.spec.provider_type == "router")
if api_str.startswith("inner-") or provider.spec.provider_type == "routing_table":
# Inner providers or routing tables cannot be registered through the API
# They need to be instantiated directly
logger.info(f"Instantiating {provider.provider_id} for {api_str}")
from llama_stack.core.resolver import instantiate_provider
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]
# Get inner impls if available
inner_impls = {}
# For routing tables of autorouted APIs, get inner impls from the router API
# E.g., tool_groups routing table needs inner-tool_runtime providers
if provider.spec.provider_type == "routing_table":
from llama_stack.core.distribution import builtin_automatically_routed_apis
autorouted_map = {info.routing_table_api: info.router_api for info in builtin_automatically_routed_apis()}
if Api(api_str) in autorouted_map:
router_api_str = autorouted_map[Api(api_str)].value
inner_key = f"inner-{router_api_str}"
if inner_key in impls:
inner_impls = impls[inner_key]
else:
# For regular inner providers, use their own inner key
inner_key = f"inner-{api_str}"
if inner_key in impls:
inner_impls = impls[inner_key]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
# Store appropriately
if api_str.startswith("inner-"):
if api_str not in impls:
impls[api_str] = {}
impls[api_str][provider.provider_id] = impl
else:
api = Api(api_str)
impls[api] = impl
# Update providers_impl.deps so subsequent providers can depend on this
providers_impl.deps[api] = impl
elif is_router:
# Router providers also need special handling
logger.info(f"Instantiating router {provider.provider_id} for {api_str}")
from llama_stack.core.resolver import instantiate_provider
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]
# Get inner impls if this is a router
inner_impls = {}
inner_key = f"inner-{api_str}"
if inner_key in impls:
inner_impls = impls[inner_key]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
api = Api(api_str)
impls[api] = impl
# Update providers_impl.deps so subsequent providers can depend on this
providers_impl.deps[api] = impl
else:
# Regular providers - register through ProviderImpl
api = Api(api_str)
logger.info(f"Registering {provider.provider_id} for {api.value}")
response = await providers_impl.register_provider(
provider_id=provider.provider_id,
api=api.value,
provider_type=provider.spec.provider_type,
config=provider.config,
attributes=getattr(provider, "attributes", None),
)
# Get the instantiated impl from dynamic_provider_impls using composite key
cache_key = f"{api.value}::{provider.provider_id}"
impl = providers_impl.dynamic_provider_impls[cache_key]
impls[api] = impl
# IMPORTANT: Update providers_impl.deps so subsequent providers can depend on this one
providers_impl.deps[api] = impl
logger.info(f"✅ Successfully registered startup provider: {provider.provider_id}")
except Exception as e:
logger.error(f"❌ Failed to handle provider {provider.provider_id}: {e}")
raise
return impls
class Stack: class Stack:
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None): def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
self.run_config = run_config self.run_config = run_config
@ -441,7 +624,13 @@ class Stack:
policy=policy, policy=policy,
) )
impls = await resolve_impls( # Initialize the ProviderImpl so it has access to kvstore
print("DEBUG: About to initialize ProviderImpl")
await internal_impls[Api.providers].initialize()
print("DEBUG: ProviderImpl initialized, about to call resolve_impls_via_provider_registration")
# Register all providers from run.yaml through ProviderImpl
impls = await resolve_impls_via_provider_registration(
self.run_config, self.run_config,
provider_registry, provider_registry,
dist_registry, dist_registry,

View file

@ -231,7 +231,7 @@ storage:
backends: backends:
kv_default: kv_default:
type: kv_sqlite type: kv_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/kvstore.db db_path: ":memory:"
sql_default: sql_default:
type: sql_sqlite type: sql_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sql_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sql_store.db