mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-10 13:28:40 +00:00
Merge 2a34226727
into ea15f2a270
This commit is contained in:
commit
79ced0c85b
94 changed files with 341 additions and 209 deletions
|
@ -165,7 +165,7 @@ def upgrade_from_routing_table(
|
|||
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
||||
version = config_dict.get("version", None)
|
||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
processed_config_dict = replace_env_vars(config_dict, provider_registry=get_provider_registry())
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
||||
if "routing_table" in config_dict:
|
||||
|
@ -177,5 +177,5 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
|
|||
if not config_dict.get("external_providers_dir", None):
|
||||
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
|
||||
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
processed_config_dict = replace_env_vars(config_dict, provider_registry=get_provider_registry())
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
|
|
@ -33,6 +33,7 @@ from termcolor import cprint
|
|||
from llama_stack.core.build import print_pip_install_help
|
||||
from llama_stack.core.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.core.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
|
@ -220,7 +221,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
config_path = Path(config_path_or_distro_name)
|
||||
if not config_path.exists():
|
||||
raise ValueError(f"Config file {config_path} does not exist")
|
||||
config_dict = replace_env_vars(yaml.safe_load(config_path.read_text()))
|
||||
config_dict = replace_env_vars(
|
||||
yaml.safe_load(config_path.read_text()), provider_registry=get_provider_registry()
|
||||
)
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
else:
|
||||
# distribution
|
||||
|
|
|
@ -43,7 +43,7 @@ from llama_stack.core.datatypes import (
|
|||
StackRunConfig,
|
||||
process_cors_config,
|
||||
)
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis, get_provider_registry
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
|
@ -371,7 +371,7 @@ def create_app(
|
|||
logger.error(f"Error: {str(e)}")
|
||||
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
|
||||
|
||||
config = replace_env_vars(config_contents)
|
||||
config = replace_env_vars(config_contents, provider_registry=get_provider_registry())
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
_log_run_config(run_config=config)
|
||||
|
@ -524,7 +524,10 @@ def main(args: argparse.Namespace | None = None):
|
|||
env_vars=args.env,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(f"Error creating app: {str(e)}")
|
||||
logger.error(f"Stack trace:\n{traceback.format_exc()}")
|
||||
sys.exit(1)
|
||||
|
||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||
|
@ -534,7 +537,9 @@ def main(args: argparse.Namespace | None = None):
|
|||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
config = StackRunConfig(
|
||||
**cast_image_name_to_string(replace_env_vars(config_contents, provider_registry=get_provider_registry()))
|
||||
)
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
|
|
@ -141,12 +141,19 @@ class EnvVarError(Exception):
|
|||
)
|
||||
|
||||
|
||||
def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||
def replace_env_vars(
|
||||
config: Any,
|
||||
path: str = "",
|
||||
provider_registry: dict[Api, dict[str, Any]] | None = None,
|
||||
current_provider_context: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
if isinstance(config, dict):
|
||||
result = {}
|
||||
for k, v in config.items():
|
||||
try:
|
||||
result[k] = replace_env_vars(v, f"{path}.{k}" if path else k)
|
||||
result[k] = replace_env_vars(
|
||||
v, f"{path}.{k}" if path else k, provider_registry, current_provider_context
|
||||
)
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
return result
|
||||
|
@ -159,7 +166,9 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
# is disabled so that we can skip config env variable expansion and avoid validation errors
|
||||
if isinstance(v, dict) and "provider_id" in v:
|
||||
try:
|
||||
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
|
||||
resolved_provider_id = replace_env_vars(
|
||||
v["provider_id"], f"{path}[{i}].provider_id", provider_registry, current_provider_context
|
||||
)
|
||||
if resolved_provider_id == "__disabled__":
|
||||
logger.debug(
|
||||
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
|
||||
|
@ -167,13 +176,19 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
# Create a copy with resolved provider_id but original config
|
||||
disabled_provider = v.copy()
|
||||
disabled_provider["provider_id"] = resolved_provider_id
|
||||
result.append(disabled_provider)
|
||||
continue
|
||||
except EnvVarError:
|
||||
# If we can't resolve the provider_id, continue with normal processing
|
||||
pass
|
||||
|
||||
# Set up provider context for config processing
|
||||
provider_context = current_provider_context
|
||||
if isinstance(v, dict) and "provider_id" in v and "provider_type" in v and provider_registry:
|
||||
provider_context = _get_provider_context(v, provider_registry)
|
||||
|
||||
# Normal processing for non-disabled providers
|
||||
result.append(replace_env_vars(v, f"{path}[{i}]"))
|
||||
result.append(replace_env_vars(v, f"{path}[{i}]", provider_registry, provider_context))
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
return result
|
||||
|
@ -228,7 +243,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
result = re.sub(pattern, get_env_var, config)
|
||||
# Only apply type conversion if substitution actually happened
|
||||
if result != config:
|
||||
return _convert_string_to_proper_type(result)
|
||||
return _convert_string_to_proper_type_with_config(result, path, current_provider_context)
|
||||
return result
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
|
@ -236,12 +251,113 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
return config
|
||||
|
||||
|
||||
def _get_provider_context(
|
||||
provider_dict: dict[str, Any], provider_registry: dict[Api, dict[str, Any]]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get provider context information including config class for type conversion."""
|
||||
try:
|
||||
provider_type = provider_dict.get("provider_type")
|
||||
if not provider_type:
|
||||
return None
|
||||
|
||||
for api, providers in provider_registry.items():
|
||||
if provider_type in providers:
|
||||
provider_spec = providers[provider_type]
|
||||
|
||||
config_class = instantiate_class_type(provider_spec.config_class)
|
||||
|
||||
return {
|
||||
"api": api,
|
||||
"provider_type": provider_type,
|
||||
"config_class": config_class,
|
||||
"provider_spec": provider_spec,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get provider context: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _convert_string_to_proper_type_with_config(value: str, path: str, provider_context: dict[str, Any] | None) -> Any:
|
||||
"""Convert string to proper type using provider config class field information."""
|
||||
if not provider_context or not provider_context.get("config_class"):
|
||||
# best effort conversion if we don't have the config class
|
||||
return _convert_string_to_proper_type(value)
|
||||
|
||||
try:
|
||||
# Extract field name from path (e.g., "providers.inference[0].config.api_key" -> "api_key")
|
||||
field_name = path.split(".")[-1] if "." in path else path
|
||||
|
||||
config_class = provider_context["config_class"]
|
||||
# Only instantiate if the class hasn't been instantiated already
|
||||
# This handles the case we entered replace_env_vars() with a dict, which
|
||||
# could happen if we use a sample_run_config() method that returns a dict. Our unit tests do
|
||||
# this on the adhoc config spec creation.
|
||||
if isinstance(config_class, str):
|
||||
config_class = instantiate_class_type(config_class)
|
||||
|
||||
if hasattr(config_class, "model_fields") and field_name in config_class.model_fields:
|
||||
field_info = config_class.model_fields[field_name]
|
||||
field_type = field_info.annotation
|
||||
return _convert_value_by_field_type(value, field_type)
|
||||
else:
|
||||
return _convert_string_to_proper_type(value)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to convert using config class: {e}")
|
||||
return _convert_string_to_proper_type(value)
|
||||
|
||||
|
||||
def _convert_value_by_field_type(value: str, field_type: Any) -> Any:
|
||||
"""Convert string value based on Pydantic field type annotation."""
|
||||
import typing
|
||||
from typing import get_args, get_origin
|
||||
|
||||
if value == "":
|
||||
if field_type is None or (hasattr(typing, "get_origin") and get_origin(field_type) is type(None)):
|
||||
return None
|
||||
if hasattr(typing, "get_origin") and get_origin(field_type) is typing.Union:
|
||||
args = get_args(field_type)
|
||||
if type(None) in args:
|
||||
return None
|
||||
return ""
|
||||
|
||||
if field_type is bool or (hasattr(typing, "get_origin") and get_origin(field_type) is bool):
|
||||
lowered = value.lower()
|
||||
if lowered == "true":
|
||||
return True
|
||||
elif lowered == "false":
|
||||
return False
|
||||
else:
|
||||
return value
|
||||
|
||||
if field_type is int or (hasattr(typing, "get_origin") and get_origin(field_type) is int):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
if field_type is float or (hasattr(typing, "get_origin") and get_origin(field_type) is float):
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
if hasattr(typing, "get_origin") and get_origin(field_type) is typing.Union:
|
||||
args = get_args(field_type)
|
||||
# Try to convert to the first non-None type
|
||||
for arg in args:
|
||||
if arg is not type(None):
|
||||
try:
|
||||
return _convert_value_by_field_type(value, arg)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _convert_string_to_proper_type(value: str) -> Any:
|
||||
# This might be tricky depending on what the config type is, if 'str | None' we are
|
||||
# good, if 'str' we need to keep the empty string... 'str | None' is more common and
|
||||
# providers config should be typed this way.
|
||||
# TODO: we could try to load the config class and see if the config has a field with type 'str | None'
|
||||
# and then convert the empty string to None or not
|
||||
# Fallback function for when provider config class is not available
|
||||
# The main type conversion logic is now in _convert_string_to_proper_type_with_config
|
||||
if value == "":
|
||||
return None
|
||||
|
||||
|
@ -416,7 +532,7 @@ def get_stack_run_config_from_distro(distro: str) -> StackRunConfig:
|
|||
raise ValueError(f"Distribution '{distro}' not found at {distro_path}")
|
||||
run_config = yaml.safe_load(path.open())
|
||||
|
||||
return StackRunConfig(**replace_env_vars(run_config))
|
||||
return StackRunConfig(**replace_env_vars(run_config, provider_registry=get_provider_registry()))
|
||||
|
||||
|
||||
def run_config_from_adhoc_config_spec(
|
||||
|
@ -452,7 +568,11 @@ def run_config_from_adhoc_config_spec(
|
|||
|
||||
# call method "sample_run_config" on the provider spec config class
|
||||
provider_config_type = instantiate_class_type(provider_spec.config_class)
|
||||
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
|
||||
provider_config = replace_env_vars(
|
||||
provider_config_type.sample_run_config(__distro_dir__=distro_dir),
|
||||
provider_registry=provider_registry,
|
||||
current_provider_context=provider_spec.model_dump(),
|
||||
)
|
||||
|
||||
provider_configs_by_api[api_str] = [
|
||||
Provider(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue