mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-27 06:48:05 +00:00
being able to update providers means that admins should have the ability to turn this feature off. introduce `immutable` as a field in the Provider class. Defauling to false means all providers can be updated by default, but an admin at runtime can choose to set this to True to disable provider updating Signed-off-by: Charlie Doern <cdoern@redhat.com>
218 lines
8.4 KiB
Python
218 lines
8.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import copy
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
|
|
|
from .datatypes import Provider, StackRunConfig
|
|
from .utils.config import redact_sensitive_fields
|
|
|
|
logger = get_logger(name=__name__, category="core")
|
|
|
|
|
|
class ProviderImplConfig(BaseModel):
|
|
run_config: StackRunConfig
|
|
|
|
|
|
async def get_provider_impl(config, deps):
|
|
impl = ProviderImpl(config, deps)
|
|
await impl.initialize()
|
|
return impl
|
|
|
|
|
|
class ProviderImpl(Providers):
|
|
def __init__(self, config, deps):
|
|
self.config = config
|
|
self.deps = deps
|
|
|
|
async def initialize(self) -> None:
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
logger.debug("ProviderImpl.shutdown")
|
|
pass
|
|
|
|
async def list_providers(self) -> ListProvidersResponse:
|
|
run_config = self.config.run_config
|
|
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
|
providers_health = await self.get_providers_health()
|
|
ret = []
|
|
for api, providers in safe_config.providers.items():
|
|
for p in providers:
|
|
# Skip providers that are not enabled
|
|
if p.provider_id is None:
|
|
continue
|
|
ret.append(
|
|
ProviderInfo(
|
|
api=api,
|
|
provider_id=p.provider_id,
|
|
provider_type=p.provider_type,
|
|
config=p.config,
|
|
health=providers_health.get(api, {}).get(
|
|
p.provider_id,
|
|
HealthResponse(
|
|
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
|
),
|
|
),
|
|
)
|
|
)
|
|
|
|
return ListProvidersResponse(data=ret)
|
|
|
|
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
|
all_providers = await self.list_providers()
|
|
for p in all_providers.data:
|
|
if p.provider_id == provider_id:
|
|
return p
|
|
|
|
raise ValueError(f"Provider {provider_id} not found")
|
|
|
|
async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
|
|
"""Get health status for all providers.
|
|
|
|
Returns:
|
|
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
|
Each API maps to a dictionary of provider IDs to their health responses.
|
|
"""
|
|
providers_health: dict[str, dict[str, HealthResponse]] = {}
|
|
timeout = 1.0
|
|
|
|
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
|
# Skip special implementations (inspect/providers) that don't have provider specs
|
|
if not hasattr(impl, "__provider_spec__"):
|
|
return None
|
|
api_name = impl.__provider_spec__.api.name
|
|
if not hasattr(impl, "health"):
|
|
return (
|
|
api_name,
|
|
HealthResponse(
|
|
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
|
),
|
|
)
|
|
|
|
try:
|
|
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
|
return api_name, health
|
|
except TimeoutError:
|
|
return (
|
|
api_name,
|
|
HealthResponse(
|
|
status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds"
|
|
),
|
|
)
|
|
except Exception as e:
|
|
return (
|
|
api_name,
|
|
HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"),
|
|
)
|
|
|
|
# Create tasks for all providers
|
|
tasks = [check_provider_health(impl) for impl in self.deps.values()]
|
|
|
|
# Wait for all health checks to complete
|
|
results = await asyncio.gather(*tasks)
|
|
|
|
# Organize results by API and provider ID
|
|
for result in results:
|
|
if result is None: # Skip special implementations
|
|
continue
|
|
api_name, health_response = result
|
|
providers_health[api_name] = health_response
|
|
|
|
return providers_health
|
|
|
|
async def update_provider(
|
|
self, api: str, provider_id: str, provider_type: str, config: dict[str, Any]
|
|
) -> ProviderInfo:
|
|
# config = ast.literal_eval(provider_request.config)
|
|
prov = Provider(
|
|
provider_id=provider_id,
|
|
provider_type=provider_type,
|
|
config=config,
|
|
)
|
|
assert prov.provider_id is not None
|
|
existing_provider = None
|
|
# if the provider isn't there or the API is invalid, we should not continue
|
|
for prov_api, providers in self.config.run_config.providers.items():
|
|
if prov_api != api:
|
|
continue
|
|
for p in providers:
|
|
# the provider needs to be mutable for us to update its config
|
|
if p.provider_id == provider_id:
|
|
if p.immutable:
|
|
raise ValueError(f"Provider {provider_id} is immutable, you can only update mutable providers.")
|
|
existing_provider = p
|
|
break
|
|
if existing_provider is not None:
|
|
break
|
|
|
|
if existing_provider is None:
|
|
raise ValueError(f"Provider {provider_id} not found, you can only update already registered providers.")
|
|
|
|
new_config = self.merge_providers(existing_provider, prov)
|
|
existing_provider.config = new_config
|
|
providers_health = await self.get_providers_health()
|
|
# takes a single provider, validates its in the registry
|
|
# if it is, merge the provider config with the existing one
|
|
ret = ProviderInfo(
|
|
api=api,
|
|
provider_id=prov.provider_id,
|
|
provider_type=prov.provider_type,
|
|
config=new_config,
|
|
health=providers_health.get(api, {}).get(
|
|
p.provider_id,
|
|
HealthResponse(status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"),
|
|
),
|
|
)
|
|
|
|
return ret
|
|
|
|
def merge_dicts(self, base: dict[str, Any], overrides: dict[str, Any]) -> dict[str, Any]:
|
|
"""Recursively merges `overrides` into `base`, replacing only specified keys."""
|
|
|
|
merged = copy.deepcopy(base) # Preserve original dict
|
|
for key, value in overrides.items():
|
|
if isinstance(value, dict) and isinstance(merged.get(key), dict):
|
|
# Recursively merge if both are dictionaries
|
|
merged[key] = self.merge_dicts(merged[key], value)
|
|
else:
|
|
# Otherwise, directly override
|
|
merged[key] = value
|
|
|
|
return merged
|
|
|
|
def merge_configs(
|
|
self, global_config: dict[str, list[Provider]], new_config: dict[str, list[Provider]]
|
|
) -> dict[str, list[Provider]]:
|
|
merged_config = copy.deepcopy(global_config) # Preserve original structure
|
|
|
|
for key, new_providers in new_config.items():
|
|
if key in merged_config:
|
|
existing_providers = {p.provider_id: p for p in merged_config[key]}
|
|
|
|
for new_provider in new_providers:
|
|
if new_provider.provider_id in existing_providers:
|
|
# Override settings of existing provider
|
|
existing = existing_providers[new_provider.provider_id]
|
|
existing.config = self.merge_dicts(existing.config, new_provider.config)
|
|
else:
|
|
# Append new provider
|
|
merged_config[key].append(new_provider)
|
|
else:
|
|
# Add new category entirely
|
|
merged_config[key] = new_providers
|
|
|
|
return merged_config
|
|
|
|
def merge_providers(self, current_provider: Provider, new_provider: Provider) -> dict[str, Any]:
|
|
return self.merge_dicts(current_provider.config, new_provider.config)
|