feat: load config class when doing variable substitution

When using bash style substitution env variable in distribution
template, we are processing the string and convert it to the type
associated with the provider's config class. This allows us to return
the proper type. This is crucial for api key since they are not strings
anymore but SecretStr. If the key is unset we will get an empty string
which will result in a Pydantic error like:

```
ERROR    2025-09-25 21:40:44,565 __main__:527 core::server: Error creating app: 1 validation error for AnthropicConfig
         api_key
           Input should be a valid string
             For further information visit
             https://errors.pydantic.dev/2.11/v/string_type
```

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-09-25 10:27:41 +02:00
parent 4af141292f
commit bc64635835
No known key found for this signature in database
79 changed files with 381 additions and 216 deletions

View file

@ -141,12 +141,19 @@ class EnvVarError(Exception):
)
def replace_env_vars(config: Any, path: str = "") -> Any:
def replace_env_vars(
config: Any,
path: str = "",
provider_registry: dict[Api, dict[str, Any]] | None = None,
current_provider_context: dict[str, Any] | None = None,
) -> Any:
if isinstance(config, dict):
result = {}
for k, v in config.items():
try:
result[k] = replace_env_vars(v, f"{path}.{k}" if path else k)
result[k] = replace_env_vars(
v, f"{path}.{k}" if path else k, provider_registry, current_provider_context
)
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return result
@ -159,7 +166,9 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
# is disabled so that we can skip config env variable expansion and avoid validation errors
if isinstance(v, dict) and "provider_id" in v:
try:
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
resolved_provider_id = replace_env_vars(
v["provider_id"], f"{path}[{i}].provider_id", provider_registry, current_provider_context
)
if resolved_provider_id == "__disabled__":
logger.debug(
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
@ -167,13 +176,19 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
# Create a copy with resolved provider_id but original config
disabled_provider = v.copy()
disabled_provider["provider_id"] = resolved_provider_id
result.append(disabled_provider)
continue
except EnvVarError:
# If we can't resolve the provider_id, continue with normal processing
pass
# Set up provider context for config processing
provider_context = current_provider_context
if isinstance(v, dict) and "provider_id" in v and "provider_type" in v and provider_registry:
provider_context = _get_provider_context(v, provider_registry)
# Normal processing for non-disabled providers
result.append(replace_env_vars(v, f"{path}[{i}]"))
result.append(replace_env_vars(v, f"{path}[{i}]", provider_registry, provider_context))
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return result
@ -228,7 +243,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
result = re.sub(pattern, get_env_var, config)
# Only apply type conversion if substitution actually happened
if result != config:
return _convert_string_to_proper_type(result)
return _convert_string_to_proper_type_with_config(result, path, current_provider_context)
return result
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
@ -236,12 +251,107 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
return config
def _get_provider_context(
provider_dict: dict[str, Any], provider_registry: dict[Api, dict[str, Any]]
) -> dict[str, Any] | None:
"""Get provider context information including config class for type conversion."""
try:
provider_type = provider_dict.get("provider_type")
if not provider_type:
return None
for api, providers in provider_registry.items():
if provider_type in providers:
provider_spec = providers[provider_type]
config_class = instantiate_class_type(provider_spec.config_class)
return {
"api": api,
"provider_type": provider_type,
"config_class": config_class,
"provider_spec": provider_spec,
}
except Exception as e:
logger.debug(f"Failed to get provider context: {e}")
return None
def _convert_string_to_proper_type_with_config(value: str, path: str, provider_context: dict[str, Any] | None) -> Any:
"""Convert string to proper type using provider config class field information."""
if not provider_context or not provider_context.get("config_class"):
# best effort conversion if we don't have the config class
return _convert_string_to_proper_type(value)
try:
# Extract field name from path (e.g., "providers.inference[0].config.api_key" -> "api_key")
field_name = path.split(".")[-1] if "." in path else path
config_class = provider_context["config_class"]
if hasattr(config_class, "model_fields") and field_name in config_class.model_fields:
field_info = config_class.model_fields[field_name]
field_type = field_info.annotation
return _convert_value_by_field_type(value, field_type)
else:
return _convert_string_to_proper_type(value)
except Exception as e:
logger.debug(f"Failed to convert using config class: {e}")
return _convert_string_to_proper_type(value)
def _convert_value_by_field_type(value: str, field_type: Any) -> Any:
"""Convert string value based on Pydantic field type annotation."""
import typing
from typing import get_args, get_origin
if value == "":
if field_type is None or (hasattr(typing, "get_origin") and get_origin(field_type) is type(None)):
return None
if hasattr(typing, "get_origin") and get_origin(field_type) is typing.Union:
args = get_args(field_type)
if type(None) in args:
return None
return ""
if field_type is bool or (hasattr(typing, "get_origin") and get_origin(field_type) is bool):
lowered = value.lower()
if lowered == "true":
return True
elif lowered == "false":
return False
else:
return value
if field_type is int or (hasattr(typing, "get_origin") and get_origin(field_type) is int):
try:
return int(value)
except ValueError:
return value
if field_type is float or (hasattr(typing, "get_origin") and get_origin(field_type) is float):
try:
return float(value)
except ValueError:
return value
if hasattr(typing, "get_origin") and get_origin(field_type) is typing.Union:
args = get_args(field_type)
# Try to convert to the first non-None type
for arg in args:
if arg is not type(None):
try:
return _convert_value_by_field_type(value, arg)
except Exception:
continue
return value
def _convert_string_to_proper_type(value: str) -> Any:
# This might be tricky depending on what the config type is, if 'str | None' we are
# good, if 'str' we need to keep the empty string... 'str | None' is more common and
# providers config should be typed this way.
# TODO: we could try to load the config class and see if the config has a field with type 'str | None'
# and then convert the empty string to None or not
# Fallback function for when provider config class is not available
# The main type conversion logic is now in _convert_string_to_proper_type_with_config
if value == "":
return None
@ -416,7 +526,7 @@ def get_stack_run_config_from_distro(distro: str) -> StackRunConfig:
raise ValueError(f"Distribution '{distro}' not found at {distro_path}")
run_config = yaml.safe_load(path.open())
return StackRunConfig(**replace_env_vars(run_config))
return StackRunConfig(**replace_env_vars(run_config, provider_registry=get_provider_registry()))
def run_config_from_adhoc_config_spec(
@ -452,7 +562,9 @@ def run_config_from_adhoc_config_spec(
# call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
provider_config = replace_env_vars(
provider_config_type.sample_run_config(__distro_dir__=distro_dir), provider_registry=provider_registry
)
provider_configs_by_api[api_str] = [
Provider(