Move to use argparse, fix issues with multiple --env cmdline options

This commit is contained in:
Ashwin Bharambe 2024-11-18 16:31:59 -08:00
parent b87f3ac499
commit fb15ff4a97
3 changed files with 31 additions and 19 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import functools
import inspect
@ -19,7 +20,6 @@ from contextlib import asynccontextmanager
from ssl import SSLError
from typing import Any, Dict, Optional
import fire
import httpx
import yaml
@ -342,23 +342,36 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
) from e
def main(
yaml_config: str = "llamastack-run.yaml",
port: int = 5000,
disable_ipv6: bool = False,
env: list[str] = None,
):
# Process environment variables from command line
if env:
for env_pair in env:
def main():
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
"--yaml-config",
default="llamastack-run.yaml",
help="Path to YAML configuration file",
)
parser.add_argument("--port", type=int, default=5000, help="Port to listen on")
parser.add_argument(
"--disable-ipv6", action="store_true", help="Whether to disable IPv6 support"
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
args = parser.parse_args()
if args.env:
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
print(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
print(f"Error: {str(e)}")
sys.exit(1)
with open(yaml_config, "r") as fp:
with open(args.yaml_config, "r") as fp:
config = replace_env_vars(yaml.safe_load(fp))
config = StackRunConfig(**config)
@ -425,10 +438,10 @@ def main(
# FYI this does not do hot-reloads
listen_host = ["::", "0.0.0.0"] if not disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
uvicorn.run(app, host=listen_host, port=port)
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{args.port}")
uvicorn.run(app, host=listen_host, port=args.port)
if __name__ == "__main__":
fire.Fire(main)
main()

View file

@ -58,9 +58,8 @@ eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name"
set -x
echo "ENV VARS $env_vars"
$CONDA_PREFIX/bin/python \
-m llama_stack.distribution.server.server \
--yaml_config "$yaml_config" \
--yaml-config "$yaml_config" \
--port "$port" \
"$env_vars"
$env_vars

View file

@ -92,5 +92,5 @@ $DOCKER_BINARY run $DOCKER_OPTS -it \
$mounts \
$docker_image:$version_tag \
python -m llama_stack.distribution.server.server \
--yaml_config /app/config.yaml \
--yaml-config /app/config.yaml \
--port "$port"