diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index f07a0f873..3fe615e6e 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -22,7 +22,7 @@ def available_templates_specs() -> List[BuildConfig]: import yaml template_specs = [] - for p in TEMPLATES_PATH.rglob("*.yaml"): + for p in TEMPLATES_PATH.rglob("*build.yaml"): with open(p, "r") as f: build_config = BuildConfig(**yaml.safe_load(f)) template_specs.append(build_config) diff --git a/llama_stack/cli/stack/configure.py b/llama_stack/cli/stack/configure.py index 13899715b..021134e6d 100644 --- a/llama_stack/cli/stack/configure.py +++ b/llama_stack/cli/stack/configure.py @@ -152,7 +152,7 @@ class StackConfigure(Subcommand): config = StackRunConfig( built_at=datetime.now(), image_name=image_name, - apis=[], + apis=list(build_config.distribution_spec.providers.keys()), providers={}, models=[], shields=[], diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index 033b2a81f..dd4247e4b 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -7,7 +7,6 @@ import argparse from llama_stack.cli.subcommand import Subcommand -from llama_stack.distribution.datatypes import * # noqa: F403 class StackRun(Subcommand): @@ -49,8 +48,8 @@ class StackRun(Subcommand): from termcolor import cprint from llama_stack.distribution.build import ImageType + from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR - from llama_stack.distribution.utils.exec import run_with_pty if not args.config: @@ -78,7 +77,8 @@ class StackRun(Subcommand): cprint(f"Using config `{config_file}`", "green") with open(config_file, "r") as f: - config = StackRunConfig(**yaml.safe_load(f)) + config_dict = yaml.safe_load(config_file.read_text()) + config = parse_and_maybe_upgrade_config(config_dict) if config.docker_image: script = pkg_resources.resource_filename( diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 12f225af2..f533422fe 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -64,9 +64,6 @@ def configure_api_providers( ) -> StackRunConfig: is_nux = len(config.providers) == 0 - # keep this default so all APIs are served - config.apis = [] - if is_nux: print( textwrap.dedent( @@ -79,7 +76,12 @@ def configure_api_providers( provider_registry = get_provider_registry() builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()] - apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)] + + if config.apis: + apis_to_serve = config.apis + else: + apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)] + for api_str in apis_to_serve: api = Api(api_str) if api in builtin_apis: @@ -153,7 +155,7 @@ def configure_api_providers( "shields": (ShieldDef, configure_shields, "safety"), "memory_banks": (MemoryBankDef, configure_memory_banks, "memory"), } - safety_providers = config.providers["safety"] + safety_providers = config.providers.get("safety", []) for otype, (odef, config_method, api_str) in object_types.items(): existing_objects = getattr(config, otype) @@ -166,9 +168,15 @@ def configure_api_providers( ) updated_objects = existing_objects else: - # we are newly configuring this API - cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"]) - updated_objects = config_method(config.providers[api_str], safety_providers) + providers = config.providers.get(api_str, []) + if not providers: + updated_objects = [] + else: + # we are newly configuring this API + cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"]) + updated_objects = config_method( + config.providers[api_str], safety_providers + ) setattr(config, otype, updated_objects) print("") @@ -277,7 +285,7 @@ def upgrade_from_routing_table_to_registry( shields = [] memory_banks = [] - routing_table = config_dict["routing_table"] + routing_table = config_dict.get("routing_table", {}) for api_str, entries in routing_table.items(): providers = get_providers(entries) providers_by_api[api_str] = providers @@ -324,15 +332,13 @@ def upgrade_from_routing_table_to_registry( config_dict["shields"] = shields config_dict["memory_banks"] = memory_banks - if "api_providers" in config_dict: - for api_str, provider in config_dict["api_providers"].items(): - if api_str in ("inference", "safety", "memory"): - continue - + provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {})) + if provider_map: + for api_str, provider in provider_map.items(): if isinstance(provider, dict): providers_by_api[api_str] = [ Provider( - provider_id=f"{provider['provider_type']}-00", + provider_id=f"{provider['provider_type']}", provider_type=provider["provider_type"], config=provider["config"], ) @@ -340,11 +346,12 @@ def upgrade_from_routing_table_to_registry( config_dict["providers"] = providers_by_api - del config_dict["routing_table"] - del config_dict["api_providers"] + config_dict.pop("routing_table", None) + config_dict.pop("api_providers", None) + config_dict.pop("provider_map", None) config_dict["apis"] = config_dict["apis_to_serve"] - del config_dict["apis_to_serve"] + config_dict.pop("apis_to_serve", None) return config_dict