mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
feat: load config class when doing variable substitution
When using bash style substitution env variable in distribution template, we are processing the string and convert it to the type associated with the provider's config class. This allows us to return the proper type. This is crucial for api key since they are not strings anymore but SecretStr. If the key is unset we will get an empty string which will result in a Pydantic error like: ``` ERROR 2025-09-25 21:40:44,565 __main__:527 core::server: Error creating app: 1 validation error for AnthropicConfig api_key Input should be a valid string For further information visit https://errors.pydantic.dev/2.11/v/string_type ``` Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
4af141292f
commit
bc64635835
79 changed files with 381 additions and 216 deletions
|
@ -141,12 +141,19 @@ class EnvVarError(Exception):
|
|||
)
|
||||
|
||||
|
||||
def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||
def replace_env_vars(
|
||||
config: Any,
|
||||
path: str = "",
|
||||
provider_registry: dict[Api, dict[str, Any]] | None = None,
|
||||
current_provider_context: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
if isinstance(config, dict):
|
||||
result = {}
|
||||
for k, v in config.items():
|
||||
try:
|
||||
result[k] = replace_env_vars(v, f"{path}.{k}" if path else k)
|
||||
result[k] = replace_env_vars(
|
||||
v, f"{path}.{k}" if path else k, provider_registry, current_provider_context
|
||||
)
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
return result
|
||||
|
@ -159,7 +166,9 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
# is disabled so that we can skip config env variable expansion and avoid validation errors
|
||||
if isinstance(v, dict) and "provider_id" in v:
|
||||
try:
|
||||
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
|
||||
resolved_provider_id = replace_env_vars(
|
||||
v["provider_id"], f"{path}[{i}].provider_id", provider_registry, current_provider_context
|
||||
)
|
||||
if resolved_provider_id == "__disabled__":
|
||||
logger.debug(
|
||||
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
|
||||
|
@ -167,13 +176,19 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
# Create a copy with resolved provider_id but original config
|
||||
disabled_provider = v.copy()
|
||||
disabled_provider["provider_id"] = resolved_provider_id
|
||||
result.append(disabled_provider)
|
||||
continue
|
||||
except EnvVarError:
|
||||
# If we can't resolve the provider_id, continue with normal processing
|
||||
pass
|
||||
|
||||
# Set up provider context for config processing
|
||||
provider_context = current_provider_context
|
||||
if isinstance(v, dict) and "provider_id" in v and "provider_type" in v and provider_registry:
|
||||
provider_context = _get_provider_context(v, provider_registry)
|
||||
|
||||
# Normal processing for non-disabled providers
|
||||
result.append(replace_env_vars(v, f"{path}[{i}]"))
|
||||
result.append(replace_env_vars(v, f"{path}[{i}]", provider_registry, provider_context))
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
return result
|
||||
|
@ -228,7 +243,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
result = re.sub(pattern, get_env_var, config)
|
||||
# Only apply type conversion if substitution actually happened
|
||||
if result != config:
|
||||
return _convert_string_to_proper_type(result)
|
||||
return _convert_string_to_proper_type_with_config(result, path, current_provider_context)
|
||||
return result
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
|
@ -236,12 +251,107 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
return config
|
||||
|
||||
|
||||
def _get_provider_context(
|
||||
provider_dict: dict[str, Any], provider_registry: dict[Api, dict[str, Any]]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get provider context information including config class for type conversion."""
|
||||
try:
|
||||
provider_type = provider_dict.get("provider_type")
|
||||
if not provider_type:
|
||||
return None
|
||||
|
||||
for api, providers in provider_registry.items():
|
||||
if provider_type in providers:
|
||||
provider_spec = providers[provider_type]
|
||||
|
||||
config_class = instantiate_class_type(provider_spec.config_class)
|
||||
|
||||
return {
|
||||
"api": api,
|
||||
"provider_type": provider_type,
|
||||
"config_class": config_class,
|
||||
"provider_spec": provider_spec,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get provider context: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _convert_string_to_proper_type_with_config(value: str, path: str, provider_context: dict[str, Any] | None) -> Any:
|
||||
"""Convert string to proper type using provider config class field information."""
|
||||
if not provider_context or not provider_context.get("config_class"):
|
||||
# best effort conversion if we don't have the config class
|
||||
return _convert_string_to_proper_type(value)
|
||||
|
||||
try:
|
||||
# Extract field name from path (e.g., "providers.inference[0].config.api_key" -> "api_key")
|
||||
field_name = path.split(".")[-1] if "." in path else path
|
||||
|
||||
config_class = provider_context["config_class"]
|
||||
|
||||
if hasattr(config_class, "model_fields") and field_name in config_class.model_fields:
|
||||
field_info = config_class.model_fields[field_name]
|
||||
field_type = field_info.annotation
|
||||
return _convert_value_by_field_type(value, field_type)
|
||||
else:
|
||||
return _convert_string_to_proper_type(value)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to convert using config class: {e}")
|
||||
return _convert_string_to_proper_type(value)
|
||||
|
||||
|
||||
def _convert_value_by_field_type(value: str, field_type: Any) -> Any:
|
||||
"""Convert string value based on Pydantic field type annotation."""
|
||||
import typing
|
||||
from typing import get_args, get_origin
|
||||
|
||||
if value == "":
|
||||
if field_type is None or (hasattr(typing, "get_origin") and get_origin(field_type) is type(None)):
|
||||
return None
|
||||
if hasattr(typing, "get_origin") and get_origin(field_type) is typing.Union:
|
||||
args = get_args(field_type)
|
||||
if type(None) in args:
|
||||
return None
|
||||
return ""
|
||||
|
||||
if field_type is bool or (hasattr(typing, "get_origin") and get_origin(field_type) is bool):
|
||||
lowered = value.lower()
|
||||
if lowered == "true":
|
||||
return True
|
||||
elif lowered == "false":
|
||||
return False
|
||||
else:
|
||||
return value
|
||||
|
||||
if field_type is int or (hasattr(typing, "get_origin") and get_origin(field_type) is int):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
if field_type is float or (hasattr(typing, "get_origin") and get_origin(field_type) is float):
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
if hasattr(typing, "get_origin") and get_origin(field_type) is typing.Union:
|
||||
args = get_args(field_type)
|
||||
# Try to convert to the first non-None type
|
||||
for arg in args:
|
||||
if arg is not type(None):
|
||||
try:
|
||||
return _convert_value_by_field_type(value, arg)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _convert_string_to_proper_type(value: str) -> Any:
|
||||
# This might be tricky depending on what the config type is, if 'str | None' we are
|
||||
# good, if 'str' we need to keep the empty string... 'str | None' is more common and
|
||||
# providers config should be typed this way.
|
||||
# TODO: we could try to load the config class and see if the config has a field with type 'str | None'
|
||||
# and then convert the empty string to None or not
|
||||
# Fallback function for when provider config class is not available
|
||||
# The main type conversion logic is now in _convert_string_to_proper_type_with_config
|
||||
if value == "":
|
||||
return None
|
||||
|
||||
|
@ -416,7 +526,7 @@ def get_stack_run_config_from_distro(distro: str) -> StackRunConfig:
|
|||
raise ValueError(f"Distribution '{distro}' not found at {distro_path}")
|
||||
run_config = yaml.safe_load(path.open())
|
||||
|
||||
return StackRunConfig(**replace_env_vars(run_config))
|
||||
return StackRunConfig(**replace_env_vars(run_config, provider_registry=get_provider_registry()))
|
||||
|
||||
|
||||
def run_config_from_adhoc_config_spec(
|
||||
|
@ -452,7 +562,9 @@ def run_config_from_adhoc_config_spec(
|
|||
|
||||
# call method "sample_run_config" on the provider spec config class
|
||||
provider_config_type = instantiate_class_type(provider_spec.config_class)
|
||||
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
|
||||
provider_config = replace_env_vars(
|
||||
provider_config_type.sample_run_config(__distro_dir__=distro_dir), provider_registry=provider_registry
|
||||
)
|
||||
|
||||
provider_configs_by_api[api_str] = [
|
||||
Provider(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue