llama-stack-mirror/llama_stack/distribution/stack.py
Charlie Doern d994305f0a
fix: remove disabled providers from model dump (#2784)
# What does this PR do?

currently when running `llama stack run --template starter...` the
__disabled__ providers, their models, etc are printed alongside the
enabled ones making the output really confusing

in server.py add a utility `remove_disabled_providers` which
post-processes the model_dump output to remove any dict with
`provider_id: __disabled__`

we also have `debug` logs printing the disabled providers, so I think
its safe to say that is the only indicator we need when using starter.

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan

before (output truncated because it was huge):


```
...
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-3.2-11B-Vision-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-3.2-11B-Vision-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-3.2-90B-Vision-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-3.2-90B-Vision-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-4-Scout-17B-16E-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-4-Scout-17B-16E-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-4-Scout-17B-16E-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-4-Scout-17B-16E-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-4-Maverick-17B-128E-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-4-Maverick-17B-128E-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Meta-Llama-Guard-3-8B
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Meta-Llama-Guard-3-8B
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-Guard-3-8B
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Meta-Llama-Guard-3-8B
         - metadata:
             embedding_dimension: 384
           model_id: all-MiniLM-L6-v2
           model_type: embedding
           provider_id: sentence-transformers
           provider_model_id: null
         providers:
           agents:
           - config:
               persistence_store:
                 db_path: /Users/charliedoern/.llama/distributions/starter/agents_store.db
                 type: sqlite
               responses_store:
                 db_path: /Users/charliedoern/.llama/distributions/starter/responses_store.db
                 type: sqlite
             provider_id: meta-reference
             provider_type: inline::meta-reference
           datasetio:
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/huggingface_datasetio.db
                 type: sqlite
             provider_id: huggingface
             provider_type: remote::huggingface
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/localfs_datasetio.db
                 type: sqlite
             provider_id: localfs
             provider_type: inline::localfs
           eval:
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/meta_reference_eval.db
                 type: sqlite
             provider_id: meta-reference
             provider_type: inline::meta-reference
           files:
           - config:
               metadata_store:
                 db_path: /Users/charliedoern/.llama/distributions/starter/files_metadata.db
                 type: sqlite
               storage_dir: /Users/charliedoern/.llama/distributions/starter/files
             provider_id: meta-reference-files
             provider_type: inline::localfs
           inference:
           - config:
               api_key: '********'
               base_url: https://api.cerebras.ai
             provider_id: __disabled__
             provider_type: remote::cerebras
           - config:
               url: http://localhost:11434
             provider_id: ollama
             provider_type: remote::ollama
           - config:
               api_token: '********'
               max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
               tls_verify: ${env.VLLM_TLS_VERIFY:=true}
               url: ${env.VLLM_URL}
             provider_id: __disabled__
             provider_type: remote::vllm
           - config:
               url: ${env.TGI_URL}
             provider_id: __disabled__
             provider_type: remote::tgi
           - config:
               api_token: '********'
               huggingface_repo: ${env.INFERENCE_MODEL}
             provider_id: __disabled__
             provider_type: remote::hf::serverless
           - config:
               api_token: '********'
               endpoint_name: ${env.INFERENCE_ENDPOINT_NAME}
             provider_id: __disabled__
             provider_type: remote::hf::endpoint
           - config:
               api_key: '********'
               url: https://api.fireworks.ai/inference/v1
             provider_id: __disabled__
             provider_type: remote::fireworks
           - config:
               api_key: '********'
               url: https://api.together.xyz/v1
             provider_id: __disabled__
             provider_type: remote::together
           - config: {}
             provider_id: __disabled__
             provider_type: remote::bedrock
           - config:
               api_token: '********'
               url: ${env.DATABRICKS_URL}
             provider_id: __disabled__
             provider_type: remote::databricks
           - config:
               api_key: '********'
               append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
               url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
             provider_id: __disabled__
             provider_type: remote::nvidia
           - config:
               api_token: '********'
               url: ${env.RUNPOD_URL:=}
             provider_id: __disabled__
             provider_type: remote::runpod
           - config:
               api_key: '********'
             provider_id: __disabled__
             provider_type: remote::openai
           - config:
               api_key: '********'
             provider_id: __disabled__
             provider_type: remote::anthropic
           - config:
               api_key: '********'
             provider_id: __disabled__
             provider_type: remote::gemini
           - config:
               api_key: '********'
               url: https://api.groq.com
             provider_id: __disabled__
             provider_type: remote::groq
           - config:
               api_key: '********'
               openai_compat_api_base: https://api.fireworks.ai/inference/v1
             provider_id: __disabled__
             provider_type: remote::fireworks-openai-compat
           - config:
               api_key: '********'
               openai_compat_api_base: https://api.llama.com/compat/v1/
             provider_id: __disabled__
             provider_type: remote::llama-openai-compat
           - config:
               api_key: '********'
               openai_compat_api_base: https://api.together.xyz/v1
             provider_id: __disabled__
             provider_type: remote::together-openai-compat
           - config:
               api_key: '********'
               openai_compat_api_base: https://api.groq.com/openai/v1
             provider_id: __disabled__
             provider_type: remote::groq-openai-compat
           - config:
               api_key: '********'
               openai_compat_api_base: https://api.sambanova.ai/v1
             provider_id: __disabled__
             provider_type: remote::sambanova-openai-compat
           - config:
               api_key: '********'
               openai_compat_api_base: https://api.cerebras.ai/v1
             provider_id: __disabled__
             provider_type: remote::cerebras-openai-compat
           - config:
               api_key: '********'
               url: https://api.sambanova.ai/v1
             provider_id: __disabled__
             provider_type: remote::sambanova
           - config:
               api_key: '********'
               url: ${env.PASSTHROUGH_URL}
             provider_id: __disabled__
             provider_type: remote::passthrough
           - config: {}
             provider_id: sentence-transformers
             provider_type: inline::sentence-transformers
           post_training:
           - config:
               checkpoint_format: huggingface
               device: cpu
               distributed_backend: null
             provider_id: huggingface
             provider_type: inline::huggingface
           safety:
           - config:
               excluded_categories: []
             provider_id: llama-guard
             provider_type: inline::llama-guard
           scoring:
           - config: {}
             provider_id: basic
             provider_type: inline::basic
           - config: {}
             provider_id: llm-as-judge
             provider_type: inline::llm-as-judge
           - config:
               openai_api_key: '********'
             provider_id: braintrust
             provider_type: inline::braintrust
           telemetry:
           - config:
               otel_exporter_otlp_endpoint: null
               service_name: "\u200B"
               sinks: console,sqlite
               sqlite_db_path: /Users/charliedoern/.llama/distributions/starter/trace_store.db
             provider_id: meta-reference
             provider_type: inline::meta-reference
           tool_runtime:
           - config:
               api_key: '********'
               max_results: 3
             provider_id: brave-search
             provider_type: remote::brave-search
           - config:
               api_key: '********'
               max_results: 3
             provider_id: tavily-search
             provider_type: remote::tavily-search
           - config: {}
             provider_id: rag-runtime
             provider_type: inline::rag-runtime
           - config: {}
             provider_id: model-context-protocol
             provider_type: remote::model-context-protocol
           vector_io:
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/faiss_store.db
                 type: sqlite
             provider_id: faiss
             provider_type: inline::faiss
           - config:
               db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
               kvstore:
                 db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
                 type: sqlite
             provider_id: __disabled__
             provider_type: inline::sqlite-vec
           - config:
               db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
               kvstore:
                 db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
                 type: sqlite
             provider_id: __disabled__
             provider_type: inline::milvus
           - config:
               url: ${env.CHROMADB_URL:=}
             provider_id: __disabled__
             provider_type: remote::chromadb
           - config:
               db: ${env.PGVECTOR_DB:=}
               host: ${env.PGVECTOR_HOST:=localhost}
               kvstore:
                 db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
                 type: sqlite
               password: '********'
               port: ${env.PGVECTOR_PORT:=5432}
               user: ${env.PGVECTOR_USER:=}
             provider_id: __disabled__
             provider_type: remote::pgvector
         scoring_fns: []
         server:
           auth: null
           host: null
           port: 8321
           quota: null
           tls_cafile: null
           tls_certfile: null
           tls_keyfile: null
         shields:
         - params: null
           provider_id: null
           provider_shield_id: ollama/__disabled__
           shield_id: __disabled__
         tool_groups:
         - args: null
           mcp_endpoint: null
           provider_id: tavily-search
           toolgroup_id: builtin::websearch
         - args: null
           mcp_endpoint: null
           provider_id: rag-runtime
           toolgroup_id: builtin::rag
         vector_dbs: []
         version: 2

```

after:

```
INFO     2025-07-16 13:00:32,604 __main__:448 server: Run configuration:
INFO     2025-07-16 13:00:32,606 __main__:450 server: apis:
         - agents
         - datasetio
         - eval
         - files
         - inference
         - post_training
         - safety
         - scoring
         - telemetry
         - tool_runtime
         - vector_io
         benchmarks: []
         datasets: []
         image_name: starter
         inference_store:
           db_path: /Users/charliedoern/.llama/distributions/starter/inference_store.db
           type: sqlite
         metadata_store:
           db_path: /Users/charliedoern/.llama/distributions/starter/registry.db
           type: sqlite
         models:
         - metadata: {}
           model_id: ollama/llama3.2:3b
           model_type: llm
           provider_id: ollama
           provider_model_id: llama3.2:3b
         - metadata:
             embedding_dimension: 384
           model_id: all-MiniLM-L6-v2
           model_type: embedding
           provider_id: sentence-transformers
         providers:
           agents:
           - config:
               persistence_store:
                 db_path: /Users/charliedoern/.llama/distributions/starter/agents_store.db
                 type: sqlite
               responses_store:
                 db_path: /Users/charliedoern/.llama/distributions/starter/responses_store.db
                 type: sqlite
             provider_id: meta-reference
             provider_type: inline::meta-reference
           datasetio:
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/huggingface_datasetio.db
                 type: sqlite
             provider_id: huggingface
             provider_type: remote::huggingface
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/localfs_datasetio.db
                 type: sqlite
             provider_id: localfs
             provider_type: inline::localfs
           eval:
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/meta_reference_eval.db
                 type: sqlite
             provider_id: meta-reference
             provider_type: inline::meta-reference
           files:
           - config:
               metadata_store:
                 db_path: /Users/charliedoern/.llama/distributions/starter/files_metadata.db
                 type: sqlite
               storage_dir: /Users/charliedoern/.llama/distributions/starter/files
             provider_id: meta-reference-files
             provider_type: inline::localfs
           inference:
           - config:
               url: http://localhost:11434
             provider_id: ollama
             provider_type: remote::ollama
           - config: {}
             provider_id: sentence-transformers
             provider_type: inline::sentence-transformers
           post_training:
           - config:
               checkpoint_format: huggingface
               device: cpu
             provider_id: huggingface
             provider_type: inline::huggingface
           safety:
           - config:
               excluded_categories: []
             provider_id: llama-guard
             provider_type: inline::llama-guard
           scoring:
           - config: {}
             provider_id: basic
             provider_type: inline::basic
           - config: {}
             provider_id: llm-as-judge
             provider_type: inline::llm-as-judge
           - config:
               openai_api_key: '********'
             provider_id: braintrust
             provider_type: inline::braintrust
           telemetry:
           - config:
               service_name: "\u200B"
               sinks: console,sqlite
               sqlite_db_path: /Users/charliedoern/.llama/distributions/starter/trace_store.db
             provider_id: meta-reference
             provider_type: inline::meta-reference
           tool_runtime:
           - config:
               api_key: '********'
               max_results: 3
             provider_id: brave-search
             provider_type: remote::brave-search
           - config:
               api_key: '********'
               max_results: 3
             provider_id: tavily-search
             provider_type: remote::tavily-search
           - config: {}
             provider_id: rag-runtime
             provider_type: inline::rag-runtime
           - config: {}
             provider_id: model-context-protocol
             provider_type: remote::model-context-protocol
           vector_io:
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/faiss_store.db
                 type: sqlite
             provider_id: faiss
             provider_type: inline::faiss
         scoring_fns: []
         server:
           port: 8321
         shields: []
         tool_groups:
         - provider_id: tavily-search
           toolgroup_id: builtin::websearch
         - provider_id: rag-runtime
           toolgroup_id: builtin::rag
         vector_dbs: []
         version: 2
```

Signed-off-by: Charlie Doern <cdoern@redhat.com>
2025-07-18 10:44:35 -07:00

388 lines
15 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib.resources
import os
import re
import tempfile
from typing import Any
import yaml
from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
logger = get_logger(name=__name__, category="core")
class LlamaStack(
Providers,
VectorDBs,
Inference,
BatchInference,
Agents,
Safety,
SyntheticDataGeneration,
Datasets,
Telemetry,
PostTraining,
VectorIO,
Eval,
Benchmarks,
Scoring,
ScoringFunctions,
DatasetIO,
Models,
Shields,
Inspect,
ToolGroups,
ToolRuntime,
RAGToolRuntime,
Files,
):
pass
RESOURCES = [
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",
Api.scoring_functions,
"register_scoring_function",
"list_scoring_functions",
),
("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"),
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
]
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
if api not in impls:
continue
method = getattr(impls[api], register_method)
for obj in objects:
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# Do not register models on disabled providers
if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
continue
# In complex templates, like our starter template, we may have dynamic model ids
# given by environment variables. This allows those environment variables to have
# a default value of __disabled__ to skip registration of the model if not set.
if (
hasattr(obj, "provider_model_id")
and obj.provider_model_id is not None
and "__disabled__" in obj.provider_model_id
):
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
continue
if hasattr(obj, "shield_id") and obj.shield_id is not None and obj.shield_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled shield.")
continue
# we want to maintain the type information in arguments to method.
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
method = getattr(impls[api], list_method)
response = await method()
objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process:
logger.debug(
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
)
class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name
self.path = path
super().__init__(
f"Environment variable '{var_name}' not set or empty {f'at {path}' if path else ''}. "
f"Use ${{env.{var_name}:=default_value}} to provide a default value, "
f"${{env.{var_name}:+value_if_set}} to make the field conditional, "
f"or ensure the environment variable is set."
)
def replace_env_vars(config: Any, path: str = "") -> Any:
if isinstance(config, dict):
result = {}
for k, v in config.items():
try:
result[k] = replace_env_vars(v, f"{path}.{k}" if path else k)
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return result
elif isinstance(config, list):
result = []
for i, v in enumerate(config):
try:
# Special handling for providers: first resolve the provider_id to check if provider
# is disabled so that we can skip config env variable expansion and avoid validation errors
if isinstance(v, dict) and "provider_id" in v:
try:
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
if resolved_provider_id == "__disabled__":
logger.debug(
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
)
# Create a copy with resolved provider_id but original config
disabled_provider = v.copy()
disabled_provider["provider_id"] = resolved_provider_id
continue
except EnvVarError:
# If we can't resolve the provider_id, continue with normal processing
pass
# Normal processing for non-disabled providers
result.append(replace_env_vars(v, f"{path}[{i}]"))
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return result
elif isinstance(config, str):
# Pattern supports bash-like syntax: := for default and :+ for conditional and a optional value
pattern = r"\${env\.([A-Z0-9_]+)(?::([=+])([^}]*))?}"
def get_env_var(match: re.Match):
env_var = match.group(1)
operator = match.group(2) # '=' for default, '+' for conditional
value_expr = match.group(3)
env_value = os.environ.get(env_var)
if operator == "=": # Default value syntax: ${env.FOO:=default}
# If the env is set like ${env.FOO:=default} then use the env value when set
if env_value:
value = env_value
else:
# If the env is not set, look for a default value
# value_expr returns empty string (not None) when not matched
# This means ${env.FOO:=} and it's accepted and returns empty string - just like bash
if value_expr == "":
return ""
else:
value = value_expr
elif operator == "+": # Conditional value syntax: ${env.FOO:+value_if_set}
# If the env is set like ${env.FOO:+value_if_set} then use the value_if_set
if env_value:
if value_expr:
value = value_expr
# This means ${env.FOO:+}
else:
# Just like bash, this doesn't care whether the env is set or not and applies
# the value, in this case the empty string
return ""
else:
# Just like bash, this doesn't care whether the env is set or not, since it's not set
# we return an empty string
value = ""
else: # No operator case: ${env.FOO}
if not env_value:
raise EnvVarError(env_var, path)
value = env_value
# expand "~" from the values
return os.path.expanduser(value)
try:
result = re.sub(pattern, get_env_var, config)
return _convert_string_to_proper_type(result)
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return config
def _convert_string_to_proper_type(value: str) -> Any:
# This might be tricky depending on what the config type is, if 'str | None' we are
# good, if 'str' we need to keep the empty string... 'str | None' is more common and
# providers config should be typed this way.
# TODO: we could try to load the config class and see if the config has a field with type 'str | None'
# and then convert the empty string to None or not
if value == "":
return None
lowered = value.lower()
if lowered == "true":
return True
elif lowered == "false":
return False
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Ensure that any value for a key 'image_name' in a config_dict is a string"""
if "image_name" in config_dict and config_dict["image_name"] is not None:
config_dict["image_name"] = str(config_dict["image_name"])
return config_dict
def validate_env_pair(env_pair: str) -> tuple[str, str]:
"""Validate and split an environment variable key-value pair."""
try:
key, value = env_pair.split("=", 1)
key = key.strip()
if not key:
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
if not all(c.isalnum() or c == "_" for c in key):
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
return key, value
except ValueError as e:
raise ValueError(
f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value"
) from e
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
"""Add internal implementations (inspect and providers) to the implementations dictionary.
Args:
impls: Dictionary of API implementations
run_config: Stack run configuration
"""
inspect_impl = DistributionInspectImpl(
DistributionInspectConfig(run_config=run_config),
deps=impls,
)
impls[Api.inspect] = inspect_impl
providers_impl = ProviderImpl(
ProviderImplConfig(run_config=run_config),
deps=impls,
)
impls[Api.providers] = providers_impl
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
policy = run_config.server.auth.access_policy if run_config.server.auth else []
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config)
await register_resources(run_config, impls)
return impls
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
with importlib.resources.as_file(template_path) as path:
if not path.exists():
raise ValueError(f"Template '{template}' not found at {template_path}")
run_config = yaml.safe_load(path.open())
return StackRunConfig(**replace_env_vars(run_config))
def run_config_from_adhoc_config_spec(
adhoc_config_spec: str, provider_registry: ProviderRegistry | None = None
) -> StackRunConfig:
"""
Create an adhoc distribution from a list of API providers.
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
"""
api_providers = adhoc_config_spec.replace(";", ",").split(",")
provider_registry = provider_registry or get_provider_registry()
distro_dir = tempfile.mkdtemp()
provider_configs_by_api = {}
for api_provider in api_providers:
api_str, provider = api_provider.split("=")
api = Api(api_str)
providers_by_type = provider_registry[api]
provider_spec = providers_by_type.get(provider)
if not provider_spec:
provider_spec = providers_by_type.get(f"inline::{provider}")
if not provider_spec:
provider_spec = providers_by_type.get(f"remote::{provider}")
if not provider_spec:
raise ValueError(
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
)
# call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
provider_configs_by_api[api_str] = [
Provider(
provider_id=provider,
provider_type=provider_spec.provider_type,
config=provider_config,
)
]
config = StackRunConfig(
image_name="distro-test",
apis=list(provider_configs_by_api.keys()),
providers=provider_configs_by_api,
)
return config