mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
When using bash style substitution env variable in distribution template, we are processing the string and convert it to the type associated with the provider's config class. This allows us to return the proper type. This is crucial for api key since they are not strings anymore but SecretStr. If the key is unset we will get an empty string which will result in a Pydantic error like: ``` ERROR 2025-09-25 21:40:44,565 __main__:527 core::server: Error creating app: 1 validation error for AnthropicConfig api_key Input should be a valid string For further information visit https://errors.pydantic.dev/2.11/v/string_type ``` Signed-off-by: Sébastien Han <seb@redhat.com>
581 lines
22 KiB
Python
581 lines
22 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import importlib.resources
|
|
import os
|
|
import re
|
|
import tempfile
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
from llama_stack.apis.agents import Agents
|
|
from llama_stack.apis.benchmarks import Benchmarks
|
|
from llama_stack.apis.datasetio import DatasetIO
|
|
from llama_stack.apis.datasets import Datasets
|
|
from llama_stack.apis.eval import Eval
|
|
from llama_stack.apis.files import Files
|
|
from llama_stack.apis.inference import Inference
|
|
from llama_stack.apis.inspect import Inspect
|
|
from llama_stack.apis.models import Models
|
|
from llama_stack.apis.post_training import PostTraining
|
|
from llama_stack.apis.prompts import Prompts
|
|
from llama_stack.apis.providers import Providers
|
|
from llama_stack.apis.safety import Safety
|
|
from llama_stack.apis.scoring import Scoring
|
|
from llama_stack.apis.scoring_functions import ScoringFunctions
|
|
from llama_stack.apis.shields import Shields
|
|
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
|
from llama_stack.apis.telemetry import Telemetry
|
|
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
|
from llama_stack.apis.vector_dbs import VectorDBs
|
|
from llama_stack.apis.vector_io import VectorIO
|
|
from llama_stack.core.datatypes import Provider, StackRunConfig
|
|
from llama_stack.core.distribution import get_provider_registry
|
|
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
|
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
|
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
|
|
from llama_stack.core.resolver import ProviderRegistry, resolve_impls
|
|
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
|
|
from llama_stack.core.store.registry import create_dist_registry
|
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.datatypes import Api
|
|
|
|
logger = get_logger(name=__name__, category="core")
|
|
|
|
|
|
class LlamaStack(
|
|
Providers,
|
|
VectorDBs,
|
|
Inference,
|
|
Agents,
|
|
Safety,
|
|
SyntheticDataGeneration,
|
|
Datasets,
|
|
Telemetry,
|
|
PostTraining,
|
|
VectorIO,
|
|
Eval,
|
|
Benchmarks,
|
|
Scoring,
|
|
ScoringFunctions,
|
|
DatasetIO,
|
|
Models,
|
|
Shields,
|
|
Inspect,
|
|
ToolGroups,
|
|
ToolRuntime,
|
|
RAGToolRuntime,
|
|
Files,
|
|
Prompts,
|
|
):
|
|
pass
|
|
|
|
|
|
RESOURCES = [
|
|
("models", Api.models, "register_model", "list_models"),
|
|
("shields", Api.shields, "register_shield", "list_shields"),
|
|
("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
|
|
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
|
(
|
|
"scoring_fns",
|
|
Api.scoring_functions,
|
|
"register_scoring_function",
|
|
"list_scoring_functions",
|
|
),
|
|
("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"),
|
|
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
|
|
]
|
|
|
|
|
|
REGISTRY_REFRESH_INTERVAL_SECONDS = 300
|
|
REGISTRY_REFRESH_TASK = None
|
|
TEST_RECORDING_CONTEXT = None
|
|
|
|
|
|
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|
for rsrc, api, register_method, list_method in RESOURCES:
|
|
objects = getattr(run_config, rsrc)
|
|
if api not in impls:
|
|
continue
|
|
|
|
method = getattr(impls[api], register_method)
|
|
for obj in objects:
|
|
if hasattr(obj, "provider_id"):
|
|
# Do not register models on disabled providers
|
|
if not obj.provider_id or obj.provider_id == "__disabled__":
|
|
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
|
continue
|
|
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
|
|
|
# we want to maintain the type information in arguments to method.
|
|
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
|
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
|
|
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
|
|
|
|
method = getattr(impls[api], list_method)
|
|
response = await method()
|
|
|
|
objects_to_process = response.data if hasattr(response, "data") else response
|
|
|
|
for obj in objects_to_process:
|
|
logger.debug(
|
|
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
|
)
|
|
|
|
|
|
class EnvVarError(Exception):
|
|
def __init__(self, var_name: str, path: str = ""):
|
|
self.var_name = var_name
|
|
self.path = path
|
|
super().__init__(
|
|
f"Environment variable '{var_name}' not set or empty {f'at {path}' if path else ''}. "
|
|
f"Use ${{env.{var_name}:=default_value}} to provide a default value, "
|
|
f"${{env.{var_name}:+value_if_set}} to make the field conditional, "
|
|
f"or ensure the environment variable is set."
|
|
)
|
|
|
|
|
|
def replace_env_vars(
|
|
config: Any,
|
|
path: str = "",
|
|
provider_registry: dict[Api, dict[str, Any]] | None = None,
|
|
current_provider_context: dict[str, Any] | None = None,
|
|
) -> Any:
|
|
if isinstance(config, dict):
|
|
result = {}
|
|
for k, v in config.items():
|
|
try:
|
|
result[k] = replace_env_vars(
|
|
v, f"{path}.{k}" if path else k, provider_registry, current_provider_context
|
|
)
|
|
except EnvVarError as e:
|
|
raise EnvVarError(e.var_name, e.path) from None
|
|
return result
|
|
|
|
elif isinstance(config, list):
|
|
result = []
|
|
for i, v in enumerate(config):
|
|
try:
|
|
# Special handling for providers: first resolve the provider_id to check if provider
|
|
# is disabled so that we can skip config env variable expansion and avoid validation errors
|
|
if isinstance(v, dict) and "provider_id" in v:
|
|
try:
|
|
resolved_provider_id = replace_env_vars(
|
|
v["provider_id"], f"{path}[{i}].provider_id", provider_registry, current_provider_context
|
|
)
|
|
if resolved_provider_id == "__disabled__":
|
|
logger.debug(
|
|
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
|
|
)
|
|
# Create a copy with resolved provider_id but original config
|
|
disabled_provider = v.copy()
|
|
disabled_provider["provider_id"] = resolved_provider_id
|
|
result.append(disabled_provider)
|
|
continue
|
|
except EnvVarError:
|
|
# If we can't resolve the provider_id, continue with normal processing
|
|
pass
|
|
|
|
# Set up provider context for config processing
|
|
provider_context = current_provider_context
|
|
if isinstance(v, dict) and "provider_id" in v and "provider_type" in v and provider_registry:
|
|
provider_context = _get_provider_context(v, provider_registry)
|
|
|
|
# Normal processing for non-disabled providers
|
|
result.append(replace_env_vars(v, f"{path}[{i}]", provider_registry, provider_context))
|
|
except EnvVarError as e:
|
|
raise EnvVarError(e.var_name, e.path) from None
|
|
return result
|
|
|
|
elif isinstance(config, str):
|
|
# Pattern supports bash-like syntax: := for default and :+ for conditional and a optional value
|
|
pattern = r"\${env\.([A-Z0-9_]+)(?::([=+])([^}]*))?}"
|
|
|
|
def get_env_var(match: re.Match):
|
|
env_var = match.group(1)
|
|
operator = match.group(2) # '=' for default, '+' for conditional
|
|
value_expr = match.group(3)
|
|
|
|
env_value = os.environ.get(env_var)
|
|
|
|
if operator == "=": # Default value syntax: ${env.FOO:=default}
|
|
# If the env is set like ${env.FOO:=default} then use the env value when set
|
|
if env_value:
|
|
value = env_value
|
|
else:
|
|
# If the env is not set, look for a default value
|
|
# value_expr returns empty string (not None) when not matched
|
|
# This means ${env.FOO:=} and it's accepted and returns empty string - just like bash
|
|
if value_expr == "":
|
|
return ""
|
|
else:
|
|
value = value_expr
|
|
|
|
elif operator == "+": # Conditional value syntax: ${env.FOO:+value_if_set}
|
|
# If the env is set like ${env.FOO:+value_if_set} then use the value_if_set
|
|
if env_value:
|
|
if value_expr:
|
|
value = value_expr
|
|
# This means ${env.FOO:+}
|
|
else:
|
|
# Just like bash, this doesn't care whether the env is set or not and applies
|
|
# the value, in this case the empty string
|
|
return ""
|
|
else:
|
|
# Just like bash, this doesn't care whether the env is set or not, since it's not set
|
|
# we return an empty string
|
|
value = ""
|
|
else: # No operator case: ${env.FOO}
|
|
if not env_value:
|
|
raise EnvVarError(env_var, path)
|
|
value = env_value
|
|
|
|
# expand "~" from the values
|
|
return os.path.expanduser(value)
|
|
|
|
try:
|
|
result = re.sub(pattern, get_env_var, config)
|
|
# Only apply type conversion if substitution actually happened
|
|
if result != config:
|
|
return _convert_string_to_proper_type_with_config(result, path, current_provider_context)
|
|
return result
|
|
except EnvVarError as e:
|
|
raise EnvVarError(e.var_name, e.path) from None
|
|
|
|
return config
|
|
|
|
|
|
def _get_provider_context(
|
|
provider_dict: dict[str, Any], provider_registry: dict[Api, dict[str, Any]]
|
|
) -> dict[str, Any] | None:
|
|
"""Get provider context information including config class for type conversion."""
|
|
try:
|
|
provider_type = provider_dict.get("provider_type")
|
|
if not provider_type:
|
|
return None
|
|
|
|
for api, providers in provider_registry.items():
|
|
if provider_type in providers:
|
|
provider_spec = providers[provider_type]
|
|
|
|
config_class = instantiate_class_type(provider_spec.config_class)
|
|
|
|
return {
|
|
"api": api,
|
|
"provider_type": provider_type,
|
|
"config_class": config_class,
|
|
"provider_spec": provider_spec,
|
|
}
|
|
except Exception as e:
|
|
logger.debug(f"Failed to get provider context: {e}")
|
|
return None
|
|
|
|
|
|
def _convert_string_to_proper_type_with_config(value: str, path: str, provider_context: dict[str, Any] | None) -> Any:
|
|
"""Convert string to proper type using provider config class field information."""
|
|
if not provider_context or not provider_context.get("config_class"):
|
|
# best effort conversion if we don't have the config class
|
|
return _convert_string_to_proper_type(value)
|
|
|
|
try:
|
|
# Extract field name from path (e.g., "providers.inference[0].config.api_key" -> "api_key")
|
|
field_name = path.split(".")[-1] if "." in path else path
|
|
|
|
config_class = provider_context["config_class"]
|
|
|
|
if hasattr(config_class, "model_fields") and field_name in config_class.model_fields:
|
|
field_info = config_class.model_fields[field_name]
|
|
field_type = field_info.annotation
|
|
return _convert_value_by_field_type(value, field_type)
|
|
else:
|
|
return _convert_string_to_proper_type(value)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Failed to convert using config class: {e}")
|
|
return _convert_string_to_proper_type(value)
|
|
|
|
|
|
def _convert_value_by_field_type(value: str, field_type: Any) -> Any:
|
|
"""Convert string value based on Pydantic field type annotation."""
|
|
import typing
|
|
from typing import get_args, get_origin
|
|
|
|
if value == "":
|
|
if field_type is None or (hasattr(typing, "get_origin") and get_origin(field_type) is type(None)):
|
|
return None
|
|
if hasattr(typing, "get_origin") and get_origin(field_type) is typing.Union:
|
|
args = get_args(field_type)
|
|
if type(None) in args:
|
|
return None
|
|
return ""
|
|
|
|
if field_type is bool or (hasattr(typing, "get_origin") and get_origin(field_type) is bool):
|
|
lowered = value.lower()
|
|
if lowered == "true":
|
|
return True
|
|
elif lowered == "false":
|
|
return False
|
|
else:
|
|
return value
|
|
|
|
if field_type is int or (hasattr(typing, "get_origin") and get_origin(field_type) is int):
|
|
try:
|
|
return int(value)
|
|
except ValueError:
|
|
return value
|
|
|
|
if field_type is float or (hasattr(typing, "get_origin") and get_origin(field_type) is float):
|
|
try:
|
|
return float(value)
|
|
except ValueError:
|
|
return value
|
|
|
|
if hasattr(typing, "get_origin") and get_origin(field_type) is typing.Union:
|
|
args = get_args(field_type)
|
|
# Try to convert to the first non-None type
|
|
for arg in args:
|
|
if arg is not type(None):
|
|
try:
|
|
return _convert_value_by_field_type(value, arg)
|
|
except Exception:
|
|
continue
|
|
|
|
return value
|
|
|
|
|
|
def _convert_string_to_proper_type(value: str) -> Any:
|
|
# Fallback function for when provider config class is not available
|
|
# The main type conversion logic is now in _convert_string_to_proper_type_with_config
|
|
if value == "":
|
|
return None
|
|
|
|
lowered = value.lower()
|
|
if lowered == "true":
|
|
return True
|
|
elif lowered == "false":
|
|
return False
|
|
|
|
try:
|
|
return int(value)
|
|
except ValueError:
|
|
pass
|
|
|
|
try:
|
|
return float(value)
|
|
except ValueError:
|
|
pass
|
|
|
|
return value
|
|
|
|
|
|
def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
"""Ensure that any value for a key 'image_name' in a config_dict is a string"""
|
|
if "image_name" in config_dict and config_dict["image_name"] is not None:
|
|
config_dict["image_name"] = str(config_dict["image_name"])
|
|
return config_dict
|
|
|
|
|
|
def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
|
"""Validate and split an environment variable key-value pair."""
|
|
try:
|
|
key, value = env_pair.split("=", 1)
|
|
key = key.strip()
|
|
if not key:
|
|
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
|
|
if not all(c.isalnum() or c == "_" for c in key):
|
|
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
|
|
return key, value
|
|
except ValueError as e:
|
|
raise ValueError(
|
|
f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value"
|
|
) from e
|
|
|
|
|
|
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
|
|
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
|
|
|
Args:
|
|
impls: Dictionary of API implementations
|
|
run_config: Stack run configuration
|
|
"""
|
|
inspect_impl = DistributionInspectImpl(
|
|
DistributionInspectConfig(run_config=run_config),
|
|
deps=impls,
|
|
)
|
|
impls[Api.inspect] = inspect_impl
|
|
|
|
providers_impl = ProviderImpl(
|
|
ProviderImplConfig(run_config=run_config),
|
|
deps=impls,
|
|
)
|
|
impls[Api.providers] = providers_impl
|
|
|
|
prompts_impl = PromptServiceImpl(
|
|
PromptServiceConfig(run_config=run_config),
|
|
deps=impls,
|
|
)
|
|
impls[Api.prompts] = prompts_impl
|
|
|
|
|
|
class Stack:
|
|
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
|
self.run_config = run_config
|
|
self.provider_registry = provider_registry
|
|
self.impls = None
|
|
|
|
# Produces a stack of providers for the given run config. Not all APIs may be
|
|
# asked for in the run config.
|
|
async def initialize(self):
|
|
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
|
from llama_stack.testing.inference_recorder import setup_inference_recording
|
|
|
|
global TEST_RECORDING_CONTEXT
|
|
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
|
if TEST_RECORDING_CONTEXT:
|
|
TEST_RECORDING_CONTEXT.__enter__()
|
|
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
|
|
|
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
|
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
|
impls = await resolve_impls(
|
|
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
|
|
)
|
|
|
|
# Add internal implementations after all other providers are resolved
|
|
add_internal_implementations(impls, self.run_config)
|
|
|
|
if Api.prompts in impls:
|
|
await impls[Api.prompts].initialize()
|
|
|
|
await register_resources(self.run_config, impls)
|
|
|
|
await refresh_registry_once(impls)
|
|
self.impls = impls
|
|
|
|
def create_registry_refresh_task(self):
|
|
assert self.impls is not None, "Must call initialize() before starting"
|
|
|
|
global REGISTRY_REFRESH_TASK
|
|
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
|
|
|
|
def cb(task):
|
|
import traceback
|
|
|
|
if task.cancelled():
|
|
logger.error("Model refresh task cancelled")
|
|
elif task.exception():
|
|
logger.error(f"Model refresh task failed: {task.exception()}")
|
|
traceback.print_exception(task.exception())
|
|
else:
|
|
logger.debug("Model refresh task completed")
|
|
|
|
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
|
|
|
async def shutdown(self):
|
|
for impl in self.impls.values():
|
|
impl_name = impl.__class__.__name__
|
|
logger.info(f"Shutting down {impl_name}")
|
|
try:
|
|
if hasattr(impl, "shutdown"):
|
|
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
|
else:
|
|
logger.warning(f"No shutdown method for {impl_name}")
|
|
except TimeoutError:
|
|
logger.exception(f"Shutdown timeout for {impl_name}")
|
|
except (Exception, asyncio.CancelledError) as e:
|
|
logger.exception(f"Failed to shutdown {impl_name}: {e}")
|
|
|
|
global TEST_RECORDING_CONTEXT
|
|
if TEST_RECORDING_CONTEXT:
|
|
try:
|
|
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
|
except Exception as e:
|
|
logger.error(f"Error during inference recording cleanup: {e}")
|
|
|
|
global REGISTRY_REFRESH_TASK
|
|
if REGISTRY_REFRESH_TASK:
|
|
REGISTRY_REFRESH_TASK.cancel()
|
|
|
|
|
|
async def refresh_registry_once(impls: dict[Api, Any]):
|
|
logger.debug("refreshing registry")
|
|
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
|
for routing_table in routing_tables:
|
|
await routing_table.refresh()
|
|
|
|
|
|
async def refresh_registry_task(impls: dict[Api, Any]):
|
|
logger.info("starting registry refresh task")
|
|
while True:
|
|
await refresh_registry_once(impls)
|
|
|
|
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
|
|
|
|
|
|
def get_stack_run_config_from_distro(distro: str) -> StackRunConfig:
|
|
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro}/run.yaml"
|
|
|
|
with importlib.resources.as_file(distro_path) as path:
|
|
if not path.exists():
|
|
raise ValueError(f"Distribution '{distro}' not found at {distro_path}")
|
|
run_config = yaml.safe_load(path.open())
|
|
|
|
return StackRunConfig(**replace_env_vars(run_config, provider_registry=get_provider_registry()))
|
|
|
|
|
|
def run_config_from_adhoc_config_spec(
|
|
adhoc_config_spec: str, provider_registry: ProviderRegistry | None = None
|
|
) -> StackRunConfig:
|
|
"""
|
|
Create an adhoc distribution from a list of API providers.
|
|
|
|
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
|
|
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
|
|
"""
|
|
|
|
api_providers = adhoc_config_spec.replace(";", ",").split(",")
|
|
provider_registry = provider_registry or get_provider_registry()
|
|
|
|
distro_dir = tempfile.mkdtemp()
|
|
provider_configs_by_api = {}
|
|
for api_provider in api_providers:
|
|
api_str, provider = api_provider.split("=")
|
|
api = Api(api_str)
|
|
|
|
providers_by_type = provider_registry[api]
|
|
provider_spec = providers_by_type.get(provider)
|
|
if not provider_spec:
|
|
provider_spec = providers_by_type.get(f"inline::{provider}")
|
|
if not provider_spec:
|
|
provider_spec = providers_by_type.get(f"remote::{provider}")
|
|
|
|
if not provider_spec:
|
|
raise ValueError(
|
|
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
|
|
)
|
|
|
|
# call method "sample_run_config" on the provider spec config class
|
|
provider_config_type = instantiate_class_type(provider_spec.config_class)
|
|
provider_config = replace_env_vars(
|
|
provider_config_type.sample_run_config(__distro_dir__=distro_dir), provider_registry=provider_registry
|
|
)
|
|
|
|
provider_configs_by_api[api_str] = [
|
|
Provider(
|
|
provider_id=provider,
|
|
provider_type=provider_spec.provider_type,
|
|
config=provider_config,
|
|
)
|
|
]
|
|
config = StackRunConfig(
|
|
image_name="distro-test",
|
|
apis=list(provider_configs_by_api.keys()),
|
|
providers=provider_configs_by_api,
|
|
)
|
|
return config
|