# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import asyncio from typing import Any from pydantic import BaseModel from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers from llama_stack.log import get_logger from llama_stack.providers.datatypes import HealthResponse, HealthStatus from .datatypes import StackRunConfig from .utils.config import redact_sensitive_fields logger = get_logger(name=__name__, category="core") class ProviderImplConfig(BaseModel): run_config: StackRunConfig async def get_provider_impl(config, deps): impl = ProviderImpl(config, deps) await impl.initialize() return impl class ProviderImpl(Providers): def __init__(self, config, deps): self.config = config self.deps = deps async def initialize(self) -> None: pass async def shutdown(self) -> None: logger.debug("ProviderImpl.shutdown") pass async def list_providers(self) -> ListProvidersResponse: run_config = self.config.run_config safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump())) providers_health = await self.get_providers_health() ret = [] for api, providers in safe_config.providers.items(): for p in providers: ret.append( ProviderInfo( api=api, provider_id=p.provider_id, provider_type=p.provider_type, config=p.config, health=providers_health.get(api, {}).get( p.provider_id, HealthResponse( status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check" ), ), ) ) return ListProvidersResponse(data=ret) async def inspect_provider(self, provider_id: str) -> ProviderInfo: all_providers = await self.list_providers() for p in all_providers.data: if p.provider_id == provider_id: return p raise ValueError(f"Provider {provider_id} not found") async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]: """Get health status for all providers. Returns: Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses. Each API maps to a dictionary of provider IDs to their health responses. """ providers_health: dict[str, dict[str, HealthResponse]] = {} timeout = 1.0 async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None: # Skip special implementations (inspect/providers) that don't have provider specs if not hasattr(impl, "__provider_spec__"): return None api_name = impl.__provider_spec__.api.name if not hasattr(impl, "health"): return ( api_name, HealthResponse( status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check" ), ) try: health = await asyncio.wait_for(impl.health(), timeout=timeout) return api_name, health except asyncio.TimeoutError: return ( api_name, HealthResponse( status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds" ), ) except Exception as e: return ( api_name, HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"), ) # Create tasks for all providers tasks = [check_provider_health(impl) for impl in self.deps.values()] # Wait for all health checks to complete results = await asyncio.gather(*tasks) # Organize results by API and provider ID for result in results: if result is None: # Skip special implementations continue api_name, health_response = result providers_health[api_name] = health_response return providers_health