mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat: llama stack run --providers (#3989)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Python Package Build Test / build (3.13) (push) Failing after 1s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Python Package Build Test / build (3.12) (push) Failing after 3s
Pre-commit / pre-commit (push) Failing after 5s
Vector IO Integration Tests / test-matrix (push) Failing after 5s
Test Llama Stack Build / build-single-provider (push) Failing after 5s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 4s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 5s
API Conformance Tests / check-schema-compatibility (push) Successful in 10s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
Unit Tests / unit-tests (3.12) (push) Failing after 5s
Test External API and Providers / test-external (venv) (push) Failing after 6s
Test Llama Stack Build / build (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 56s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Python Package Build Test / build (3.13) (push) Failing after 1s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Python Package Build Test / build (3.12) (push) Failing after 3s
Pre-commit / pre-commit (push) Failing after 5s
Vector IO Integration Tests / test-matrix (push) Failing after 5s
Test Llama Stack Build / build-single-provider (push) Failing after 5s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 4s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 5s
API Conformance Tests / check-schema-compatibility (push) Successful in 10s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
Unit Tests / unit-tests (3.12) (push) Failing after 5s
Test External API and Providers / test-external (venv) (push) Failing after 6s
Test Llama Stack Build / build (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 56s
# What does this PR do? llama stack run --providers takes a list of providers in the format of api1=provider1,api2=provider2 this allows users to run with a simple list of providers. given the architecture of `create_app`, this run config needs to be written to disk. use ~/.llama/distribution/providers-run/run.yaml each time for consistency resolves #3956 ## Test Plan new unit tests to ensure --providers. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
b2a5428a14
commit
93401836b7
2 changed files with 143 additions and 1 deletions
|
|
@ -8,15 +8,28 @@ import argparse
|
||||||
import os
|
import os
|
||||||
import ssl
|
import ssl
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import yaml
|
import yaml
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.cli.stack.utils import ImageType
|
from llama_stack.cli.stack.utils import ImageType
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||||
|
from llama_stack.core.distribution import get_provider_registry
|
||||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||||
|
from llama_stack.core.storage.datatypes import (
|
||||||
|
InferenceStoreReference,
|
||||||
|
KVStoreReference,
|
||||||
|
ServerStoresConfig,
|
||||||
|
SqliteKVStoreConfig,
|
||||||
|
SqliteSqlStoreConfig,
|
||||||
|
SqlStoreReference,
|
||||||
|
StorageConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||||
from llama_stack.log import LoggingConfig, get_logger
|
from llama_stack.log import LoggingConfig, get_logger
|
||||||
|
|
||||||
|
|
@ -68,6 +81,12 @@ class StackRun(Subcommand):
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Start the UI server",
|
help="Start the UI server",
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--providers",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Run a stack with only a list of providers. This list is formatted like: api1=provider1,api1=provider2,api2=provider3. Where there can be multiple providers per API.",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -93,6 +112,49 @@ class StackRun(Subcommand):
|
||||||
config_file = resolve_config_or_distro(args.config, Mode.RUN)
|
config_file = resolve_config_or_distro(args.config, Mode.RUN)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
self.parser.error(str(e))
|
self.parser.error(str(e))
|
||||||
|
elif args.providers:
|
||||||
|
provider_list: dict[str, list[Provider]] = dict()
|
||||||
|
for api_provider in args.providers.split(","):
|
||||||
|
if "=" not in api_provider:
|
||||||
|
cprint(
|
||||||
|
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
api, provider_type = api_provider.split("=")
|
||||||
|
providers_for_api = get_provider_registry().get(Api(api), None)
|
||||||
|
if providers_for_api is None:
|
||||||
|
cprint(
|
||||||
|
f"{api} is not a valid API.",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
if provider_type in providers_for_api:
|
||||||
|
provider = Provider(
|
||||||
|
provider_type=provider_type,
|
||||||
|
provider_id=provider_type.split("::")[1],
|
||||||
|
)
|
||||||
|
provider_list.setdefault(api, []).append(provider)
|
||||||
|
else:
|
||||||
|
cprint(
|
||||||
|
f"{provider} is not a valid provider for the {api} API.",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
run_config = self._generate_run_config_from_providers(providers=provider_list)
|
||||||
|
config_dict = run_config.model_dump(mode="json")
|
||||||
|
|
||||||
|
# Write config to disk in providers-run directory
|
||||||
|
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
|
||||||
|
config_file = distro_dir / "run.yaml"
|
||||||
|
|
||||||
|
logger.info(f"Writing generated config to: {config_file}")
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
config_file = None
|
config_file = None
|
||||||
|
|
||||||
|
|
@ -214,3 +276,44 @@ class StackRun(Subcommand):
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
||||||
|
|
||||||
|
def _generate_run_config_from_providers(self, providers: dict[str, list[Provider]]):
|
||||||
|
apis = list(providers.keys())
|
||||||
|
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
|
||||||
|
# need somewhere to put the storage.
|
||||||
|
os.makedirs(distro_dir, exist_ok=True)
|
||||||
|
storage = StorageConfig(
|
||||||
|
backends={
|
||||||
|
"kv_default": SqliteKVStoreConfig(
|
||||||
|
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/kvstore.db",
|
||||||
|
),
|
||||||
|
"sql_default": SqliteSqlStoreConfig(
|
||||||
|
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/sql_store.db",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
stores=ServerStoresConfig(
|
||||||
|
metadata=KVStoreReference(
|
||||||
|
backend="kv_default",
|
||||||
|
namespace="registry",
|
||||||
|
),
|
||||||
|
inference=InferenceStoreReference(
|
||||||
|
backend="sql_default",
|
||||||
|
table_name="inference_store",
|
||||||
|
),
|
||||||
|
conversations=SqlStoreReference(
|
||||||
|
backend="sql_default",
|
||||||
|
table_name="openai_conversations",
|
||||||
|
),
|
||||||
|
prompts=KVStoreReference(
|
||||||
|
backend="kv_default",
|
||||||
|
namespace="prompts",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return StackRunConfig(
|
||||||
|
image_name="providers-run",
|
||||||
|
apis=apis,
|
||||||
|
providers=providers,
|
||||||
|
storage=storage,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -229,3 +229,42 @@ def test_parse_and_maybe_upgrade_config_preserves_custom_external_providers_dir(
|
||||||
|
|
||||||
# Verify the custom value was preserved
|
# Verify the custom value was preserved
|
||||||
assert str(result.external_providers_dir) == custom_dir
|
assert str(result.external_providers_dir) == custom_dir
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_run_config_from_providers():
|
||||||
|
"""Test that _generate_run_config_from_providers creates a valid config"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from llama_stack.cli.stack.run import StackRun
|
||||||
|
from llama_stack.core.datatypes import Provider
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
subparsers = parser.add_subparsers()
|
||||||
|
stack_run = StackRun(subparsers)
|
||||||
|
|
||||||
|
providers = {
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_type="inline::meta-reference",
|
||||||
|
provider_id="meta-reference",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
config = stack_run._generate_run_config_from_providers(providers=providers)
|
||||||
|
config_dict = config.model_dump(mode="json")
|
||||||
|
|
||||||
|
# Verify basic structure
|
||||||
|
assert config_dict["image_name"] == "providers-run"
|
||||||
|
assert "inference" in config_dict["apis"]
|
||||||
|
assert "inference" in config_dict["providers"]
|
||||||
|
|
||||||
|
# Verify storage has all required stores including prompts
|
||||||
|
assert "storage" in config_dict
|
||||||
|
stores = config_dict["storage"]["stores"]
|
||||||
|
assert "prompts" in stores
|
||||||
|
assert stores["prompts"]["namespace"] == "prompts"
|
||||||
|
|
||||||
|
# Verify config can be parsed back
|
||||||
|
parsed = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
assert parsed.image_name == "providers-run"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue