This commit is contained in:
raghotham 2025-10-27 11:36:30 -07:00 committed by GitHub
commit fb49732f2f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 4819 additions and 41 deletions

View file

@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import UTC, datetime
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.core.datatypes import User
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class ProviderConnectionStatus(StrEnum):
"""Status of a dynamic provider connection.
:cvar pending: Configuration stored, not yet initialized
:cvar initializing: In the process of connecting
:cvar connected: Successfully connected and healthy
:cvar failed: Connection attempt failed
:cvar disconnected: Previously connected, now disconnected
:cvar testing: Health check in progress
"""
pending = "pending"
initializing = "initializing"
connected = "connected"
failed = "failed"
disconnected = "disconnected"
testing = "testing"
@json_schema_type
class ProviderHealth(BaseModel):
"""Structured wrapper around provider health status.
This wraps the existing dict-based HealthResponse for API responses
while maintaining backward compatibility with existing provider implementations.
:param status: Health status (OK, ERROR, NOT_IMPLEMENTED)
:param message: Optional error or status message
:param metrics: Provider-specific health metrics
:param last_checked: Timestamp of last health check
"""
status: HealthStatus
message: str | None = None
metrics: dict[str, Any] = Field(default_factory=dict)
last_checked: datetime
@classmethod
def from_health_response(cls, response: dict[str, Any]) -> "ProviderHealth":
"""Convert dict-based HealthResponse to ProviderHealth.
This allows us to maintain the existing dict[str, Any] return type
for provider.health() methods while providing a structured model
for API responses.
:param response: Dict with 'status' and optional 'message', 'metrics'
:returns: ProviderHealth instance
"""
return cls(
status=HealthStatus(response.get("status", HealthStatus.NOT_IMPLEMENTED)),
message=response.get("message"),
metrics=response.get("metrics", {}),
last_checked=datetime.now(UTC),
)
@json_schema_type
class ProviderConnectionInfo(BaseModel):
"""Information about a dynamically managed provider connection.
This model represents a provider that has been registered at runtime
via the /providers API, as opposed to static providers configured in run.yaml.
Dynamic providers support full lifecycle management including registration,
configuration updates, health monitoring, and removal.
:param provider_id: Unique identifier for this provider instance
:param api: API namespace (e.g., "inference", "vector_io", "safety")
:param provider_type: Provider type identifier (e.g., "remote::openai", "inline::faiss")
:param config: Provider-specific configuration (API keys, endpoints, etc.)
:param status: Current connection status
:param health: Most recent health check result
:param created_at: Timestamp when provider was registered
:param updated_at: Timestamp of last update
:param last_health_check: Timestamp of last health check
:param error_message: Error message if status is failed
:param metadata: User-defined metadata (deprecated, use attributes)
:param owner: User who created this provider connection
:param attributes: Key-value attributes for ABAC access control
"""
provider_id: str
api: str
provider_type: str
config: dict[str, Any]
status: ProviderConnectionStatus
health: ProviderHealth | None = None
created_at: datetime
updated_at: datetime
last_health_check: datetime | None = None
error_message: str | None = None
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Deprecated: use attributes for access control",
)
# ABAC fields (same as ResourceWithOwner)
owner: User | None = None
attributes: dict[str, list[str]] | None = None

View file

@ -8,6 +8,7 @@ from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.providers.connection import ProviderConnectionInfo
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.datatypes import HealthResponse
from llama_stack.schema_utils import json_schema_type, webmethod
@ -40,6 +41,85 @@ class ListProvidersResponse(BaseModel):
data: list[ProviderInfo]
# ===== Dynamic Provider Management API Models =====
@json_schema_type
class RegisterProviderRequest(BaseModel):
"""Request to register a new dynamic provider.
:param provider_id: Unique identifier for the provider instance
:param api: API namespace (e.g., 'inference', 'vector_io', 'safety')
:param provider_type: Provider type identifier (e.g., 'remote::openai', 'inline::faiss')
:param config: Provider-specific configuration (API keys, endpoints, etc.)
:param attributes: Optional key-value attributes for ABAC access control
"""
provider_id: str
api: str
provider_type: str
config: dict[str, Any]
attributes: dict[str, list[str]] | None = None
@json_schema_type
class RegisterProviderResponse(BaseModel):
"""Response after registering a provider.
:param provider: Information about the registered provider
"""
provider: ProviderConnectionInfo
@json_schema_type
class UpdateProviderRequest(BaseModel):
"""Request to update an existing provider's configuration.
:param config: New configuration parameters (will be merged with existing)
:param attributes: Optional updated attributes for access control
"""
config: dict[str, Any] | None = None
attributes: dict[str, list[str]] | None = None
@json_schema_type
class UpdateProviderResponse(BaseModel):
"""Response after updating a provider.
:param provider: Updated provider information
"""
provider: ProviderConnectionInfo
@json_schema_type
class UnregisterProviderResponse(BaseModel):
"""Response after unregistering a provider.
:param success: Whether the operation succeeded
:param message: Optional status message
"""
success: bool
message: str | None = None
@json_schema_type
class TestProviderConnectionResponse(BaseModel):
"""Response from testing a provider connection.
:param success: Whether the connection test succeeded
:param health: Health status from the provider
:param error_message: Error message if test failed
"""
success: bool
health: HealthResponse | None = None
error_message: str | None = None
@runtime_checkable
class Providers(Protocol):
"""Providers
@ -57,12 +137,107 @@ class Providers(Protocol):
"""
...
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
"""Get provider.
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
async def inspect_provider(self, provider_id: str) -> ListProvidersResponse:
"""Get providers by ID (deprecated - use /providers/{api}/{provider_id} instead).
Get detailed information about a specific provider.
DEPRECATED: Returns all providers with the given provider_id across all APIs.
This can return multiple providers if the same ID is used for different APIs.
Use /providers/{api}/{provider_id} for unambiguous access.
:param provider_id: The ID of the provider(s) to inspect.
:returns: A ListProvidersResponse containing all providers with matching provider_id.
"""
...
# ===== Dynamic Provider Management Methods =====
@webmethod(route="/admin/providers/{api}", method="POST", level=LLAMA_STACK_API_V1)
async def register_provider(
self,
api: str,
provider_id: str,
provider_type: str,
config: dict[str, Any],
attributes: dict[str, list[str]] | None = None,
) -> RegisterProviderResponse:
"""Register a new dynamic provider.
Register a new provider instance at runtime. The provider will be validated,
instantiated, and persisted to the kvstore. Requires appropriate ABAC permissions.
:param api: API namespace this provider implements (e.g., 'inference', 'vector_io').
:param provider_id: Unique identifier for this provider instance.
:param provider_type: Provider type (e.g., 'remote::openai').
:param config: Provider configuration (API keys, endpoints, etc.).
:param attributes: Optional attributes for ABAC access control.
:returns: RegisterProviderResponse with the registered provider info.
"""
...
@webmethod(route="/admin/providers/{api}/{provider_id}", method="PUT", level=LLAMA_STACK_API_V1)
async def update_provider(
self,
api: str,
provider_id: str,
config: dict[str, Any] | None = None,
attributes: dict[str, list[str]] | None = None,
) -> UpdateProviderResponse:
"""Update an existing provider's configuration.
Update the configuration and/or attributes of a dynamic provider. The provider
will be re-instantiated with the new configuration (hot-reload).
:param api: API namespace the provider implements
:param provider_id: ID of the provider to update
:param config: New configuration parameters (merged with existing)
:param attributes: New attributes for access control
:returns: UpdateProviderResponse with updated provider info
"""
...
@webmethod(route="/admin/providers/{api}/{provider_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_provider(self, api: str, provider_id: str) -> None:
"""Unregister a dynamic provider.
Remove a dynamic provider, shutting down its instance and removing it from
the kvstore.
:param api: API namespace the provider implements
:param provider_id: ID of the provider to unregister.
"""
...
@webmethod(route="/admin/providers/{api}/{provider_id}/test", method="POST", level=LLAMA_STACK_API_V1)
async def test_provider_connection(self, api: str, provider_id: str) -> TestProviderConnectionResponse:
"""Test a provider connection.
Execute a health check on a provider to verify it is reachable and functioning.
:param api: API namespace the provider implements.
:param provider_id: ID of the provider to test.
:returns: TestProviderConnectionResponse with health status.
"""
...
@webmethod(route="/providers/{api}", method="GET", level=LLAMA_STACK_API_V1)
async def list_providers_for_api(self, api: str) -> ListProvidersResponse:
"""List providers for a specific API.
List all providers that implement a specific API.
:param api: The API namespace to filter by (e.g., 'inference', 'vector_io')
:returns: A ListProvidersResponse containing providers for the specified API.
"""
...
@webmethod(route="/providers/{api}/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
async def inspect_provider_for_api(self, api: str, provider_id: str) -> ProviderInfo:
"""Get provider for specific API.
Get detailed information about a specific provider for a specific API.
:param api: The API namespace.
:param provider_id: The ID of the provider to inspect.
:returns: A ProviderInfo object containing the provider's details.
"""

View file

@ -5,22 +5,45 @@
# the root directory of this source tree.
import asyncio
from datetime import UTC, datetime
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from llama_stack.apis.providers import (
ListProvidersResponse,
ProviderInfo,
Providers,
RegisterProviderResponse,
TestProviderConnectionResponse,
UpdateProviderResponse,
)
from llama_stack.apis.providers.connection import (
ProviderConnectionInfo,
ProviderConnectionStatus,
ProviderHealth,
)
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.core.resolver import ProviderWithSpec, instantiate_provider
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
from llama_stack.providers.datatypes import Api, HealthResponse, HealthStatus
from .datatypes import StackRunConfig
from .utils.config import redact_sensitive_fields
logger = get_logger(name=__name__, category="core")
# Storage constants for dynamic provider connections
# Use composite key format: provider_connections:v1::{api}::{provider_id}
# This allows the same provider_id to be used for different APIs
PROVIDER_CONNECTIONS_PREFIX = "provider_connections:v1::"
class ProviderImplConfig(BaseModel):
run_config: StackRunConfig
provider_registry: Any | None = None # ProviderRegistry from resolver
dist_registry: Any | None = None # DistributionRegistry
policy: list[Any] | None = None # list[AccessRule]
async def get_provider_impl(config, deps):
@ -33,19 +56,71 @@ class ProviderImpl(Providers):
def __init__(self, config, deps):
self.config = config
self.deps = deps
self.kvstore = None # KVStore for dynamic provider persistence
# Runtime cache uses composite key: "{api}::{provider_id}"
# This allows the same provider_id to be used for different APIs
self.dynamic_providers: dict[str, ProviderConnectionInfo] = {} # Runtime cache
self.dynamic_provider_impls: dict[str, Any] = {} # Initialized provider instances
# Store registry references for provider instantiation
self.provider_registry = config.provider_registry
self.dist_registry = config.dist_registry
self.policy = config.policy or []
async def initialize(self) -> None:
pass
# Initialize kvstore for dynamic providers
# Use the metadata store from the new storage config structure
if not (self.config.run_config.storage and self.config.run_config.storage.stores.metadata):
raise RuntimeError(
"No metadata store configured in storage.stores.metadata. "
"Provider management requires a configured metadata store (kv_memory, kv_sqlite, etc)."
)
from llama_stack.providers.utils.kvstore import kvstore_impl
self.kvstore = await kvstore_impl(self.config.run_config.storage.stores.metadata)
logger.info("Initialized kvstore for dynamic provider management")
# Load existing dynamic providers from kvstore
await self._load_dynamic_providers()
logger.info(f"Loaded {len(self.dynamic_providers)} existing dynamic providers from kvstore")
for provider_id, conn_info in self.dynamic_providers.items():
if conn_info.status == ProviderConnectionStatus.connected:
try:
impl = await self._instantiate_provider(conn_info)
self.dynamic_provider_impls[provider_id] = impl
except Exception as e:
logger.error(f"Failed to instantiate provider {provider_id}: {e}")
# Update status to failed
conn_info.status = ProviderConnectionStatus.failed
conn_info.error_message = str(e)
conn_info.updated_at = datetime.now(UTC)
await self._store_connection(conn_info)
async def shutdown(self) -> None:
logger.debug("ProviderImpl.shutdown")
pass
# Shutdown all dynamic provider instances
for provider_id, impl in self.dynamic_provider_impls.items():
try:
if hasattr(impl, "shutdown"):
await impl.shutdown()
logger.debug(f"Shutdown dynamic provider {provider_id}")
except Exception as e:
logger.warning(f"Error shutting down dynamic provider {provider_id}: {e}")
# Shutdown kvstore
if self.kvstore and hasattr(self.kvstore, "shutdown"):
await self.kvstore.shutdown()
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
providers_health = await self.get_providers_health()
ret = []
# Add static providers (from run.yaml)
for api, providers in safe_config.providers.items():
for p in providers:
# Skip providers that are not enabled
@ -66,15 +141,62 @@ class ProviderImpl(Providers):
)
)
# Add dynamic providers (from kvstore)
for _provider_id, conn_info in self.dynamic_providers.items():
# Redact sensitive config for API response
redacted_config = self._redact_sensitive_config(conn_info.config)
# Convert ProviderHealth to HealthResponse dict for API compatibility
health_dict: HealthResponse | None = None
if conn_info.health:
health_dict = HealthResponse(
status=conn_info.health.status,
message=conn_info.health.message,
)
if conn_info.health.metrics:
health_dict["metrics"] = conn_info.health.metrics
ret.append(
ProviderInfo(
api=conn_info.api,
provider_id=conn_info.provider_id,
provider_type=conn_info.provider_type,
config=redacted_config,
health=health_dict
or HealthResponse(status=HealthStatus.NOT_IMPLEMENTED, message="No health check available"),
)
)
return ListProvidersResponse(data=ret)
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
async def inspect_provider(self, provider_id: str) -> ListProvidersResponse:
"""Get all providers with the given provider_id (deprecated).
Returns all providers across all APIs that have this provider_id.
This is deprecated - use inspect_provider_for_api() for unambiguous access.
"""
all_providers = await self.list_providers()
matching = [p for p in all_providers.data if p.provider_id == provider_id]
if not matching:
raise ValueError(f"Provider {provider_id} not found")
return ListProvidersResponse(data=matching)
async def list_providers_for_api(self, api: str) -> ListProvidersResponse:
"""List providers for a specific API."""
all_providers = await self.list_providers()
filtered = [p for p in all_providers.data if p.api == api]
return ListProvidersResponse(data=filtered)
async def inspect_provider_for_api(self, api: str, provider_id: str) -> ProviderInfo:
"""Get a specific provider for a specific API."""
all_providers = await self.list_providers()
for p in all_providers.data:
if p.provider_id == provider_id:
if p.api == api and p.provider_id == provider_id:
return p
raise ValueError(f"Provider {provider_id} not found")
raise ValueError(f"Provider {provider_id} not found for API {api}")
async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
"""Get health status for all providers.
@ -135,3 +257,392 @@ class ProviderImpl(Providers):
providers_health[api_name] = health_response
return providers_health
# Storage helper methods for dynamic providers
async def _store_connection(self, info: ProviderConnectionInfo) -> None:
"""Store provider connection info in kvstore.
:param info: ProviderConnectionInfo to store
"""
if not self.kvstore:
raise RuntimeError("KVStore not initialized")
# Use composite key: provider_connections:v1::{api}::{provider_id}
key = f"{PROVIDER_CONNECTIONS_PREFIX}{info.api}::{info.provider_id}"
await self.kvstore.set(key, info.model_dump_json())
logger.debug(f"Stored provider connection: {info.api}::{info.provider_id}")
async def _load_connection(self, provider_id: str) -> ProviderConnectionInfo | None:
"""Load provider connection info from kvstore.
:param provider_id: Provider ID to load
:returns: ProviderConnectionInfo if found, None otherwise
"""
if not self.kvstore:
return None
key = f"{PROVIDER_CONNECTIONS_PREFIX}{provider_id}"
value = await self.kvstore.get(key)
if value:
return ProviderConnectionInfo.model_validate_json(value)
return None
async def _delete_connection(self, api: str, provider_id: str) -> None:
"""Delete provider connection from kvstore.
:param api: API namespace
:param provider_id: Provider ID to delete
"""
if not self.kvstore:
raise RuntimeError("KVStore not initialized")
# Use composite key: provider_connections:v1::{api}::{provider_id}
key = f"{PROVIDER_CONNECTIONS_PREFIX}{api}::{provider_id}"
await self.kvstore.delete(key)
logger.debug(f"Deleted provider connection: {api}::{provider_id}")
async def _list_connections(self) -> list[ProviderConnectionInfo]:
"""List all dynamic provider connections from kvstore.
:returns: List of ProviderConnectionInfo
"""
if not self.kvstore:
return []
start_key = PROVIDER_CONNECTIONS_PREFIX
end_key = f"{PROVIDER_CONNECTIONS_PREFIX}\xff"
values = await self.kvstore.values_in_range(start_key, end_key)
return [ProviderConnectionInfo.model_validate_json(v) for v in values]
async def _load_dynamic_providers(self) -> None:
"""Load dynamic providers from kvstore into runtime cache."""
connections = await self._list_connections()
for conn in connections:
# Use composite key for runtime cache
cache_key = f"{conn.api}::{conn.provider_id}"
self.dynamic_providers[cache_key] = conn
logger.debug(f"Loaded dynamic provider: {cache_key} (status: {conn.status})")
def _find_provider_cache_key(self, provider_id: str) -> str | None:
"""Find the cache key for a provider by its provider_id.
Since we use composite keys ({api}::{provider_id}), this searches for the matching key.
Returns None if not found.
"""
for key in self.dynamic_providers.keys():
if key.endswith(f"::{provider_id}"):
return key
return None
# Helper methods for dynamic provider management
def _redact_sensitive_config(self, config: dict[str, Any]) -> dict[str, Any]:
"""Redact sensitive fields in provider config for API responses.
:param config: Provider configuration dict
:returns: Config with sensitive fields redacted
"""
return redact_sensitive_fields(config)
async def _instantiate_provider(self, conn_info: ProviderConnectionInfo) -> Any:
"""Instantiate a provider from connection info.
Uses the resolver's instantiate_provider() to create a provider instance
with all necessary dependencies.
:param conn_info: Provider connection information
:returns: Instantiated provider implementation
:raises RuntimeError: If provider cannot be instantiated
"""
if not self.provider_registry:
raise RuntimeError("Provider registry not available for provider instantiation")
if not self.dist_registry:
raise RuntimeError("Distribution registry not available for provider instantiation")
# Get provider spec from registry
api = Api(conn_info.api)
if api not in self.provider_registry:
raise ValueError(f"API {conn_info.api} not found in provider registry")
if conn_info.provider_type not in self.provider_registry[api]:
raise ValueError(f"Provider type {conn_info.provider_type} not found for API {conn_info.api}")
provider_spec = self.provider_registry[api][conn_info.provider_type]
# Create ProviderWithSpec for instantiation
provider_with_spec = ProviderWithSpec(
provider_id=conn_info.provider_id,
provider_type=conn_info.provider_type,
config=conn_info.config,
spec=provider_spec,
)
# Resolve dependencies
deps = {}
for dep_api in provider_spec.api_dependencies:
if dep_api not in self.deps:
raise RuntimeError(
f"Required dependency {dep_api.value} not available for provider {conn_info.provider_id}"
)
deps[dep_api] = self.deps[dep_api]
# Add optional dependencies if available
for dep_api in provider_spec.optional_api_dependencies:
if dep_api in self.deps:
deps[dep_api] = self.deps[dep_api]
# Instantiate provider using resolver
impl = await instantiate_provider(
provider_with_spec,
deps,
{}, # inner_impls (empty for dynamic providers)
self.dist_registry,
self.config.run_config,
self.policy,
)
logger.debug(f"Instantiated provider {conn_info.provider_id} (type={conn_info.provider_type})")
return impl
# Dynamic Provider Management Methods
async def register_provider(
self,
api: str,
provider_id: str,
provider_type: str,
config: dict[str, Any],
attributes: dict[str, list[str]] | None = None,
) -> RegisterProviderResponse:
"""Register a new provider.
This is used both for:
- Providers from run.yaml (registered at startup)
- Providers registered via API (registered at runtime)
All providers are stored in kvstore and treated equally.
"""
if not self.kvstore:
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
# Use composite key to allow same provider_id for different APIs
cache_key = f"{api}::{provider_id}"
# Check if provider already exists for this API
if cache_key in self.dynamic_providers:
raise ValueError(f"Provider {provider_id} already exists for API {api}")
# Get authenticated user as owner
user = get_authenticated_user()
# Create ProviderConnectionInfo
now = datetime.now(UTC)
conn_info = ProviderConnectionInfo(
provider_id=provider_id,
api=api,
provider_type=provider_type,
config=config,
status=ProviderConnectionStatus.initializing,
created_at=now,
updated_at=now,
owner=user,
attributes=attributes,
)
try:
# Store in kvstore
await self._store_connection(conn_info)
impl = await self._instantiate_provider(conn_info)
# Use composite key for impl cache too
self.dynamic_provider_impls[cache_key] = impl
# Update status to connected after successful instantiation
conn_info.status = ProviderConnectionStatus.connected
conn_info.updated_at = datetime.now(UTC)
logger.info(f"Registered and instantiated dynamic provider {provider_id} (api={api}, type={provider_type})")
# Store updated status
await self._store_connection(conn_info)
# Add to runtime cache using composite key
self.dynamic_providers[cache_key] = conn_info
return RegisterProviderResponse(provider=conn_info)
except Exception as e:
# Mark as failed and store
conn_info.status = ProviderConnectionStatus.failed
conn_info.error_message = str(e)
conn_info.updated_at = datetime.now(UTC)
await self._store_connection(conn_info)
self.dynamic_providers[cache_key] = conn_info
logger.error(f"Failed to register provider {provider_id}: {e}")
raise RuntimeError(f"Failed to register provider: {e}") from e
async def update_provider(
self,
api: str,
provider_id: str,
config: dict[str, Any] | None = None,
attributes: dict[str, list[str]] | None = None,
) -> UpdateProviderResponse:
"""Update an existing provider's configuration.
Updates persist to kvstore and survive server restarts.
This works for all providers (whether originally from run.yaml or API).
"""
if not self.kvstore:
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
# Use composite key
cache_key = f"{api}::{provider_id}"
if cache_key not in self.dynamic_providers:
raise ValueError(f"Provider {provider_id} not found for API {api}")
conn_info = self.dynamic_providers[cache_key]
# Update config if provided
if config is not None:
conn_info.config.update(config)
# Update attributes if provided
if attributes is not None:
conn_info.attributes = attributes
conn_info.updated_at = datetime.now(UTC)
conn_info.status = ProviderConnectionStatus.initializing
try:
# Store updated config
await self._store_connection(conn_info)
# Hot-reload: Shutdown old instance and reinstantiate with new config
# Shutdown old instance if it exists
if cache_key in self.dynamic_provider_impls:
old_impl = self.dynamic_provider_impls[cache_key]
if hasattr(old_impl, "shutdown"):
try:
await old_impl.shutdown()
logger.debug(f"Shutdown old instance of provider {provider_id}")
except Exception as e:
logger.warning(f"Error shutting down old instance of {provider_id}: {e}")
# Reinstantiate with new config
impl = await self._instantiate_provider(conn_info)
self.dynamic_provider_impls[cache_key] = impl
# Update status to connected after successful reinstantiation
conn_info.status = ProviderConnectionStatus.connected
conn_info.updated_at = datetime.now(UTC)
await self._store_connection(conn_info)
logger.info(f"Hot-reloaded dynamic provider {provider_id}")
return UpdateProviderResponse(provider=conn_info)
except Exception as e:
conn_info.status = ProviderConnectionStatus.failed
conn_info.error_message = str(e)
conn_info.updated_at = datetime.now(UTC)
await self._store_connection(conn_info)
logger.error(f"Failed to update provider {provider_id}: {e}")
raise RuntimeError(f"Failed to update provider: {e}") from e
async def unregister_provider(self, api: str, provider_id: str) -> None:
"""Unregister a provider.
Removes the provider from kvstore and shuts down its instance.
This works for all providers (whether originally from run.yaml or API).
"""
if not self.kvstore:
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
# Use composite key
cache_key = f"{api}::{provider_id}"
if cache_key not in self.dynamic_providers:
raise ValueError(f"Provider {provider_id} not found for API {api}")
conn_info = self.dynamic_providers[cache_key]
try:
# Shutdown provider instance if it exists
if cache_key in self.dynamic_provider_impls:
impl = self.dynamic_provider_impls[cache_key]
if hasattr(impl, "shutdown"):
await impl.shutdown()
del self.dynamic_provider_impls[cache_key]
# Remove from kvstore (using the api and provider_id from conn_info)
await self._delete_connection(conn_info.api, provider_id)
# Remove from runtime cache
del self.dynamic_providers[cache_key]
logger.info(f"Unregistered dynamic provider {provider_id}")
except Exception as e:
logger.error(f"Failed to unregister provider {provider_id}: {e}")
raise RuntimeError(f"Failed to unregister provider: {e}") from e
async def test_provider_connection(self, api: str, provider_id: str) -> TestProviderConnectionResponse:
"""Test a provider connection."""
# Check if provider exists (static or dynamic)
provider_impl = None
cache_key = f"{api}::{provider_id}"
# Check dynamic providers first (using composite keys)
if cache_key in self.dynamic_provider_impls:
provider_impl = self.dynamic_provider_impls[cache_key]
# Check static providers
if not provider_impl and provider_id in self.deps:
provider_impl = self.deps[provider_id]
if not provider_impl:
return TestProviderConnectionResponse(
success=False, error_message=f"Provider {provider_id} not found for API {api}"
)
# Check if provider has health method
if not hasattr(provider_impl, "health"):
return TestProviderConnectionResponse(
success=False,
health=HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
),
)
# Call health check
try:
health_result = await asyncio.wait_for(provider_impl.health(), timeout=5.0)
# Update health in dynamic provider cache if applicable
if cache_key and cache_key in self.dynamic_providers:
conn_info = self.dynamic_providers[cache_key]
conn_info.health = ProviderHealth.from_health_response(health_result)
conn_info.last_health_check = datetime.now(UTC)
await self._store_connection(conn_info)
logger.debug(f"Tested provider connection {provider_id}: status={health_result.get('status', 'UNKNOWN')}")
return TestProviderConnectionResponse(
success=health_result.get("status") == HealthStatus.OK,
health=health_result,
)
except TimeoutError:
health = HealthResponse(status=HealthStatus.ERROR, message="Health check timed out after 5 seconds")
return TestProviderConnectionResponse(success=False, health=health)
except Exception as e:
health = HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
return TestProviderConnectionResponse(success=False, health=health, error_message=str(e))

View file

@ -34,13 +34,20 @@ from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.distribution import builtin_automatically_routed_apis, get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
from llama_stack.core.resolver import ProviderRegistry, resolve_impls
from llama_stack.core.resolver import (
ProviderRegistry,
instantiate_provider,
sort_providers_by_deps,
specs_for_autorouted_apis,
validate_and_prepare_providers,
)
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
@ -52,10 +59,12 @@ from llama_stack.core.storage.datatypes import (
StorageBackendConfig,
StorageConfig,
)
from llama_stack.core.store.registry import create_dist_registry
from llama_stack.core.store.registry import DistributionRegistry, create_dist_registry
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
logger = get_logger(name=__name__, category="core")
@ -341,12 +350,21 @@ def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
def add_internal_implementations(
impls: dict[Api, Any],
run_config: StackRunConfig,
provider_registry=None,
dist_registry=None,
policy=None,
) -> None:
"""Add internal implementations (inspect and providers) to the implementations dictionary.
Args:
impls: Dictionary of API implementations
run_config: Stack run configuration
provider_registry: Provider registry for dynamic provider instantiation
dist_registry: Distribution registry
policy: Access control policy
"""
inspect_impl = DistributionInspectImpl(
DistributionInspectConfig(run_config=run_config),
@ -355,7 +373,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
impls[Api.inspect] = inspect_impl
providers_impl = ProviderImpl(
ProviderImplConfig(run_config=run_config),
ProviderImplConfig(
run_config=run_config,
provider_registry=provider_registry,
dist_registry=dist_registry,
policy=policy,
),
deps=impls,
)
impls[Api.providers] = providers_impl
@ -385,13 +408,179 @@ def _initialize_storage(run_config: StackRunConfig):
else:
raise ValueError(f"Unknown storage backend type: {type}")
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
register_kvstore_backends(kv_backends)
register_sqlstore_backends(sql_backends)
async def resolve_impls_via_provider_registration(
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
policy: list[AccessRule],
internal_impls: dict[Api, Any],
) -> dict[Api, Any]:
"""
Resolves provider implementations by registering them through ProviderImpl.
This ensures all providers (startup and runtime) go through the same registration code path.
Args:
run_config: Stack run configuration with providers from run.yaml
provider_registry: Registry of available provider types
dist_registry: Distribution registry
policy: Access control policy
internal_impls: Internal implementations (inspect, providers) already initialized
Returns:
Dictionary mapping API to implementation instances
"""
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
# Validate and prepare providers from run.yaml
providers_with_specs = validate_and_prepare_providers(
run_config, provider_registry, routing_table_apis, router_apis
)
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
)
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
# Sort providers in dependency order
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
# Get the ProviderImpl instance
providers_impl = internal_impls[Api.providers]
# Register each provider through ProviderImpl
impls = internal_impls.copy()
logger.info(f"Provider registration for {len(sorted_providers)} providers from run.yaml")
for api_str, provider in sorted_providers:
# Skip providers that are not enabled
if provider.provider_id is None:
continue
# Skip internal APIs (already initialized)
if api_str in ["providers", "inspect"]:
continue
# Handle different provider types
try:
# Check if this is a router (system infrastructure)
is_router = not api_str.startswith("inner-") and (
Api(api_str) in router_apis or provider.spec.provider_type == "router"
)
if api_str.startswith("inner-") or provider.spec.provider_type == "routing_table":
# Inner providers or routing tables cannot be registered through the API
# They need to be instantiated directly
logger.info(f"Instantiating {provider.provider_id} for {api_str}")
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]
# Get inner impls if available
inner_impls = {}
# For routing tables of autorouted APIs, get inner impls from the router API
# E.g., tool_groups routing table needs inner-tool_runtime providers
if provider.spec.provider_type == "routing_table":
autorouted_map = {
info.routing_table_api: info.router_api for info in builtin_automatically_routed_apis()
}
if Api(api_str) in autorouted_map:
router_api_str = autorouted_map[Api(api_str)].value
inner_key = f"inner-{router_api_str}"
if inner_key in impls:
inner_impls = impls[inner_key]
else:
# For regular inner providers, use their own inner key
inner_key = f"inner-{api_str}"
if inner_key in impls:
inner_impls = impls[inner_key]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
# Store appropriately
if api_str.startswith("inner-"):
if api_str not in impls:
impls[api_str] = {}
impls[api_str][provider.provider_id] = impl
else:
api = Api(api_str)
impls[api] = impl
# Update providers_impl.deps so subsequent providers can depend on this
providers_impl.deps[api] = impl
elif is_router:
# Router providers also need special handling
logger.info(f"Instantiating router {provider.provider_id} for {api_str}")
deps = {a: impls[a] for a in provider.spec.api_dependencies if a in impls}
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]
# Get inner impls if this is a router
inner_impls = {}
inner_key = f"inner-{api_str}"
if inner_key in impls:
inner_impls = impls[inner_key]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
api = Api(api_str)
impls[api] = impl
# Update providers_impl.deps so subsequent providers can depend on this
providers_impl.deps[api] = impl
else:
# Regular providers - register through ProviderImpl
api = Api(api_str)
cache_key = f"{api.value}::{provider.provider_id}"
# Check if provider already exists (loaded from kvstore during initialization)
if cache_key in providers_impl.dynamic_providers:
logger.info(
f"Provider {provider.provider_id} for {api.value} already exists, using existing instance"
)
impl = providers_impl.dynamic_provider_impls.get(cache_key)
if impl is None:
# Provider exists but not instantiated, instantiate it
conn_info = providers_impl.dynamic_providers[cache_key]
impl = await providers_impl._instantiate_provider(conn_info)
providers_impl.dynamic_provider_impls[cache_key] = impl
else:
logger.info(f"Registering {provider.provider_id} for {api.value}")
await providers_impl.register_provider(
api=api.value,
provider_id=provider.provider_id,
provider_type=provider.spec.provider_type,
config=provider.config,
attributes=getattr(provider, "attributes", None),
)
# Get the instantiated impl from dynamic_provider_impls using composite key
impl = providers_impl.dynamic_provider_impls[cache_key]
logger.info(f"Successfully registered startup provider: {provider.provider_id}")
impls[api] = impl
# IMPORTANT: Update providers_impl.deps so subsequent providers can depend on this one
providers_impl.deps[api] = impl
except Exception as e:
logger.error(f"Failed to handle provider {provider.provider_id}: {e}")
raise
return impls
class Stack:
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
self.run_config = run_config
@ -416,13 +605,24 @@ class Stack:
raise ValueError("storage.stores.metadata must be configured with a kv_* backend")
dist_registry, _ = await create_dist_registry(stores.metadata, self.run_config.image_name)
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
provider_registry = self.provider_registry or get_provider_registry(self.run_config)
internal_impls = {}
add_internal_implementations(internal_impls, self.run_config)
impls = await resolve_impls(
add_internal_implementations(
internal_impls,
self.run_config,
self.provider_registry or get_provider_registry(self.run_config),
provider_registry=provider_registry,
dist_registry=dist_registry,
policy=policy,
)
# Initialize the ProviderImpl so it has access to kvstore
await internal_impls[Api.providers].initialize()
# Register all providers from run.yaml through ProviderImpl
impls = await resolve_impls_via_provider_registration(
self.run_config,
provider_registry,
dist_registry,
policy,
internal_impls,