diff --git a/src/llama_stack/cli/stack/run.py b/src/llama_stack/cli/stack/run.py index ae35664af..9ceb238fa 100644 --- a/src/llama_stack/cli/stack/run.py +++ b/src/llama_stack/cli/stack/run.py @@ -31,6 +31,7 @@ from llama_stack.core.storage.datatypes import ( ) from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro +from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.log import LoggingConfig, get_logger REPO_ROOT = Path(__file__).parent.parent.parent.parent @@ -132,8 +133,14 @@ class StackRun(Subcommand): ) sys.exit(1) if provider_type in providers_for_api: + config_type = instantiate_class_type(providers_for_api[provider_type].config_class) + if config_type is not None and hasattr(config_type, "sample_run_config"): + config = config_type.sample_run_config(__distro_dir__="~/.llama/distributions/providers-run") + else: + config = {} provider = Provider( provider_type=provider_type, + config=config, provider_id=provider_type.split("::")[1], ) provider_list.setdefault(api, []).append(provider) diff --git a/tests/unit/cli/test_stack_config.py b/tests/unit/cli/test_stack_config.py index 5270b8614..6aefac003 100644 --- a/tests/unit/cli/test_stack_config.py +++ b/tests/unit/cli/test_stack_config.py @@ -268,3 +268,50 @@ def test_generate_run_config_from_providers(): # Verify config can be parsed back parsed = parse_and_maybe_upgrade_config(config_dict) assert parsed.image_name == "providers-run" + + +def test_providers_flag_generates_config_with_api_keys(): + """Test that --providers flag properly generates provider configs including API keys. + + This tests the fix where sample_run_config() is called to populate + API keys and other credentials for remote providers like remote::openai. + """ + import argparse + from unittest.mock import patch + + from llama_stack.cli.stack.run import StackRun + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + stack_run = StackRun(subparsers) + + # Create args with --providers flag set + args = argparse.Namespace( + providers="inference=remote::openai", + config=None, + port=8321, + image_type=None, + image_name=None, + enable_ui=False, + ) + + # Mock _uvicorn_run to prevent starting a server + with patch.object(stack_run, "_uvicorn_run"): + stack_run._run_stack_run_cmd(args) + + # Read the generated config file + from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR + + config_file = DISTRIBS_BASE_DIR / "providers-run" / "run.yaml" + with open(config_file) as f: + config_dict = yaml.safe_load(f) + + # Verify the provider has config with API keys + inference_providers = config_dict["providers"]["inference"] + assert len(inference_providers) == 1 + + openai_provider = inference_providers[0] + assert openai_provider["provider_type"] == "remote::openai" + assert openai_provider["config"], "Provider config should not be empty" + assert "api_key" in openai_provider["config"], "API key should be in provider config" + assert "base_url" in openai_provider["config"], "Base URL should be in provider config"