mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
updates
This commit is contained in:
parent
e21db79d6c
commit
13b6f3df65
12 changed files with 1238 additions and 605 deletions
|
|
@ -79,29 +79,24 @@ class ProviderImpl(Providers):
|
|||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
self.kvstore = await kvstore_impl(self.config.run_config.storage.stores.metadata)
|
||||
logger.info("✅ Initialized kvstore for dynamic provider management")
|
||||
logger.info("Initialized kvstore for dynamic provider management")
|
||||
|
||||
# Load existing dynamic providers from kvstore
|
||||
await self._load_dynamic_providers()
|
||||
logger.info(f"📦 Loaded {len(self.dynamic_providers)} existing dynamic providers from kvstore")
|
||||
logger.info(f"Loaded {len(self.dynamic_providers)} existing dynamic providers from kvstore")
|
||||
|
||||
# Auto-instantiate connected providers on startup
|
||||
if self.provider_registry:
|
||||
for provider_id, conn_info in self.dynamic_providers.items():
|
||||
if conn_info.status == ProviderConnectionStatus.connected:
|
||||
try:
|
||||
impl = await self._instantiate_provider(conn_info)
|
||||
self.dynamic_provider_impls[provider_id] = impl
|
||||
logger.info(f"♻️ Auto-instantiated provider {provider_id} from kvstore")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to auto-instantiate provider {provider_id}: {e}")
|
||||
# Update status to failed
|
||||
conn_info.status = ProviderConnectionStatus.failed
|
||||
conn_info.error_message = str(e)
|
||||
conn_info.updated_at = datetime.now(UTC)
|
||||
await self._store_connection(conn_info)
|
||||
else:
|
||||
logger.warning("Provider registry not available, skipping auto-instantiation")
|
||||
for provider_id, conn_info in self.dynamic_providers.items():
|
||||
if conn_info.status == ProviderConnectionStatus.connected:
|
||||
try:
|
||||
impl = await self._instantiate_provider(conn_info)
|
||||
self.dynamic_provider_impls[provider_id] = impl
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to instantiate provider {provider_id}: {e}")
|
||||
# Update status to failed
|
||||
conn_info.status = ProviderConnectionStatus.failed
|
||||
conn_info.error_message = str(e)
|
||||
conn_info.updated_at = datetime.now(UTC)
|
||||
await self._store_connection(conn_info)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("ProviderImpl.shutdown")
|
||||
|
|
@ -174,13 +169,34 @@ class ProviderImpl(Providers):
|
|||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
async def inspect_provider(self, provider_id: str) -> ListProvidersResponse:
|
||||
"""Get all providers with the given provider_id (deprecated).
|
||||
|
||||
Returns all providers across all APIs that have this provider_id.
|
||||
This is deprecated - use inspect_provider_for_api() for unambiguous access.
|
||||
"""
|
||||
all_providers = await self.list_providers()
|
||||
matching = [p for p in all_providers.data if p.provider_id == provider_id]
|
||||
|
||||
if not matching:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
return ListProvidersResponse(data=matching)
|
||||
|
||||
async def list_providers_for_api(self, api: str) -> ListProvidersResponse:
|
||||
"""List providers for a specific API."""
|
||||
all_providers = await self.list_providers()
|
||||
filtered = [p for p in all_providers.data if p.api == api]
|
||||
return ListProvidersResponse(data=filtered)
|
||||
|
||||
async def inspect_provider_for_api(self, api: str, provider_id: str) -> ProviderInfo:
|
||||
"""Get a specific provider for a specific API."""
|
||||
all_providers = await self.list_providers()
|
||||
for p in all_providers.data:
|
||||
if p.provider_id == provider_id:
|
||||
if p.api == api and p.provider_id == provider_id:
|
||||
return p
|
||||
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
raise ValueError(f"Provider {provider_id} not found for API {api}")
|
||||
|
||||
async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
|
||||
"""Get health status for all providers.
|
||||
|
|
@ -272,17 +288,19 @@ class ProviderImpl(Providers):
|
|||
return ProviderConnectionInfo.model_validate_json(value)
|
||||
return None
|
||||
|
||||
async def _delete_connection(self, provider_id: str) -> None:
|
||||
async def _delete_connection(self, api: str, provider_id: str) -> None:
|
||||
"""Delete provider connection from kvstore.
|
||||
|
||||
:param api: API namespace
|
||||
:param provider_id: Provider ID to delete
|
||||
"""
|
||||
if not self.kvstore:
|
||||
raise RuntimeError("KVStore not initialized")
|
||||
|
||||
key = f"{PROVIDER_CONNECTIONS_PREFIX}{provider_id}"
|
||||
# Use composite key: provider_connections:v1::{api}::{provider_id}
|
||||
key = f"{PROVIDER_CONNECTIONS_PREFIX}{api}::{provider_id}"
|
||||
await self.kvstore.delete(key)
|
||||
logger.debug(f"Deleted provider connection: {provider_id}")
|
||||
logger.debug(f"Deleted provider connection: {api}::{provider_id}")
|
||||
|
||||
async def _list_connections(self) -> list[ProviderConnectionInfo]:
|
||||
"""List all dynamic provider connections from kvstore.
|
||||
|
|
@ -306,6 +324,17 @@ class ProviderImpl(Providers):
|
|||
self.dynamic_providers[cache_key] = conn
|
||||
logger.debug(f"Loaded dynamic provider: {cache_key} (status: {conn.status})")
|
||||
|
||||
def _find_provider_cache_key(self, provider_id: str) -> str | None:
|
||||
"""Find the cache key for a provider by its provider_id.
|
||||
|
||||
Since we use composite keys ({api}::{provider_id}), this searches for the matching key.
|
||||
Returns None if not found.
|
||||
"""
|
||||
for key in self.dynamic_providers.keys():
|
||||
if key.endswith(f"::{provider_id}"):
|
||||
return key
|
||||
return None
|
||||
|
||||
# Helper methods for dynamic provider management
|
||||
|
||||
def _redact_sensitive_config(self, config: dict[str, Any]) -> dict[str, Any]:
|
||||
|
|
@ -380,8 +409,8 @@ class ProviderImpl(Providers):
|
|||
|
||||
async def register_provider(
|
||||
self,
|
||||
provider_id: str,
|
||||
api: str,
|
||||
provider_id: str,
|
||||
provider_type: str,
|
||||
config: dict[str, Any],
|
||||
attributes: dict[str, list[str]] | None = None,
|
||||
|
|
@ -394,7 +423,6 @@ class ProviderImpl(Providers):
|
|||
|
||||
All providers are stored in kvstore and treated equally.
|
||||
"""
|
||||
logger.info(f"📝 REGISTER_PROVIDER called: provider_id={provider_id}, api={api}, type={provider_type}")
|
||||
|
||||
if not self.kvstore:
|
||||
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
|
||||
|
|
@ -427,25 +455,15 @@ class ProviderImpl(Providers):
|
|||
# Store in kvstore
|
||||
await self._store_connection(conn_info)
|
||||
|
||||
# Instantiate provider if we have a provider registry
|
||||
if self.provider_registry:
|
||||
impl = await self._instantiate_provider(conn_info)
|
||||
# Use composite key for impl cache too
|
||||
self.dynamic_provider_impls[cache_key] = impl
|
||||
impl = await self._instantiate_provider(conn_info)
|
||||
# Use composite key for impl cache too
|
||||
self.dynamic_provider_impls[cache_key] = impl
|
||||
|
||||
# Update status to connected after successful instantiation
|
||||
conn_info.status = ProviderConnectionStatus.connected
|
||||
conn_info.updated_at = datetime.now(UTC)
|
||||
# Update status to connected after successful instantiation
|
||||
conn_info.status = ProviderConnectionStatus.connected
|
||||
conn_info.updated_at = datetime.now(UTC)
|
||||
|
||||
logger.info(
|
||||
f"Registered and instantiated dynamic provider {provider_id} (api={api}, type={provider_type})"
|
||||
)
|
||||
else:
|
||||
# No registry available - just mark as connected without instantiation
|
||||
# This can happen during testing or if provider management is disabled
|
||||
conn_info.status = ProviderConnectionStatus.connected
|
||||
conn_info.updated_at = datetime.now(UTC)
|
||||
logger.warning(f"Registered provider {provider_id} without instantiation (no registry)")
|
||||
logger.info(f"Registered and instantiated dynamic provider {provider_id} (api={api}, type={provider_type})")
|
||||
|
||||
# Store updated status
|
||||
await self._store_connection(conn_info)
|
||||
|
|
@ -468,6 +486,7 @@ class ProviderImpl(Providers):
|
|||
|
||||
async def update_provider(
|
||||
self,
|
||||
api: str,
|
||||
provider_id: str,
|
||||
config: dict[str, Any] | None = None,
|
||||
attributes: dict[str, list[str]] | None = None,
|
||||
|
|
@ -477,16 +496,16 @@ class ProviderImpl(Providers):
|
|||
Updates persist to kvstore and survive server restarts.
|
||||
This works for all providers (whether originally from run.yaml or API).
|
||||
"""
|
||||
logger.info(f"🔄 UPDATE_PROVIDER called: provider_id={provider_id}, has_config={config is not None}, has_attributes={attributes is not None}")
|
||||
|
||||
if not self.kvstore:
|
||||
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
|
||||
|
||||
# Check if provider exists
|
||||
if provider_id not in self.dynamic_providers:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
# Use composite key
|
||||
cache_key = f"{api}::{provider_id}"
|
||||
if cache_key not in self.dynamic_providers:
|
||||
raise ValueError(f"Provider {provider_id} not found for API {api}")
|
||||
|
||||
conn_info = self.dynamic_providers[provider_id]
|
||||
conn_info = self.dynamic_providers[cache_key]
|
||||
|
||||
# Update config if provided
|
||||
if config is not None:
|
||||
|
|
@ -504,33 +523,26 @@ class ProviderImpl(Providers):
|
|||
await self._store_connection(conn_info)
|
||||
|
||||
# Hot-reload: Shutdown old instance and reinstantiate with new config
|
||||
if self.provider_registry:
|
||||
# Shutdown old instance if it exists
|
||||
if provider_id in self.dynamic_provider_impls:
|
||||
old_impl = self.dynamic_provider_impls[provider_id]
|
||||
if hasattr(old_impl, "shutdown"):
|
||||
try:
|
||||
await old_impl.shutdown()
|
||||
logger.debug(f"Shutdown old instance of provider {provider_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down old instance of {provider_id}: {e}")
|
||||
# Shutdown old instance if it exists
|
||||
if cache_key in self.dynamic_provider_impls:
|
||||
old_impl = self.dynamic_provider_impls[cache_key]
|
||||
if hasattr(old_impl, "shutdown"):
|
||||
try:
|
||||
await old_impl.shutdown()
|
||||
logger.debug(f"Shutdown old instance of provider {provider_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down old instance of {provider_id}: {e}")
|
||||
|
||||
# Reinstantiate with new config
|
||||
impl = await self._instantiate_provider(conn_info)
|
||||
self.dynamic_provider_impls[provider_id] = impl
|
||||
# Reinstantiate with new config
|
||||
impl = await self._instantiate_provider(conn_info)
|
||||
self.dynamic_provider_impls[cache_key] = impl
|
||||
|
||||
# Update status to connected after successful reinstantiation
|
||||
conn_info.status = ProviderConnectionStatus.connected
|
||||
conn_info.updated_at = datetime.now(UTC)
|
||||
await self._store_connection(conn_info)
|
||||
# Update status to connected after successful reinstantiation
|
||||
conn_info.status = ProviderConnectionStatus.connected
|
||||
conn_info.updated_at = datetime.now(UTC)
|
||||
await self._store_connection(conn_info)
|
||||
|
||||
logger.info(f"Hot-reloaded dynamic provider {provider_id}")
|
||||
else:
|
||||
# No registry - just update config without reinstantiation
|
||||
conn_info.status = ProviderConnectionStatus.connected
|
||||
conn_info.updated_at = datetime.now(UTC)
|
||||
await self._store_connection(conn_info)
|
||||
logger.warning(f"Updated provider {provider_id} config without hot-reload (no registry)")
|
||||
logger.info(f"Hot-reloaded dynamic provider {provider_id}")
|
||||
|
||||
return UpdateProviderResponse(provider=conn_info)
|
||||
|
||||
|
|
@ -543,34 +555,36 @@ class ProviderImpl(Providers):
|
|||
logger.error(f"Failed to update provider {provider_id}: {e}")
|
||||
raise RuntimeError(f"Failed to update provider: {e}") from e
|
||||
|
||||
async def unregister_provider(self, provider_id: str) -> None:
|
||||
async def unregister_provider(self, api: str, provider_id: str) -> None:
|
||||
"""Unregister a provider.
|
||||
|
||||
Removes the provider from kvstore and shuts down its instance.
|
||||
This works for all providers (whether originally from run.yaml or API).
|
||||
"""
|
||||
logger.info(f"🗑️ UNREGISTER_PROVIDER called: provider_id={provider_id}")
|
||||
|
||||
if not self.kvstore:
|
||||
raise RuntimeError("Dynamic provider management is not enabled (no kvstore configured)")
|
||||
|
||||
# Check if provider exists
|
||||
if provider_id not in self.dynamic_providers:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
# Use composite key
|
||||
cache_key = f"{api}::{provider_id}"
|
||||
if cache_key not in self.dynamic_providers:
|
||||
raise ValueError(f"Provider {provider_id} not found for API {api}")
|
||||
|
||||
conn_info = self.dynamic_providers[cache_key]
|
||||
|
||||
try:
|
||||
# Shutdown provider instance if it exists
|
||||
if provider_id in self.dynamic_provider_impls:
|
||||
impl = self.dynamic_provider_impls[provider_id]
|
||||
if cache_key in self.dynamic_provider_impls:
|
||||
impl = self.dynamic_provider_impls[cache_key]
|
||||
if hasattr(impl, "shutdown"):
|
||||
await impl.shutdown()
|
||||
del self.dynamic_provider_impls[provider_id]
|
||||
del self.dynamic_provider_impls[cache_key]
|
||||
|
||||
# Remove from kvstore
|
||||
await self._delete_connection(provider_id)
|
||||
# Remove from kvstore (using the api and provider_id from conn_info)
|
||||
await self._delete_connection(conn_info.api, provider_id)
|
||||
|
||||
# Remove from runtime cache
|
||||
del self.dynamic_providers[provider_id]
|
||||
del self.dynamic_providers[cache_key]
|
||||
|
||||
logger.info(f"Unregistered dynamic provider {provider_id}")
|
||||
|
||||
|
|
@ -578,23 +592,24 @@ class ProviderImpl(Providers):
|
|||
logger.error(f"Failed to unregister provider {provider_id}: {e}")
|
||||
raise RuntimeError(f"Failed to unregister provider: {e}") from e
|
||||
|
||||
async def test_provider_connection(self, provider_id: str) -> TestProviderConnectionResponse:
|
||||
async def test_provider_connection(self, api: str, provider_id: str) -> TestProviderConnectionResponse:
|
||||
"""Test a provider connection."""
|
||||
logger.info(f"🔍 TEST_PROVIDER_CONNECTION called: provider_id={provider_id}")
|
||||
|
||||
# Check if provider exists (static or dynamic)
|
||||
provider_impl = None
|
||||
cache_key = f"{api}::{provider_id}"
|
||||
|
||||
# Check dynamic providers first (using composite keys)
|
||||
if cache_key in self.dynamic_provider_impls:
|
||||
provider_impl = self.dynamic_provider_impls[cache_key]
|
||||
|
||||
# Check dynamic providers first
|
||||
if provider_id in self.dynamic_provider_impls:
|
||||
provider_impl = self.dynamic_provider_impls[provider_id]
|
||||
# Check static providers
|
||||
elif provider_id in self.deps:
|
||||
if not provider_impl and provider_id in self.deps:
|
||||
provider_impl = self.deps[provider_id]
|
||||
|
||||
if not provider_impl:
|
||||
return TestProviderConnectionResponse(
|
||||
success=False, error_message=f"Provider {provider_id} not found or not initialized"
|
||||
success=False, error_message=f"Provider {provider_id} not found for API {api}"
|
||||
)
|
||||
|
||||
# Check if provider has health method
|
||||
|
|
@ -611,8 +626,8 @@ class ProviderImpl(Providers):
|
|||
health_result = await asyncio.wait_for(provider_impl.health(), timeout=5.0)
|
||||
|
||||
# Update health in dynamic provider cache if applicable
|
||||
if provider_id in self.dynamic_providers:
|
||||
conn_info = self.dynamic_providers[provider_id]
|
||||
if cache_key and cache_key in self.dynamic_providers:
|
||||
conn_info = self.dynamic_providers[cache_key]
|
||||
conn_info.health = ProviderHealth.from_health_response(health_result)
|
||||
conn_info.last_health_check = datetime.now(UTC)
|
||||
await self._store_connection(conn_info)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue