This commit is contained in:
raghotham 2025-10-27 11:36:30 -07:00 committed by GitHub
commit fb49732f2f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 4819 additions and 41 deletions

View file

@ -34,13 +34,20 @@ from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.distribution import builtin_automatically_routed_apis, get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
from llama_stack.core.resolver import ProviderRegistry, resolve_impls
from llama_stack.core.resolver import (
ProviderRegistry,
instantiate_provider,
sort_providers_by_deps,
specs_for_autorouted_apis,
validate_and_prepare_providers,
)
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
@ -52,10 +59,12 @@ from llama_stack.core.storage.datatypes import (
StorageBackendConfig,
StorageConfig,
)
from llama_stack.core.store.registry import create_dist_registry
from llama_stack.core.store.registry import DistributionRegistry, create_dist_registry
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
logger = get_logger(name=__name__, category="core")
@ -341,12 +350,21 @@ def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
def add_internal_implementations(
impls: dict[Api, Any],
run_config: StackRunConfig,
provider_registry=None,
dist_registry=None,
policy=None,
) -> None:
"""Add internal implementations (inspect and providers) to the implementations dictionary.
Args:
impls: Dictionary of API implementations
run_config: Stack run configuration
provider_registry: Provider registry for dynamic provider instantiation
dist_registry: Distribution registry
policy: Access control policy
"""
inspect_impl = DistributionInspectImpl(
DistributionInspectConfig(run_config=run_config),
@ -355,7 +373,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
impls[Api.inspect] = inspect_impl
providers_impl = ProviderImpl(
ProviderImplConfig(run_config=run_config),
ProviderImplConfig(
run_config=run_config,
provider_registry=provider_registry,
dist_registry=dist_registry,
policy=policy,
),
deps=impls,
)
impls[Api.providers] = providers_impl
@ -385,13 +408,179 @@ def _initialize_storage(run_config: StackRunConfig):
else:
raise ValueError(f"Unknown storage backend type: {type}")
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
register_kvstore_backends(kv_backends)
register_sqlstore_backends(sql_backends)
async def resolve_impls_via_provider_registration(
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
policy: list[AccessRule],
internal_impls: dict[Api, Any],
) -> dict[Api, Any]:
"""
Resolves provider implementations by registering them through ProviderImpl.
This ensures all providers (startup and runtime) go through the same registration code path.
Args:
run_config: Stack run configuration with providers from run.yaml
provider_registry: Registry of available provider types
dist_registry: Distribution registry
policy: Access control policy
internal_impls: Internal implementations (inspect, providers) already initialized
Returns:
Dictionary mapping API to implementation instances
"""
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
# Validate and prepare providers from run.yaml
providers_with_specs = validate_and_prepare_providers(
run_config, provider_registry, routing_table_apis, router_apis
)
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
)
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
# Sort providers in dependency order
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
# Get the ProviderImpl instance
providers_impl = internal_impls[Api.providers]
# Register each provider through ProviderImpl
impls = internal_impls.copy()
logger.info(f"Provider registration for {len(sorted_providers)} providers from run.yaml")
for api_str, provider in sorted_providers:
# Skip providers that are not enabled
if provider.provider_id is None:
continue
# Skip internal APIs (already initialized)
if api_str in ["providers", "inspect"]:
continue
# Handle different provider types
try:
# Check if this is a router (system infrastructure)
is_router = not api_str.startswith("inner-") and (
Api(api_str) in router_apis or provider.spec.provider_type == "router"
)
if api_str.startswith("inner-") or provider.spec.provider_type == "routing_table":
# Inner providers or routing tables cannot be registered through the API
# They need to be instantiated directly
logger.info(f"Instantiating {provider.provider_id} for {api_str}")
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]
# Get inner impls if available
inner_impls = {}
# For routing tables of autorouted APIs, get inner impls from the router API
# E.g., tool_groups routing table needs inner-tool_runtime providers
if provider.spec.provider_type == "routing_table":
autorouted_map = {
info.routing_table_api: info.router_api for info in builtin_automatically_routed_apis()
}
if Api(api_str) in autorouted_map:
router_api_str = autorouted_map[Api(api_str)].value
inner_key = f"inner-{router_api_str}"
if inner_key in impls:
inner_impls = impls[inner_key]
else:
# For regular inner providers, use their own inner key
inner_key = f"inner-{api_str}"
if inner_key in impls:
inner_impls = impls[inner_key]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
# Store appropriately
if api_str.startswith("inner-"):
if api_str not in impls:
impls[api_str] = {}
impls[api_str][provider.provider_id] = impl
else:
api = Api(api_str)
impls[api] = impl
# Update providers_impl.deps so subsequent providers can depend on this
providers_impl.deps[api] = impl
elif is_router:
# Router providers also need special handling
logger.info(f"Instantiating router {provider.provider_id} for {api_str}")
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]
# Get inner impls if this is a router
inner_impls = {}
inner_key = f"inner-{api_str}"
if inner_key in impls:
inner_impls = impls[inner_key]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
api = Api(api_str)
impls[api] = impl
# Update providers_impl.deps so subsequent providers can depend on this
providers_impl.deps[api] = impl
else:
# Regular providers - register through ProviderImpl
api = Api(api_str)
cache_key = f"{api.value}::{provider.provider_id}"
# Check if provider already exists (loaded from kvstore during initialization)
if cache_key in providers_impl.dynamic_providers:
logger.info(
f"Provider {provider.provider_id} for {api.value} already exists, using existing instance"
)
impl = providers_impl.dynamic_provider_impls.get(cache_key)
if impl is None:
# Provider exists but not instantiated, instantiate it
conn_info = providers_impl.dynamic_providers[cache_key]
impl = await providers_impl._instantiate_provider(conn_info)
providers_impl.dynamic_provider_impls[cache_key] = impl
else:
logger.info(f"Registering {provider.provider_id} for {api.value}")
await providers_impl.register_provider(
api=api.value,
provider_id=provider.provider_id,
provider_type=provider.spec.provider_type,
config=provider.config,
attributes=getattr(provider, "attributes", None),
)
# Get the instantiated impl from dynamic_provider_impls using composite key
impl = providers_impl.dynamic_provider_impls[cache_key]
logger.info(f"Successfully registered startup provider: {provider.provider_id}")
impls[api] = impl
# IMPORTANT: Update providers_impl.deps so subsequent providers can depend on this one
providers_impl.deps[api] = impl
except Exception as e:
logger.error(f"Failed to handle provider {provider.provider_id}: {e}")
raise
return impls
class Stack:
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
self.run_config = run_config
@ -416,13 +605,24 @@ class Stack:
raise ValueError("storage.stores.metadata must be configured with a kv_* backend")
dist_registry, _ = await create_dist_registry(stores.metadata, self.run_config.image_name)
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
provider_registry = self.provider_registry or get_provider_registry(self.run_config)
internal_impls = {}
add_internal_implementations(internal_impls, self.run_config)
impls = await resolve_impls(
add_internal_implementations(
internal_impls,
self.run_config,
self.provider_registry or get_provider_registry(self.run_config),
provider_registry=provider_registry,
dist_registry=dist_registry,
policy=policy,
)
# Initialize the ProviderImpl so it has access to kvstore
await internal_impls[Api.providers].initialize()
# Register all providers from run.yaml through ProviderImpl
impls = await resolve_impls_via_provider_registration(
self.run_config,
provider_registry,
dist_registry,
policy,
internal_impls,