mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
fix: generate provider config when using --providers
call the sample_run_config method for providers that have it when generating a run config using `llama stack run --providers`. This will propagate API keys resolves #4032 Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
7e294d33d9
commit
23b8535022
2 changed files with 53 additions and 0 deletions
|
|
@ -31,6 +31,7 @@ from llama_stack.core.storage.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||||
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.log import LoggingConfig, get_logger
|
from llama_stack.log import LoggingConfig, get_logger
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
@ -132,8 +133,14 @@ class StackRun(Subcommand):
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
if provider_type in providers_for_api:
|
if provider_type in providers_for_api:
|
||||||
|
config_type = instantiate_class_type(providers_for_api[provider_type].config_class)
|
||||||
|
if config_type is not None and hasattr(config_type, "sample_run_config"):
|
||||||
|
config = config_type.sample_run_config(__distro_dir__="~/.llama/distributions/providers-run")
|
||||||
|
else:
|
||||||
|
config = {}
|
||||||
provider = Provider(
|
provider = Provider(
|
||||||
provider_type=provider_type,
|
provider_type=provider_type,
|
||||||
|
config=config,
|
||||||
provider_id=provider_type.split("::")[1],
|
provider_id=provider_type.split("::")[1],
|
||||||
)
|
)
|
||||||
provider_list.setdefault(api, []).append(provider)
|
provider_list.setdefault(api, []).append(provider)
|
||||||
|
|
|
||||||
|
|
@ -268,3 +268,49 @@ def test_generate_run_config_from_providers():
|
||||||
# Verify config can be parsed back
|
# Verify config can be parsed back
|
||||||
parsed = parse_and_maybe_upgrade_config(config_dict)
|
parsed = parse_and_maybe_upgrade_config(config_dict)
|
||||||
assert parsed.image_name == "providers-run"
|
assert parsed.image_name == "providers-run"
|
||||||
|
|
||||||
|
|
||||||
|
def test_providers_flag_generates_config_with_api_keys():
|
||||||
|
"""Test that --providers flag properly generates provider configs including API keys.
|
||||||
|
|
||||||
|
This tests the fix where sample_run_config() is called to populate
|
||||||
|
API keys and other credentials for remote providers like remote::openai.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from llama_stack.cli.stack.run import StackRun
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
subparsers = parser.add_subparsers()
|
||||||
|
stack_run = StackRun(subparsers)
|
||||||
|
|
||||||
|
# Create args with --providers flag set
|
||||||
|
args = MagicMock()
|
||||||
|
args.providers = "inference=remote::openai"
|
||||||
|
args.config = None
|
||||||
|
args.port = 8321
|
||||||
|
args.image_type = None
|
||||||
|
args.image_name = None
|
||||||
|
args.enable_ui = False
|
||||||
|
|
||||||
|
# Mock _uvicorn_run to prevent starting a server
|
||||||
|
with patch.object(stack_run, "_uvicorn_run"):
|
||||||
|
stack_run._run_stack_run_cmd(args)
|
||||||
|
|
||||||
|
# Read the generated config file
|
||||||
|
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
|
|
||||||
|
config_file = DISTRIBS_BASE_DIR / "providers-run" / "run.yaml"
|
||||||
|
with open(config_file) as f:
|
||||||
|
config_dict = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Verify the provider has config with API keys
|
||||||
|
inference_providers = config_dict["providers"]["inference"]
|
||||||
|
assert len(inference_providers) == 1
|
||||||
|
|
||||||
|
openai_provider = inference_providers[0]
|
||||||
|
assert openai_provider["provider_type"] == "remote::openai"
|
||||||
|
assert openai_provider["config"], "Provider config should not be empty"
|
||||||
|
assert "api_key" in openai_provider["config"], "API key should be in provider config"
|
||||||
|
assert "base_url" in openai_provider["config"], "Base URL should be in provider config"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue