mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
save changes
This commit is contained in:
parent
66158d1999
commit
e21db79d6c
3 changed files with 250 additions and 39 deletions
|
|
@ -34,6 +34,8 @@ from .utils.config import redact_sensitive_fields
|
|||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
# Storage constants for dynamic provider connections
|
||||
# Use composite key format: provider_connections:v1::{api}::{provider_id}
|
||||
# This allows the same provider_id to be used for different APIs
|
||||
PROVIDER_CONNECTIONS_PREFIX = "provider_connections:v1::"
|
||||
|
||||
|
||||
|
|
@ -55,6 +57,8 @@ class ProviderImpl(Providers):
|
|||
self.config = config
|
||||
self.deps = deps
|
||||
self.kvstore = None # KVStore for dynamic provider persistence
|
||||
# Runtime cache uses composite key: "{api}::{provider_id}"
|
||||
# This allows the same provider_id to be used for different APIs
|
||||
self.dynamic_providers: dict[str, ProviderConnectionInfo] = {} # Runtime cache
|
||||
self.dynamic_provider_impls: dict[str, Any] = {} # Initialized provider instances
|
||||
|
||||
|
|
@ -65,16 +69,21 @@ class ProviderImpl(Providers):
|
|||
|
||||
async def initialize(self) -> None:
|
||||
# Initialize kvstore for dynamic providers
|
||||
# Reuse the same kvstore as the distribution registry if available
|
||||
if hasattr(self.config.run_config, "metadata_store") and self.config.run_config.metadata_store:
|
||||
# Use the metadata store from the new storage config structure
|
||||
if not (self.config.run_config.storage and self.config.run_config.storage.stores.metadata):
|
||||
raise RuntimeError(
|
||||
"No metadata store configured in storage.stores.metadata. "
|
||||
"Provider management requires a configured metadata store (kv_memory, kv_sqlite, etc)."
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
self.kvstore = await kvstore_impl(self.config.run_config.metadata_store)
|
||||
logger.info("Initialized kvstore for dynamic provider management")
|
||||
self.kvstore = await kvstore_impl(self.config.run_config.storage.stores.metadata)
|
||||
logger.info("✅ Initialized kvstore for dynamic provider management")
|
||||
|
||||
# Load existing dynamic providers from kvstore
|
||||
await self._load_dynamic_providers()
|
||||
logger.info(f"Loaded {len(self.dynamic_providers)} dynamic providers from kvstore")
|
||||
logger.info(f"📦 Loaded {len(self.dynamic_providers)} existing dynamic providers from kvstore")
|
||||
|
||||
# Auto-instantiate connected providers on startup
|
||||
if self.provider_registry:
|
||||
|
|
@ -83,7 +92,7 @@ class ProviderImpl(Providers):
|
|||
try:
|
||||
impl = await self._instantiate_provider(conn_info)
|
||||
self.dynamic_provider_impls[provider_id] = impl
|
||||
logger.info(f"Auto-instantiated provider {provider_id} from kvstore")
|
||||
logger.info(f"♻️ Auto-instantiated provider {provider_id} from kvstore")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to auto-instantiate provider {provider_id}: {e}")
|
||||
# Update status to failed
|
||||
|
|
@ -93,8 +102,6 @@ class ProviderImpl(Providers):
|
|||
await self._store_connection(conn_info)
|
||||
else:
|
||||
logger.warning("Provider registry not available, skipping auto-instantiation")
|
||||
else:
|
||||
logger.warning("No metadata_store configured, dynamic provider management disabled")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("ProviderImpl.shutdown")
|
||||
|
|
@ -245,9 +252,10 @@ class ProviderImpl(Providers):
|
|||
if not self.kvstore:
|
||||
raise RuntimeError("KVStore not initialized")
|
||||
|
||||
key = f"{PROVIDER_CONNECTIONS_PREFIX}{info.provider_id}"
|
||||
# Use composite key: provider_connections:v1::{api}::{provider_id}
|
||||
key = f"{PROVIDER_CONNECTIONS_PREFIX}{info.api}::{info.provider_id}"
|
||||
await self.kvstore.set(key, info.model_dump_json())
|
||||
logger.debug(f"Stored provider connection: {info.provider_id}")
|
||||
logger.debug(f"Stored provider connection: {info.api}::{info.provider_id}")
|
||||
|
||||
async def _load_connection(self, provider_id: str) -> ProviderConnectionInfo | None:
|
||||
"""Load provider connection info from kvstore.
|
||||
|
|
@ -293,8 +301,10 @@ class ProviderImpl(Providers):
|
|||
"""Load dynamic providers from kvstore into runtime cache."""
|
||||
connections = await self._list_connections()
|
||||
for conn in connections:
|
||||
self.dynamic_providers[conn.provider_id] = conn
|
||||
logger.debug(f"Loaded dynamic provider: {conn.provider_id} (status: {conn.status})")
|
||||
# Use composite key for runtime cache
|
||||
cache_key = f"{conn.api}::{conn.provider_id}"
|
||||
self.dynamic_providers[cache_key] = conn
|
||||
logger.debug(f"Loaded dynamic provider: {cache_key} (status: {conn.status})")
|
||||
|
||||
# Helper methods for dynamic provider management
|
||||
|
||||
|
|
@ -384,12 +394,17 @@ class ProviderImpl(Providers):
|
|||
|
||||
All providers are stored in kvstore and treated equally.
|
||||
"""
|
||||
logger.info(f"📝 REGISTER_PROVIDER called: provider_id={provider_id}, api={api}, type={provider_type}")
|
||||
|
||||
if not self.kvstore:
|
||||
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
|
||||
|
||||
# Check if provider_id already exists
|
||||
if provider_id in self.dynamic_providers:
|
||||
raise ValueError(f"Provider {provider_id} already exists")
|
||||
# Use composite key to allow same provider_id for different APIs
|
||||
cache_key = f"{api}::{provider_id}"
|
||||
|
||||
# Check if provider already exists for this API
|
||||
if cache_key in self.dynamic_providers:
|
||||
raise ValueError(f"Provider {provider_id} already exists for API {api}")
|
||||
|
||||
# Get authenticated user as owner
|
||||
user = get_authenticated_user()
|
||||
|
|
@ -415,7 +430,8 @@ class ProviderImpl(Providers):
|
|||
# Instantiate provider if we have a provider registry
|
||||
if self.provider_registry:
|
||||
impl = await self._instantiate_provider(conn_info)
|
||||
self.dynamic_provider_impls[provider_id] = impl
|
||||
# Use composite key for impl cache too
|
||||
self.dynamic_provider_impls[cache_key] = impl
|
||||
|
||||
# Update status to connected after successful instantiation
|
||||
conn_info.status = ProviderConnectionStatus.connected
|
||||
|
|
@ -434,8 +450,8 @@ class ProviderImpl(Providers):
|
|||
# Store updated status
|
||||
await self._store_connection(conn_info)
|
||||
|
||||
# Add to runtime cache
|
||||
self.dynamic_providers[provider_id] = conn_info
|
||||
# Add to runtime cache using composite key
|
||||
self.dynamic_providers[cache_key] = conn_info
|
||||
|
||||
return RegisterProviderResponse(provider=conn_info)
|
||||
|
||||
|
|
@ -445,7 +461,7 @@ class ProviderImpl(Providers):
|
|||
conn_info.error_message = str(e)
|
||||
conn_info.updated_at = datetime.now(UTC)
|
||||
await self._store_connection(conn_info)
|
||||
self.dynamic_providers[provider_id] = conn_info
|
||||
self.dynamic_providers[cache_key] = conn_info
|
||||
|
||||
logger.error(f"Failed to register provider {provider_id}: {e}")
|
||||
raise RuntimeError(f"Failed to register provider: {e}") from e
|
||||
|
|
@ -461,6 +477,8 @@ class ProviderImpl(Providers):
|
|||
Updates persist to kvstore and survive server restarts.
|
||||
This works for all providers (whether originally from run.yaml or API).
|
||||
"""
|
||||
logger.info(f"🔄 UPDATE_PROVIDER called: provider_id={provider_id}, has_config={config is not None}, has_attributes={attributes is not None}")
|
||||
|
||||
if not self.kvstore:
|
||||
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
|
||||
|
||||
|
|
@ -531,6 +549,8 @@ class ProviderImpl(Providers):
|
|||
Removes the provider from kvstore and shuts down its instance.
|
||||
This works for all providers (whether originally from run.yaml or API).
|
||||
"""
|
||||
logger.info(f"🗑️ UNREGISTER_PROVIDER called: provider_id={provider_id}")
|
||||
|
||||
if not self.kvstore:
|
||||
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
|
||||
|
||||
|
|
@ -560,6 +580,8 @@ class ProviderImpl(Providers):
|
|||
|
||||
async def test_provider_connection(self, provider_id: str) -> TestProviderConnectionResponse:
|
||||
"""Test a provider connection."""
|
||||
logger.info(f"🔍 TEST_PROVIDER_CONNECTION called: provider_id={provider_id}")
|
||||
|
||||
# Check if provider exists (static or dynamic)
|
||||
provider_impl = None
|
||||
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceI
|
|||
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.core.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.store.registry import DistributionRegistry
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
|
|
@ -406,6 +408,187 @@ def _initialize_storage(run_config: StackRunConfig):
|
|||
register_sqlstore_backends(sql_backends)
|
||||
|
||||
|
||||
async def resolve_impls_via_provider_registration(
|
||||
run_config: StackRunConfig,
|
||||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
policy: list[AccessRule],
|
||||
internal_impls: dict[Api, Any],
|
||||
) -> dict[Api, Any]:
|
||||
"""
|
||||
Resolves provider implementations by registering them through ProviderImpl.
|
||||
This ensures all providers (startup and runtime) go through the same registration code path.
|
||||
|
||||
Args:
|
||||
run_config: Stack run configuration with providers from run.yaml
|
||||
provider_registry: Registry of available provider types
|
||||
dist_registry: Distribution registry
|
||||
policy: Access control policy
|
||||
internal_impls: Internal implementations (inspect, providers) already initialized
|
||||
|
||||
Returns:
|
||||
Dictionary mapping API to implementation instances
|
||||
"""
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.core.resolver import sort_providers_by_deps, specs_for_autorouted_apis, validate_and_prepare_providers
|
||||
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
||||
|
||||
# Validate and prepare providers from run.yaml
|
||||
providers_with_specs = validate_and_prepare_providers(
|
||||
run_config, provider_registry, routing_table_apis, router_apis
|
||||
)
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
||||
)
|
||||
|
||||
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
|
||||
|
||||
# Sort providers in dependency order
|
||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||
|
||||
# Get the ProviderImpl instance
|
||||
providers_impl = internal_impls[Api.providers]
|
||||
|
||||
# Register each provider through ProviderImpl
|
||||
impls = internal_impls.copy()
|
||||
|
||||
logger.info(f"🚀 Starting provider registration for {len(sorted_providers)} providers from run.yaml")
|
||||
|
||||
for api_str, provider in sorted_providers:
|
||||
# Skip providers that are not enabled
|
||||
if provider.provider_id is None:
|
||||
continue
|
||||
|
||||
# Skip internal APIs that need special handling
|
||||
# - providers: already initialized as internal_impls
|
||||
# - inspect: already initialized as internal_impls
|
||||
# - telemetry: internal observability, directly instantiated below
|
||||
if api_str in ["providers", "inspect"]:
|
||||
continue
|
||||
|
||||
# Telemetry is an internal API that should be directly instantiated
|
||||
if api_str == "telemetry":
|
||||
logger.info(f"Instantiating {provider.provider_id} for {api_str}")
|
||||
|
||||
from llama_stack.core.resolver import instantiate_provider
|
||||
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
if a in impls:
|
||||
deps[a] = impls[a]
|
||||
|
||||
impl = await instantiate_provider(provider, deps, {}, dist_registry, run_config, policy)
|
||||
api = Api(api_str)
|
||||
impls[api] = impl
|
||||
providers_impl.deps[api] = impl
|
||||
continue
|
||||
|
||||
# Handle different provider types
|
||||
try:
|
||||
# Check if this is a routing table or router (system infrastructure)
|
||||
is_routing_table = api_str.startswith("inner-") or provider.spec.provider_type in ["routing_table", "router"]
|
||||
is_router = not api_str.startswith("inner-") and (Api(api_str) in router_apis or provider.spec.provider_type == "router")
|
||||
|
||||
if api_str.startswith("inner-") or provider.spec.provider_type == "routing_table":
|
||||
# Inner providers or routing tables cannot be registered through the API
|
||||
# They need to be instantiated directly
|
||||
logger.info(f"Instantiating {provider.provider_id} for {api_str}")
|
||||
|
||||
from llama_stack.core.resolver import instantiate_provider
|
||||
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
if a in impls:
|
||||
deps[a] = impls[a]
|
||||
|
||||
# Get inner impls if available
|
||||
inner_impls = {}
|
||||
|
||||
# For routing tables of autorouted APIs, get inner impls from the router API
|
||||
# E.g., tool_groups routing table needs inner-tool_runtime providers
|
||||
if provider.spec.provider_type == "routing_table":
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
autorouted_map = {info.routing_table_api: info.router_api for info in builtin_automatically_routed_apis()}
|
||||
if Api(api_str) in autorouted_map:
|
||||
router_api_str = autorouted_map[Api(api_str)].value
|
||||
inner_key = f"inner-{router_api_str}"
|
||||
if inner_key in impls:
|
||||
inner_impls = impls[inner_key]
|
||||
else:
|
||||
# For regular inner providers, use their own inner key
|
||||
inner_key = f"inner-{api_str}"
|
||||
if inner_key in impls:
|
||||
inner_impls = impls[inner_key]
|
||||
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
|
||||
|
||||
# Store appropriately
|
||||
if api_str.startswith("inner-"):
|
||||
if api_str not in impls:
|
||||
impls[api_str] = {}
|
||||
impls[api_str][provider.provider_id] = impl
|
||||
else:
|
||||
api = Api(api_str)
|
||||
impls[api] = impl
|
||||
# Update providers_impl.deps so subsequent providers can depend on this
|
||||
providers_impl.deps[api] = impl
|
||||
|
||||
elif is_router:
|
||||
# Router providers also need special handling
|
||||
logger.info(f"Instantiating router {provider.provider_id} for {api_str}")
|
||||
|
||||
from llama_stack.core.resolver import instantiate_provider
|
||||
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
if a in impls:
|
||||
deps[a] = impls[a]
|
||||
|
||||
# Get inner impls if this is a router
|
||||
inner_impls = {}
|
||||
inner_key = f"inner-{api_str}"
|
||||
if inner_key in impls:
|
||||
inner_impls = impls[inner_key]
|
||||
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
|
||||
api = Api(api_str)
|
||||
impls[api] = impl
|
||||
# Update providers_impl.deps so subsequent providers can depend on this
|
||||
providers_impl.deps[api] = impl
|
||||
|
||||
else:
|
||||
# Regular providers - register through ProviderImpl
|
||||
api = Api(api_str)
|
||||
logger.info(f"Registering {provider.provider_id} for {api.value}")
|
||||
|
||||
response = await providers_impl.register_provider(
|
||||
provider_id=provider.provider_id,
|
||||
api=api.value,
|
||||
provider_type=provider.spec.provider_type,
|
||||
config=provider.config,
|
||||
attributes=getattr(provider, "attributes", None),
|
||||
)
|
||||
|
||||
# Get the instantiated impl from dynamic_provider_impls using composite key
|
||||
cache_key = f"{api.value}::{provider.provider_id}"
|
||||
impl = providers_impl.dynamic_provider_impls[cache_key]
|
||||
impls[api] = impl
|
||||
|
||||
# IMPORTANT: Update providers_impl.deps so subsequent providers can depend on this one
|
||||
providers_impl.deps[api] = impl
|
||||
|
||||
logger.info(f"✅ Successfully registered startup provider: {provider.provider_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to handle provider {provider.provider_id}: {e}")
|
||||
raise
|
||||
|
||||
return impls
|
||||
|
||||
|
||||
class Stack:
|
||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||
self.run_config = run_config
|
||||
|
|
@ -441,7 +624,13 @@ class Stack:
|
|||
policy=policy,
|
||||
)
|
||||
|
||||
impls = await resolve_impls(
|
||||
# Initialize the ProviderImpl so it has access to kvstore
|
||||
print("DEBUG: About to initialize ProviderImpl")
|
||||
await internal_impls[Api.providers].initialize()
|
||||
print("DEBUG: ProviderImpl initialized, about to call resolve_impls_via_provider_registration")
|
||||
|
||||
# Register all providers from run.yaml through ProviderImpl
|
||||
impls = await resolve_impls_via_provider_registration(
|
||||
self.run_config,
|
||||
provider_registry,
|
||||
dist_registry,
|
||||
|
|
|
|||
|
|
@ -231,7 +231,7 @@ storage:
|
|||
backends:
|
||||
kv_default:
|
||||
type: kv_sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/kvstore.db
|
||||
db_path: ":memory:"
|
||||
sql_default:
|
||||
type: sql_sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sql_store.db
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue