mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 04:00:42 +00:00
save changes
This commit is contained in:
parent
66158d1999
commit
e21db79d6c
3 changed files with 250 additions and 39 deletions
|
|
@ -42,6 +42,8 @@ from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceI
|
|||
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.core.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.store.registry import DistributionRegistry
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
|
|
@ -406,6 +408,187 @@ def _initialize_storage(run_config: StackRunConfig):
|
|||
register_sqlstore_backends(sql_backends)
|
||||
|
||||
|
||||
async def resolve_impls_via_provider_registration(
|
||||
run_config: StackRunConfig,
|
||||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
policy: list[AccessRule],
|
||||
internal_impls: dict[Api, Any],
|
||||
) -> dict[Api, Any]:
|
||||
"""
|
||||
Resolves provider implementations by registering them through ProviderImpl.
|
||||
This ensures all providers (startup and runtime) go through the same registration code path.
|
||||
|
||||
Args:
|
||||
run_config: Stack run configuration with providers from run.yaml
|
||||
provider_registry: Registry of available provider types
|
||||
dist_registry: Distribution registry
|
||||
policy: Access control policy
|
||||
internal_impls: Internal implementations (inspect, providers) already initialized
|
||||
|
||||
Returns:
|
||||
Dictionary mapping API to implementation instances
|
||||
"""
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.core.resolver import sort_providers_by_deps, specs_for_autorouted_apis, validate_and_prepare_providers
|
||||
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
||||
|
||||
# Validate and prepare providers from run.yaml
|
||||
providers_with_specs = validate_and_prepare_providers(
|
||||
run_config, provider_registry, routing_table_apis, router_apis
|
||||
)
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
||||
)
|
||||
|
||||
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
|
||||
|
||||
# Sort providers in dependency order
|
||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||
|
||||
# Get the ProviderImpl instance
|
||||
providers_impl = internal_impls[Api.providers]
|
||||
|
||||
# Register each provider through ProviderImpl
|
||||
impls = internal_impls.copy()
|
||||
|
||||
logger.info(f"🚀 Starting provider registration for {len(sorted_providers)} providers from run.yaml")
|
||||
|
||||
for api_str, provider in sorted_providers:
|
||||
# Skip providers that are not enabled
|
||||
if provider.provider_id is None:
|
||||
continue
|
||||
|
||||
# Skip internal APIs that need special handling
|
||||
# - providers: already initialized as internal_impls
|
||||
# - inspect: already initialized as internal_impls
|
||||
# - telemetry: internal observability, directly instantiated below
|
||||
if api_str in ["providers", "inspect"]:
|
||||
continue
|
||||
|
||||
# Telemetry is an internal API that should be directly instantiated
|
||||
if api_str == "telemetry":
|
||||
logger.info(f"Instantiating {provider.provider_id} for {api_str}")
|
||||
|
||||
from llama_stack.core.resolver import instantiate_provider
|
||||
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
if a in impls:
|
||||
deps[a] = impls[a]
|
||||
|
||||
impl = await instantiate_provider(provider, deps, {}, dist_registry, run_config, policy)
|
||||
api = Api(api_str)
|
||||
impls[api] = impl
|
||||
providers_impl.deps[api] = impl
|
||||
continue
|
||||
|
||||
# Handle different provider types
|
||||
try:
|
||||
# Check if this is a routing table or router (system infrastructure)
|
||||
is_routing_table = api_str.startswith("inner-") or provider.spec.provider_type in ["routing_table", "router"]
|
||||
is_router = not api_str.startswith("inner-") and (Api(api_str) in router_apis or provider.spec.provider_type == "router")
|
||||
|
||||
if api_str.startswith("inner-") or provider.spec.provider_type == "routing_table":
|
||||
# Inner providers or routing tables cannot be registered through the API
|
||||
# They need to be instantiated directly
|
||||
logger.info(f"Instantiating {provider.provider_id} for {api_str}")
|
||||
|
||||
from llama_stack.core.resolver import instantiate_provider
|
||||
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
if a in impls:
|
||||
deps[a] = impls[a]
|
||||
|
||||
# Get inner impls if available
|
||||
inner_impls = {}
|
||||
|
||||
# For routing tables of autorouted APIs, get inner impls from the router API
|
||||
# E.g., tool_groups routing table needs inner-tool_runtime providers
|
||||
if provider.spec.provider_type == "routing_table":
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
autorouted_map = {info.routing_table_api: info.router_api for info in builtin_automatically_routed_apis()}
|
||||
if Api(api_str) in autorouted_map:
|
||||
router_api_str = autorouted_map[Api(api_str)].value
|
||||
inner_key = f"inner-{router_api_str}"
|
||||
if inner_key in impls:
|
||||
inner_impls = impls[inner_key]
|
||||
else:
|
||||
# For regular inner providers, use their own inner key
|
||||
inner_key = f"inner-{api_str}"
|
||||
if inner_key in impls:
|
||||
inner_impls = impls[inner_key]
|
||||
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
|
||||
|
||||
# Store appropriately
|
||||
if api_str.startswith("inner-"):
|
||||
if api_str not in impls:
|
||||
impls[api_str] = {}
|
||||
impls[api_str][provider.provider_id] = impl
|
||||
else:
|
||||
api = Api(api_str)
|
||||
impls[api] = impl
|
||||
# Update providers_impl.deps so subsequent providers can depend on this
|
||||
providers_impl.deps[api] = impl
|
||||
|
||||
elif is_router:
|
||||
# Router providers also need special handling
|
||||
logger.info(f"Instantiating router {provider.provider_id} for {api_str}")
|
||||
|
||||
from llama_stack.core.resolver import instantiate_provider
|
||||
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
if a in impls:
|
||||
deps[a] = impls[a]
|
||||
|
||||
# Get inner impls if this is a router
|
||||
inner_impls = {}
|
||||
inner_key = f"inner-{api_str}"
|
||||
if inner_key in impls:
|
||||
inner_impls = impls[inner_key]
|
||||
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
|
||||
api = Api(api_str)
|
||||
impls[api] = impl
|
||||
# Update providers_impl.deps so subsequent providers can depend on this
|
||||
providers_impl.deps[api] = impl
|
||||
|
||||
else:
|
||||
# Regular providers - register through ProviderImpl
|
||||
api = Api(api_str)
|
||||
logger.info(f"Registering {provider.provider_id} for {api.value}")
|
||||
|
||||
response = await providers_impl.register_provider(
|
||||
provider_id=provider.provider_id,
|
||||
api=api.value,
|
||||
provider_type=provider.spec.provider_type,
|
||||
config=provider.config,
|
||||
attributes=getattr(provider, "attributes", None),
|
||||
)
|
||||
|
||||
# Get the instantiated impl from dynamic_provider_impls using composite key
|
||||
cache_key = f"{api.value}::{provider.provider_id}"
|
||||
impl = providers_impl.dynamic_provider_impls[cache_key]
|
||||
impls[api] = impl
|
||||
|
||||
# IMPORTANT: Update providers_impl.deps so subsequent providers can depend on this one
|
||||
providers_impl.deps[api] = impl
|
||||
|
||||
logger.info(f"✅ Successfully registered startup provider: {provider.provider_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to handle provider {provider.provider_id}: {e}")
|
||||
raise
|
||||
|
||||
return impls
|
||||
|
||||
|
||||
class Stack:
|
||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||
self.run_config = run_config
|
||||
|
|
@ -441,7 +624,13 @@ class Stack:
|
|||
policy=policy,
|
||||
)
|
||||
|
||||
impls = await resolve_impls(
|
||||
# Initialize the ProviderImpl so it has access to kvstore
|
||||
print("DEBUG: About to initialize ProviderImpl")
|
||||
await internal_impls[Api.providers].initialize()
|
||||
print("DEBUG: ProviderImpl initialized, about to call resolve_impls_via_provider_registration")
|
||||
|
||||
# Register all providers from run.yaml through ProviderImpl
|
||||
impls = await resolve_impls_via_provider_registration(
|
||||
self.run_config,
|
||||
provider_registry,
|
||||
dist_registry,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue