diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 7de009b87..d36e21c6d 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -249,6 +249,10 @@ class ServerConfig(BaseModel): default=None, description="Path to TLS key file for HTTPS", ) + tls_cafile: str | None = Field( + default=None, + description="Path to TLS CA file for HTTPS with mutual TLS authentication", + ) auth: AuthenticationConfig | None = Field( default=None, description="Authentication configuration for the server", diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index e34a62b00..32046d2b1 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -9,6 +9,7 @@ import asyncio import inspect import json import os +import ssl import sys import traceback import warnings @@ -484,7 +485,14 @@ def main(args: argparse.Namespace | None = None): "ssl_keyfile": keyfile, "ssl_certfile": certfile, } - logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") + if config.server.tls_cafile: + ssl_config["ssl_ca_certs"] = config.server.tls_cafile + ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED + logger.info( + f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}" + ) + else: + logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") listen_host = ["::", "0.0.0.0"] if not config.server.disable_ipv6 else "0.0.0.0" logger.info(f"Listening on {listen_host}:{port}")