# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import asyncio import importlib.resources import os import re import tempfile from typing import Any import yaml from llama_stack.apis.agents import Agents from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval import Eval from llama_stack.apis.files import Files from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect from llama_stack.apis.models import Models from llama_stack.apis.post_training import PostTraining from llama_stack.apis.prompts import Prompts from llama_stack.apis.providers import Providers from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_io import VectorIO from llama_stack.core.datatypes import Provider, StackRunConfig from llama_stack.core.distribution import get_provider_registry from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl from llama_stack.core.providers import ProviderImpl, ProviderImplConfig from llama_stack.core.resolver import ProviderRegistry, resolve_impls from llama_stack.core.routing_tables.common import CommonRoutingTableImpl from llama_stack.core.store.registry import create_dist_registry from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api logger = get_logger(name=__name__, category="core") class LlamaStack( Providers, VectorDBs, Inference, Agents, Safety, SyntheticDataGeneration, Datasets, Telemetry, PostTraining, VectorIO, Eval, Benchmarks, Scoring, ScoringFunctions, DatasetIO, Models, Shields, Inspect, ToolGroups, ToolRuntime, RAGToolRuntime, Files, Prompts, ): pass RESOURCES = [ ("models", Api.models, "register_model", "list_models"), ("shields", Api.shields, "register_shield", "list_shields"), ("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"), ("datasets", Api.datasets, "register_dataset", "list_datasets"), ( "scoring_fns", Api.scoring_functions, "register_scoring_function", "list_scoring_functions", ), ("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"), ("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"), ] REGISTRY_REFRESH_INTERVAL_SECONDS = 300 REGISTRY_REFRESH_TASK = None TEST_RECORDING_CONTEXT = None async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): for rsrc, api, register_method, list_method in RESOURCES: objects = getattr(run_config, rsrc) if api not in impls: continue method = getattr(impls[api], register_method) for obj in objects: if hasattr(obj, "provider_id"): # Do not register models on disabled providers if not obj.provider_id or obj.provider_id == "__disabled__": logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.") continue logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}") # we want to maintain the type information in arguments to method. # instead of method(**obj.model_dump()), which may convert a typed attr to a dict, # we use model_dump() to find all the attrs and then getattr to get the still typed value. await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()}) method = getattr(impls[api], list_method) response = await method() objects_to_process = response.data if hasattr(response, "data") else response for obj in objects_to_process: logger.debug( f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}", ) class EnvVarError(Exception): def __init__(self, var_name: str, path: str = ""): self.var_name = var_name self.path = path super().__init__( f"Environment variable '{var_name}' not set or empty {f'at {path}' if path else ''}. " f"Use ${{env.{var_name}:=default_value}} to provide a default value, " f"${{env.{var_name}:+value_if_set}} to make the field conditional, " f"or ensure the environment variable is set." ) def replace_env_vars( config: Any, path: str = "", provider_registry: dict[Api, dict[str, Any]] | None = None, current_provider_context: dict[str, Any] | None = None, ) -> Any: if isinstance(config, dict): result = {} for k, v in config.items(): try: result[k] = replace_env_vars( v, f"{path}.{k}" if path else k, provider_registry, current_provider_context ) except EnvVarError as e: raise EnvVarError(e.var_name, e.path) from None return result elif isinstance(config, list): result = [] for i, v in enumerate(config): try: # Special handling for providers: first resolve the provider_id to check if provider # is disabled so that we can skip config env variable expansion and avoid validation errors if isinstance(v, dict) and "provider_id" in v: try: resolved_provider_id = replace_env_vars( v["provider_id"], f"{path}[{i}].provider_id", provider_registry, current_provider_context ) if resolved_provider_id == "__disabled__": logger.debug( f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}" ) # Create a copy with resolved provider_id but original config disabled_provider = v.copy() disabled_provider["provider_id"] = resolved_provider_id result.append(disabled_provider) continue except EnvVarError: # If we can't resolve the provider_id, continue with normal processing pass # Set up provider context for config processing provider_context = current_provider_context if isinstance(v, dict) and "provider_id" in v and "provider_type" in v and provider_registry: provider_context = _get_provider_context(v, provider_registry) # Normal processing for non-disabled providers result.append(replace_env_vars(v, f"{path}[{i}]", provider_registry, provider_context)) except EnvVarError as e: raise EnvVarError(e.var_name, e.path) from None return result elif isinstance(config, str): # Pattern supports bash-like syntax: := for default and :+ for conditional and a optional value pattern = r"\${env\.([A-Z0-9_]+)(?::([=+])([^}]*))?}" def get_env_var(match: re.Match): env_var = match.group(1) operator = match.group(2) # '=' for default, '+' for conditional value_expr = match.group(3) env_value = os.environ.get(env_var) if operator == "=": # Default value syntax: ${env.FOO:=default} # If the env is set like ${env.FOO:=default} then use the env value when set if env_value: value = env_value else: # If the env is not set, look for a default value # value_expr returns empty string (not None) when not matched # This means ${env.FOO:=} and it's accepted and returns empty string - just like bash if value_expr == "": return "" else: value = value_expr elif operator == "+": # Conditional value syntax: ${env.FOO:+value_if_set} # If the env is set like ${env.FOO:+value_if_set} then use the value_if_set if env_value: if value_expr: value = value_expr # This means ${env.FOO:+} else: # Just like bash, this doesn't care whether the env is set or not and applies # the value, in this case the empty string return "" else: # Just like bash, this doesn't care whether the env is set or not, since it's not set # we return an empty string value = "" else: # No operator case: ${env.FOO} if not env_value: raise EnvVarError(env_var, path) value = env_value # expand "~" from the values return os.path.expanduser(value) try: result = re.sub(pattern, get_env_var, config) # Only apply type conversion if substitution actually happened if result != config: return _convert_string_to_proper_type_with_config(result, path, current_provider_context) return result except EnvVarError as e: raise EnvVarError(e.var_name, e.path) from None return config def _get_provider_context( provider_dict: dict[str, Any], provider_registry: dict[Api, dict[str, Any]] ) -> dict[str, Any] | None: """Get provider context information including config class for type conversion.""" try: provider_type = provider_dict.get("provider_type") if not provider_type: return None for api, providers in provider_registry.items(): if provider_type in providers: provider_spec = providers[provider_type] config_class = instantiate_class_type(provider_spec.config_class) return { "api": api, "provider_type": provider_type, "config_class": config_class, "provider_spec": provider_spec, } except Exception as e: logger.debug(f"Failed to get provider context: {e}") return None def _convert_string_to_proper_type_with_config(value: str, path: str, provider_context: dict[str, Any] | None) -> Any: """Convert string to proper type using provider config class field information.""" if not provider_context or not provider_context.get("config_class"): # best effort conversion if we don't have the config class return _convert_string_to_proper_type(value) try: # Extract field name from path (e.g., "providers.inference[0].config.api_key" -> "api_key") field_name = path.split(".")[-1] if "." in path else path config_class = provider_context["config_class"] # Only instantiate if the class hasn't been instantiated already # This handles the case we entered replace_env_vars() with a dict, which # could happen if we use a sample_run_config() method that returns a dict. Our unit tests do # this on the adhoc config spec creation. if isinstance(config_class, str): config_class = instantiate_class_type(config_class) if hasattr(config_class, "model_fields") and field_name in config_class.model_fields: field_info = config_class.model_fields[field_name] field_type = field_info.annotation return _convert_value_by_field_type(value, field_type) else: return _convert_string_to_proper_type(value) except Exception as e: logger.debug(f"Failed to convert using config class: {e}") return _convert_string_to_proper_type(value) def _convert_value_by_field_type(value: str, field_type: Any) -> Any: """Convert string value based on Pydantic field type annotation.""" import typing from typing import get_args, get_origin if value == "": if field_type is None or (hasattr(typing, "get_origin") and get_origin(field_type) is type(None)): return None if hasattr(typing, "get_origin") and get_origin(field_type) is typing.Union: args = get_args(field_type) if type(None) in args: return None return "" if field_type is bool or (hasattr(typing, "get_origin") and get_origin(field_type) is bool): lowered = value.lower() if lowered == "true": return True elif lowered == "false": return False else: return value if field_type is int or (hasattr(typing, "get_origin") and get_origin(field_type) is int): try: return int(value) except ValueError: return value if field_type is float or (hasattr(typing, "get_origin") and get_origin(field_type) is float): try: return float(value) except ValueError: return value if hasattr(typing, "get_origin") and get_origin(field_type) is typing.Union: args = get_args(field_type) # Try to convert to the first non-None type for arg in args: if arg is not type(None): try: return _convert_value_by_field_type(value, arg) except Exception: continue return value def _convert_string_to_proper_type(value: str) -> Any: # Fallback function for when provider config class is not available # The main type conversion logic is now in _convert_string_to_proper_type_with_config if value == "": return None lowered = value.lower() if lowered == "true": return True elif lowered == "false": return False try: return int(value) except ValueError: pass try: return float(value) except ValueError: pass return value def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]: """Ensure that any value for a key 'image_name' in a config_dict is a string""" if "image_name" in config_dict and config_dict["image_name"] is not None: config_dict["image_name"] = str(config_dict["image_name"]) return config_dict def validate_env_pair(env_pair: str) -> tuple[str, str]: """Validate and split an environment variable key-value pair.""" try: key, value = env_pair.split("=", 1) key = key.strip() if not key: raise ValueError(f"Empty key in environment variable pair: {env_pair}") if not all(c.isalnum() or c == "_" for c in key): raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}") return key, value except ValueError as e: raise ValueError( f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value" ) from e def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None: """Add internal implementations (inspect and providers) to the implementations dictionary. Args: impls: Dictionary of API implementations run_config: Stack run configuration """ inspect_impl = DistributionInspectImpl( DistributionInspectConfig(run_config=run_config), deps=impls, ) impls[Api.inspect] = inspect_impl providers_impl = ProviderImpl( ProviderImplConfig(run_config=run_config), deps=impls, ) impls[Api.providers] = providers_impl prompts_impl = PromptServiceImpl( PromptServiceConfig(run_config=run_config), deps=impls, ) impls[Api.prompts] = prompts_impl class Stack: def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None): self.run_config = run_config self.provider_registry = provider_registry self.impls = None # Produces a stack of providers for the given run config. Not all APIs may be # asked for in the run config. async def initialize(self): if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ: from llama_stack.testing.inference_recorder import setup_inference_recording global TEST_RECORDING_CONTEXT TEST_RECORDING_CONTEXT = setup_inference_recording() if TEST_RECORDING_CONTEXT: TEST_RECORDING_CONTEXT.__enter__() logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name) policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else [] impls = await resolve_impls( self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy ) # Add internal implementations after all other providers are resolved add_internal_implementations(impls, self.run_config) if Api.prompts in impls: await impls[Api.prompts].initialize() await register_resources(self.run_config, impls) await refresh_registry_once(impls) self.impls = impls def create_registry_refresh_task(self): assert self.impls is not None, "Must call initialize() before starting" global REGISTRY_REFRESH_TASK REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls)) def cb(task): import traceback if task.cancelled(): logger.error("Model refresh task cancelled") elif task.exception(): logger.error(f"Model refresh task failed: {task.exception()}") traceback.print_exception(task.exception()) else: logger.debug("Model refresh task completed") REGISTRY_REFRESH_TASK.add_done_callback(cb) async def shutdown(self): for impl in self.impls.values(): impl_name = impl.__class__.__name__ logger.info(f"Shutting down {impl_name}") try: if hasattr(impl, "shutdown"): await asyncio.wait_for(impl.shutdown(), timeout=5) else: logger.warning(f"No shutdown method for {impl_name}") except TimeoutError: logger.exception(f"Shutdown timeout for {impl_name}") except (Exception, asyncio.CancelledError) as e: logger.exception(f"Failed to shutdown {impl_name}: {e}") global TEST_RECORDING_CONTEXT if TEST_RECORDING_CONTEXT: try: TEST_RECORDING_CONTEXT.__exit__(None, None, None) except Exception as e: logger.error(f"Error during inference recording cleanup: {e}") global REGISTRY_REFRESH_TASK if REGISTRY_REFRESH_TASK: REGISTRY_REFRESH_TASK.cancel() async def refresh_registry_once(impls: dict[Api, Any]): logger.debug("refreshing registry") routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)] for routing_table in routing_tables: await routing_table.refresh() async def refresh_registry_task(impls: dict[Api, Any]): logger.info("starting registry refresh task") while True: await refresh_registry_once(impls) await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS) def get_stack_run_config_from_distro(distro: str) -> StackRunConfig: distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro}/run.yaml" with importlib.resources.as_file(distro_path) as path: if not path.exists(): raise ValueError(f"Distribution '{distro}' not found at {distro_path}") run_config = yaml.safe_load(path.open()) return StackRunConfig(**replace_env_vars(run_config, provider_registry=get_provider_registry())) def run_config_from_adhoc_config_spec( adhoc_config_spec: str, provider_registry: ProviderRegistry | None = None ) -> StackRunConfig: """ Create an adhoc distribution from a list of API providers. The list should be of the form "api=provider", e.g. "inference=fireworks". If you have multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference" """ api_providers = adhoc_config_spec.replace(";", ",").split(",") provider_registry = provider_registry or get_provider_registry() distro_dir = tempfile.mkdtemp() provider_configs_by_api = {} for api_provider in api_providers: api_str, provider = api_provider.split("=") api = Api(api_str) providers_by_type = provider_registry[api] provider_spec = providers_by_type.get(provider) if not provider_spec: provider_spec = providers_by_type.get(f"inline::{provider}") if not provider_spec: provider_spec = providers_by_type.get(f"remote::{provider}") if not provider_spec: raise ValueError( f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}" ) # call method "sample_run_config" on the provider spec config class provider_config_type = instantiate_class_type(provider_spec.config_class) provider_config = replace_env_vars( provider_config_type.sample_run_config(__distro_dir__=distro_dir), provider_registry=provider_registry, current_provider_context=provider_spec.model_dump(), ) provider_configs_by_api[api_str] = [ Provider( provider_id=provider, provider_type=provider_spec.provider_type, config=provider_config, ) ] config = StackRunConfig( image_name="distro-test", apis=list(provider_configs_by_api.keys()), providers=provider_configs_by_api, ) return config