This commit is contained in:
Raghotham Murthy 2025-10-27 10:52:20 -07:00
parent e21db79d6c
commit 13b6f3df65
12 changed files with 1238 additions and 605 deletions

View file

@ -79,29 +79,24 @@ class ProviderImpl(Providers):
from llama_stack.providers.utils.kvstore import kvstore_impl
self.kvstore = await kvstore_impl(self.config.run_config.storage.stores.metadata)
logger.info("Initialized kvstore for dynamic provider management")
logger.info("Initialized kvstore for dynamic provider management")
# Load existing dynamic providers from kvstore
await self._load_dynamic_providers()
logger.info(f"📦 Loaded {len(self.dynamic_providers)} existing dynamic providers from kvstore")
logger.info(f"Loaded {len(self.dynamic_providers)} existing dynamic providers from kvstore")
# Auto-instantiate connected providers on startup
if self.provider_registry:
for provider_id, conn_info in self.dynamic_providers.items():
if conn_info.status == ProviderConnectionStatus.connected:
try:
impl = await self._instantiate_provider(conn_info)
self.dynamic_provider_impls[provider_id] = impl
logger.info(f"♻️ Auto-instantiated provider {provider_id} from kvstore")
except Exception as e:
logger.error(f"Failed to auto-instantiate provider {provider_id}: {e}")
# Update status to failed
conn_info.status = ProviderConnectionStatus.failed
conn_info.error_message = str(e)
conn_info.updated_at = datetime.now(UTC)
await self._store_connection(conn_info)
else:
logger.warning("Provider registry not available, skipping auto-instantiation")
for provider_id, conn_info in self.dynamic_providers.items():
if conn_info.status == ProviderConnectionStatus.connected:
try:
impl = await self._instantiate_provider(conn_info)
self.dynamic_provider_impls[provider_id] = impl
except Exception as e:
logger.error(f"Failed to instantiate provider {provider_id}: {e}")
# Update status to failed
conn_info.status = ProviderConnectionStatus.failed
conn_info.error_message = str(e)
conn_info.updated_at = datetime.now(UTC)
await self._store_connection(conn_info)
async def shutdown(self) -> None:
logger.debug("ProviderImpl.shutdown")
@ -174,13 +169,34 @@ class ProviderImpl(Providers):
return ListProvidersResponse(data=ret)
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
async def inspect_provider(self, provider_id: str) -> ListProvidersResponse:
"""Get all providers with the given provider_id (deprecated).
Returns all providers across all APIs that have this provider_id.
This is deprecated - use inspect_provider_for_api() for unambiguous access.
"""
all_providers = await self.list_providers()
matching = [p for p in all_providers.data if p.provider_id == provider_id]
if not matching:
raise ValueError(f"Provider {provider_id} not found")
return ListProvidersResponse(data=matching)
async def list_providers_for_api(self, api: str) -> ListProvidersResponse:
"""List providers for a specific API."""
all_providers = await self.list_providers()
filtered = [p for p in all_providers.data if p.api == api]
return ListProvidersResponse(data=filtered)
async def inspect_provider_for_api(self, api: str, provider_id: str) -> ProviderInfo:
"""Get a specific provider for a specific API."""
all_providers = await self.list_providers()
for p in all_providers.data:
if p.provider_id == provider_id:
if p.api == api and p.provider_id == provider_id:
return p
raise ValueError(f"Provider {provider_id} not found")
raise ValueError(f"Provider {provider_id} not found for API {api}")
async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
"""Get health status for all providers.
@ -272,17 +288,19 @@ class ProviderImpl(Providers):
return ProviderConnectionInfo.model_validate_json(value)
return None
async def _delete_connection(self, provider_id: str) -> None:
async def _delete_connection(self, api: str, provider_id: str) -> None:
"""Delete provider connection from kvstore.
:param api: API namespace
:param provider_id: Provider ID to delete
"""
if not self.kvstore:
raise RuntimeError("KVStore not initialized")
key = f"{PROVIDER_CONNECTIONS_PREFIX}{provider_id}"
# Use composite key: provider_connections:v1::{api}::{provider_id}
key = f"{PROVIDER_CONNECTIONS_PREFIX}{api}::{provider_id}"
await self.kvstore.delete(key)
logger.debug(f"Deleted provider connection: {provider_id}")
logger.debug(f"Deleted provider connection: {api}::{provider_id}")
async def _list_connections(self) -> list[ProviderConnectionInfo]:
"""List all dynamic provider connections from kvstore.
@ -306,6 +324,17 @@ class ProviderImpl(Providers):
self.dynamic_providers[cache_key] = conn
logger.debug(f"Loaded dynamic provider: {cache_key} (status: {conn.status})")
def _find_provider_cache_key(self, provider_id: str) -> str | None:
"""Find the cache key for a provider by its provider_id.
Since we use composite keys ({api}::{provider_id}), this searches for the matching key.
Returns None if not found.
"""
for key in self.dynamic_providers.keys():
if key.endswith(f"::{provider_id}"):
return key
return None
# Helper methods for dynamic provider management
def _redact_sensitive_config(self, config: dict[str, Any]) -> dict[str, Any]:
@ -380,8 +409,8 @@ class ProviderImpl(Providers):
async def register_provider(
self,
provider_id: str,
api: str,
provider_id: str,
provider_type: str,
config: dict[str, Any],
attributes: dict[str, list[str]] | None = None,
@ -394,7 +423,6 @@ class ProviderImpl(Providers):
All providers are stored in kvstore and treated equally.
"""
logger.info(f"📝 REGISTER_PROVIDER called: provider_id={provider_id}, api={api}, type={provider_type}")
if not self.kvstore:
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
@ -427,25 +455,15 @@ class ProviderImpl(Providers):
# Store in kvstore
await self._store_connection(conn_info)
# Instantiate provider if we have a provider registry
if self.provider_registry:
impl = await self._instantiate_provider(conn_info)
# Use composite key for impl cache too
self.dynamic_provider_impls[cache_key] = impl
impl = await self._instantiate_provider(conn_info)
# Use composite key for impl cache too
self.dynamic_provider_impls[cache_key] = impl
# Update status to connected after successful instantiation
conn_info.status = ProviderConnectionStatus.connected
conn_info.updated_at = datetime.now(UTC)
# Update status to connected after successful instantiation
conn_info.status = ProviderConnectionStatus.connected
conn_info.updated_at = datetime.now(UTC)
logger.info(
f"Registered and instantiated dynamic provider {provider_id} (api={api}, type={provider_type})"
)
else:
# No registry available - just mark as connected without instantiation
# This can happen during testing or if provider management is disabled
conn_info.status = ProviderConnectionStatus.connected
conn_info.updated_at = datetime.now(UTC)
logger.warning(f"Registered provider {provider_id} without instantiation (no registry)")
logger.info(f"Registered and instantiated dynamic provider {provider_id} (api={api}, type={provider_type})")
# Store updated status
await self._store_connection(conn_info)
@ -468,6 +486,7 @@ class ProviderImpl(Providers):
async def update_provider(
self,
api: str,
provider_id: str,
config: dict[str, Any] | None = None,
attributes: dict[str, list[str]] | None = None,
@ -477,16 +496,16 @@ class ProviderImpl(Providers):
Updates persist to kvstore and survive server restarts.
This works for all providers (whether originally from run.yaml or API).
"""
logger.info(f"🔄 UPDATE_PROVIDER called: provider_id={provider_id}, has_config={config is not None}, has_attributes={attributes is not None}")
if not self.kvstore:
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
# Check if provider exists
if provider_id not in self.dynamic_providers:
raise ValueError(f"Provider {provider_id} not found")
# Use composite key
cache_key = f"{api}::{provider_id}"
if cache_key not in self.dynamic_providers:
raise ValueError(f"Provider {provider_id} not found for API {api}")
conn_info = self.dynamic_providers[provider_id]
conn_info = self.dynamic_providers[cache_key]
# Update config if provided
if config is not None:
@ -504,33 +523,26 @@ class ProviderImpl(Providers):
await self._store_connection(conn_info)
# Hot-reload: Shutdown old instance and reinstantiate with new config
if self.provider_registry:
# Shutdown old instance if it exists
if provider_id in self.dynamic_provider_impls:
old_impl = self.dynamic_provider_impls[provider_id]
if hasattr(old_impl, "shutdown"):
try:
await old_impl.shutdown()
logger.debug(f"Shutdown old instance of provider {provider_id}")
except Exception as e:
logger.warning(f"Error shutting down old instance of {provider_id}: {e}")
# Shutdown old instance if it exists
if cache_key in self.dynamic_provider_impls:
old_impl = self.dynamic_provider_impls[cache_key]
if hasattr(old_impl, "shutdown"):
try:
await old_impl.shutdown()
logger.debug(f"Shutdown old instance of provider {provider_id}")
except Exception as e:
logger.warning(f"Error shutting down old instance of {provider_id}: {e}")
# Reinstantiate with new config
impl = await self._instantiate_provider(conn_info)
self.dynamic_provider_impls[provider_id] = impl
# Reinstantiate with new config
impl = await self._instantiate_provider(conn_info)
self.dynamic_provider_impls[cache_key] = impl
# Update status to connected after successful reinstantiation
conn_info.status = ProviderConnectionStatus.connected
conn_info.updated_at = datetime.now(UTC)
await self._store_connection(conn_info)
# Update status to connected after successful reinstantiation
conn_info.status = ProviderConnectionStatus.connected
conn_info.updated_at = datetime.now(UTC)
await self._store_connection(conn_info)
logger.info(f"Hot-reloaded dynamic provider {provider_id}")
else:
# No registry - just update config without reinstantiation
conn_info.status = ProviderConnectionStatus.connected
conn_info.updated_at = datetime.now(UTC)
await self._store_connection(conn_info)
logger.warning(f"Updated provider {provider_id} config without hot-reload (no registry)")
logger.info(f"Hot-reloaded dynamic provider {provider_id}")
return UpdateProviderResponse(provider=conn_info)
@ -543,34 +555,36 @@ class ProviderImpl(Providers):
logger.error(f"Failed to update provider {provider_id}: {e}")
raise RuntimeError(f"Failed to update provider: {e}") from e
async def unregister_provider(self, provider_id: str) -> None:
async def unregister_provider(self, api: str, provider_id: str) -> None:
"""Unregister a provider.
Removes the provider from kvstore and shuts down its instance.
This works for all providers (whether originally from run.yaml or API).
"""
logger.info(f"🗑️ UNREGISTER_PROVIDER called: provider_id={provider_id}")
if not self.kvstore:
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
# Check if provider exists
if provider_id not in self.dynamic_providers:
raise ValueError(f"Provider {provider_id} not found")
# Use composite key
cache_key = f"{api}::{provider_id}"
if cache_key not in self.dynamic_providers:
raise ValueError(f"Provider {provider_id} not found for API {api}")
conn_info = self.dynamic_providers[cache_key]
try:
# Shutdown provider instance if it exists
if provider_id in self.dynamic_provider_impls:
impl = self.dynamic_provider_impls[provider_id]
if cache_key in self.dynamic_provider_impls:
impl = self.dynamic_provider_impls[cache_key]
if hasattr(impl, "shutdown"):
await impl.shutdown()
del self.dynamic_provider_impls[provider_id]
del self.dynamic_provider_impls[cache_key]
# Remove from kvstore
await self._delete_connection(provider_id)
# Remove from kvstore (using the api and provider_id from conn_info)
await self._delete_connection(conn_info.api, provider_id)
# Remove from runtime cache
del self.dynamic_providers[provider_id]
del self.dynamic_providers[cache_key]
logger.info(f"Unregistered dynamic provider {provider_id}")
@ -578,23 +592,24 @@ class ProviderImpl(Providers):
logger.error(f"Failed to unregister provider {provider_id}: {e}")
raise RuntimeError(f"Failed to unregister provider: {e}") from e
async def test_provider_connection(self, provider_id: str) -> TestProviderConnectionResponse:
async def test_provider_connection(self, api: str, provider_id: str) -> TestProviderConnectionResponse:
"""Test a provider connection."""
logger.info(f"🔍 TEST_PROVIDER_CONNECTION called: provider_id={provider_id}")
# Check if provider exists (static or dynamic)
provider_impl = None
cache_key = f"{api}::{provider_id}"
# Check dynamic providers first (using composite keys)
if cache_key in self.dynamic_provider_impls:
provider_impl = self.dynamic_provider_impls[cache_key]
# Check dynamic providers first
if provider_id in self.dynamic_provider_impls:
provider_impl = self.dynamic_provider_impls[provider_id]
# Check static providers
elif provider_id in self.deps:
if not provider_impl and provider_id in self.deps:
provider_impl = self.deps[provider_id]
if not provider_impl:
return TestProviderConnectionResponse(
success=False, error_message=f"Provider {provider_id} not found or not initialized"
success=False, error_message=f"Provider {provider_id} not found for API {api}"
)
# Check if provider has health method
@ -611,8 +626,8 @@ class ProviderImpl(Providers):
health_result = await asyncio.wait_for(provider_impl.health(), timeout=5.0)
# Update health in dynamic provider cache if applicable
if provider_id in self.dynamic_providers:
conn_info = self.dynamic_providers[provider_id]
if cache_key and cache_key in self.dynamic_providers:
conn_info = self.dynamic_providers[cache_key]
conn_info.health = ProviderHealth.from_health_response(health_result)
conn_info.last_health_check = datetime.now(UTC)
await self._store_connection(conn_info)