diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index deb19ac41..92015187b 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -8,6 +8,7 @@ import argparse import os from pathlib import Path +from llama_stack.cli.stack.utils import ImageType from llama_stack.cli.subcommand import Subcommand from llama_stack.log import get_logger @@ -43,7 +44,7 @@ class StackRun(Subcommand): self.parser.add_argument( "--image-name", type=str, - default=None, + default=os.environ.get("CONDA_DEFAULT_ENV"), help="Name of the image to run. Defaults to the current conda environment", ) self.parser.add_argument( @@ -72,9 +73,24 @@ class StackRun(Subcommand): "--image-type", type=str, help="Image Type used during the build. This can be either conda or container or venv.", - choices=["conda", "container", "venv"], + choices=[e.value for e in ImageType], ) + # If neither image type nor image name is provided, but at the same time + # the current environment has conda breadcrumbs, then assume what the user + # wants to use conda mode and not the usual default mode (using + # pre-installed system packages). + # + # Note: yes, this is hacky. It's implemented this way to keep the existing + # conda users unaffected by the switch of the default behavior to using + # system packages. + def _get_image_type_and_name(self, args: argparse.Namespace) -> tuple[str, str]: + conda_env = os.environ.get("CONDA_DEFAULT_ENV") + if conda_env and args.image_name == conda_env: + logger.warning(f"Conda detected. Using conda environment {conda_env} for the run.") + return ImageType.CONDA.value, args.image_name + return args.image_type, args.image_name + def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: import yaml @@ -118,9 +134,11 @@ class StackRun(Subcommand): except AttributeError as e: self.parser.error(f"failed to parse config file '{config_file}':\n {e}") + image_type, image_name = self._get_image_type_and_name(args) + # If neither image type nor image name is provided, assume the server should be run directly # using the current environment packages. - if not args.image_type and not args.image_name: + if not image_type and not image_name: logger.info("No image type or image name provided. Assuming environment packages.") from llama_stack.distribution.server.server import main as server_main @@ -137,7 +155,7 @@ class StackRun(Subcommand): # Run the server server_main(server_args) else: - run_args = formulate_run_args(args.image_type, args.image_name, config, template_name) + run_args = formulate_run_args(image_type, image_name, config, template_name) run_args.extend([str(config_file), str(args.port)]) if args.disable_ipv6: diff --git a/llama_stack/cli/stack/utils.py b/llama_stack/cli/stack/utils.py index 1e83a5cc8..74a606b2b 100644 --- a/llama_stack/cli/stack/utils.py +++ b/llama_stack/cli/stack/utils.py @@ -4,6 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum + + +class ImageType(Enum): + CONDA = "conda" + CONTAINER = "container" + VENV = "venv" + def print_subcommand_description(parser, subparsers): """Print descriptions of subcommands."""