From 4f3c19114ca00bb86cf377dd7cc687eb457b66a8 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 6 Feb 2025 16:04:25 -0800 Subject: [PATCH] Add HTTPS serving option --- llama_stack/cli/stack/run.py | 13 ++++++ llama_stack/distribution/datatypes.py | 20 ++++++++++ llama_stack/distribution/server/server.py | 44 +++++++++++++++++++-- llama_stack/distribution/start_conda_env.sh | 5 ++- llama_stack/distribution/start_container.sh | 12 +++++- 5 files changed, 88 insertions(+), 6 deletions(-) diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index f84def184..502dfbed4 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -55,6 +55,16 @@ class StackRun(Subcommand): default=[], metavar="KEY=VALUE", ) + self.parser.add_argument( + "--ssl-keyfile", + type=str, + help="Path to SSL key file for HTTPS", + ) + self.parser.add_argument( + "--ssl-certfile", + type=str, + help="Path to SSL certificate file for HTTPS", + ) def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: import importlib.resources @@ -178,4 +188,7 @@ class StackRun(Subcommand): return run_args.extend(["--env", f"{key}={value}"]) + if args.ssl_keyfile and args.ssl_certfile: + run_args.extend(["--ssl-keyfile", args.ssl_keyfile, "--ssl-certfile", args.ssl_certfile]) + run_with_pty(run_args) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 8b579b636..a9b64398e 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -117,6 +117,21 @@ class Provider(BaseModel): config: Dict[str, Any] +class ServerConfig(BaseModel): + port: int = Field( + default=8321, + description="Port to listen on", + ) + ssl_certfile: Optional[str] = Field( + default=None, + description="Path to SSL certificate file for HTTPS", + ) + ssl_keyfile: Optional[str] = Field( + default=None, + description="Path to SSL key file for HTTPS", + ) + + class StackRunConfig(BaseModel): version: str = LLAMA_STACK_RUN_CONFIG_VERSION @@ -159,6 +174,11 @@ a default SQLite store will be used.""", eval_tasks: List[EvalTaskInput] = Field(default_factory=list) tool_groups: List[ToolGroupInput] = Field(default_factory=list) + server: ServerConfig = Field( + default_factory=ServerConfig, + description="Configuration for the HTTP(S) server", + ) + class BuildConfig(BaseModel): version: str = LLAMA_STACK_BUILD_CONFIG_VERSION diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index fcd0e3cad..69d3e3a62 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -282,6 +282,19 @@ def main(): action="append", help="Environment variables in KEY=value format. Can be specified multiple times.", ) + parser.add_argument( + "--ssl-keyfile", + help="Path to SSL key file for HTTPS", + ) + parser.add_argument( + "--ssl-certfile", + help="Path to SSL certificate file for HTTPS", + ) + + if args.ssl_keyfile and not args.ssl_certfile: + parser.error("You must provide both --ssl-keyfile and --ssl-certfile when using HTTPS") + if args.ssl_certfile and not args.ssl_keyfile: + parser.error("You must provide both --ssl-keyfile and --ssl-certfile when using HTTPS") args = parser.parse_args() if args.env: @@ -381,11 +394,36 @@ def main(): import uvicorn - # FYI this does not do hot-reloads + # Configure SSL if certificates are provided + port = args.port or config.server.port + + ssl_config = None + if args.ssl_keyfile: + keyfile = args.ssl_keyfile + certfile = args.ssl_certfile + else: + keyfile = config.server.ssl_keyfile + certfile = config.server.ssl_certfile + + if keyfile and certfile: + ssl_config = { + "ssl_keyfile": keyfile, + "ssl_certfile": certfile, + } + print(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0" - print(f"Listening on {listen_host}:{args.port}") - uvicorn.run(app, host=listen_host, port=args.port) + print(f"Listening on {listen_host}:{port}") + + uvicorn_config = { + "app": app, + "host": listen_host, + "port": port, + } + if ssl_config: + uvicorn_config.update(ssl_config) + + uvicorn.run(**uvicorn_config) def extract_path_params(route: str) -> List[str]: diff --git a/llama_stack/distribution/start_conda_env.sh b/llama_stack/distribution/start_conda_env.sh index c37f30ef0..fe830059f 100755 --- a/llama_stack/distribution/start_conda_env.sh +++ b/llama_stack/distribution/start_conda_env.sh @@ -34,6 +34,7 @@ shift # Process environment variables from --env arguments env_vars="" +other_args="" while [[ $# -gt 0 ]]; do case "$1" in --env) @@ -48,6 +49,7 @@ while [[ $# -gt 0 ]]; do fi ;; *) + other_args="$other_args $1" shift ;; esac @@ -61,4 +63,5 @@ $CONDA_PREFIX/bin/python \ -m llama_stack.distribution.server.server \ --yaml-config "$yaml_config" \ --port "$port" \ - $env_vars + $env_vars \ + $other_args diff --git a/llama_stack/distribution/start_container.sh b/llama_stack/distribution/start_container.sh index 2c5d65d09..a5f543fb4 100755 --- a/llama_stack/distribution/start_container.sh +++ b/llama_stack/distribution/start_container.sh @@ -40,8 +40,12 @@ shift port="$1" shift +# Initialize other_args +other_args="" + # Process environment variables from --env arguments env_vars="" + while [[ $# -gt 0 ]]; do case "$1" in --env) @@ -55,6 +59,7 @@ while [[ $# -gt 0 ]]; do fi ;; *) + other_args="$other_args $1" shift ;; esac @@ -93,5 +98,8 @@ $CONTAINER_BINARY run $CONTAINER_OPTS -it \ -v "$yaml_config:/app/config.yaml" \ $mounts \ --env LLAMA_STACK_PORT=$port \ - --entrypoint='["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]' \ - $container_image:$version_tag + --entrypoint python \ + $container_image:$version_tag \ + -m llama_stack.distribution.server.server \ + --yaml-config /app/config.yaml \ + $other_args