diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index c56d2c780..ccd345181 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import argparse import asyncio import functools import inspect @@ -19,7 +20,6 @@ from contextlib import asynccontextmanager from ssl import SSLError from typing import Any, Dict, Optional -import fire import httpx import yaml @@ -342,23 +342,36 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]: ) from e -def main( - yaml_config: str = "llamastack-run.yaml", - port: int = 5000, - disable_ipv6: bool = False, - env: list[str] = None, -): - # Process environment variables from command line - if env: - for env_pair in env: +def main(): + """Start the LlamaStack server.""" + parser = argparse.ArgumentParser(description="Start the LlamaStack server.") + parser.add_argument( + "--yaml-config", + default="llamastack-run.yaml", + help="Path to YAML configuration file", + ) + parser.add_argument("--port", type=int, default=5000, help="Port to listen on") + parser.add_argument( + "--disable-ipv6", action="store_true", help="Whether to disable IPv6 support" + ) + parser.add_argument( + "--env", + action="append", + help="Environment variables in KEY=value format. Can be specified multiple times.", + ) + + args = parser.parse_args() + if args.env: + for env_pair in args.env: try: key, value = validate_env_pair(env_pair) + print(f"Setting CLI environment variable {key} => {value}") os.environ[key] = value except ValueError as e: print(f"Error: {str(e)}") sys.exit(1) - with open(yaml_config, "r") as fp: + with open(args.yaml_config, "r") as fp: config = replace_env_vars(yaml.safe_load(fp)) config = StackRunConfig(**config) @@ -425,10 +438,10 @@ def main( # FYI this does not do hot-reloads - listen_host = ["::", "0.0.0.0"] if not disable_ipv6 else "0.0.0.0" - print(f"Listening on {listen_host}:{port}") - uvicorn.run(app, host=listen_host, port=port) + listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0" + print(f"Listening on {listen_host}:{args.port}") + uvicorn.run(app, host=listen_host, port=args.port) if __name__ == "__main__": - fire.Fire(main) + main() diff --git a/llama_stack/distribution/start_conda_env.sh b/llama_stack/distribution/start_conda_env.sh index d75b4afc9..f478a8bd8 100755 --- a/llama_stack/distribution/start_conda_env.sh +++ b/llama_stack/distribution/start_conda_env.sh @@ -58,9 +58,8 @@ eval "$(conda shell.bash hook)" conda deactivate && conda activate "$env_name" set -x -echo "ENV VARS $env_vars" $CONDA_PREFIX/bin/python \ -m llama_stack.distribution.server.server \ - --yaml_config "$yaml_config" \ + --yaml-config "$yaml_config" \ --port "$port" \ - "$env_vars" + $env_vars diff --git a/llama_stack/distribution/start_container.sh b/llama_stack/distribution/start_container.sh index c56606826..34476c8e0 100755 --- a/llama_stack/distribution/start_container.sh +++ b/llama_stack/distribution/start_container.sh @@ -92,5 +92,5 @@ $DOCKER_BINARY run $DOCKER_OPTS -it \ $mounts \ $docker_image:$version_tag \ python -m llama_stack.distribution.server.server \ - --yaml_config /app/config.yaml \ + --yaml-config /app/config.yaml \ --port "$port"