chore: refactor server.main

# What does this PR do?
Refactor main to split out the app construction so that we can use `uvicorn --workers` to enable multi-process stack.


## Test Plan
CI

> uv run --with llama-stack python -m llama_stack.core.server.server benchmarking/k8s-benchmark/stack_run_config.yaml

works.

> LLAMA_STACK_CONFIG=benchmarking/k8s-benchmark/stack_run_config.yaml uv run uvicorn llama_stack.core.server.server:create_app --port 8321 --workers 4

works.
This commit is contained in:
Eric Huang 2025-09-17 12:24:13 -07:00
parent ac1414b571
commit a285f9c95f
7 changed files with 233 additions and 146 deletions

View file

@ -17,11 +17,8 @@ export POSTGRES_PASSWORD=llamastack
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export MOCK_INFERENCE_MODEL=mock-inference
export MOCK_INFERENCE_URL=openai-mock-service:8080
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
export LLAMA_STACK_WORKERS=4
set -euo pipefail set -euo pipefail
set -x set -x

View file

@ -5,6 +5,7 @@ data:
image_name: kubernetes-benchmark-demo image_name: kubernetes-benchmark-demo
apis: apis:
- agents - agents
- files
- inference - inference
- files - files
- safety - safety
@ -23,6 +24,14 @@ data:
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
vector_io: vector_io:
- provider_id: ${env.ENABLE_CHROMADB:+chromadb} - provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb provider_type: remote::chromadb

View file

@ -52,9 +52,20 @@ spec:
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1 value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
- name: VLLM_TLS_VERIFY - name: VLLM_TLS_VERIFY
value: "false" value: "false"
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"] - name: LLAMA_STACK_LOGGING
value: "all=WARNING"
- name: LLAMA_STACK_CONFIG
value: "/etc/config/stack_run_config.yaml"
- name: LLAMA_STACK_WORKERS
value: "${LLAMA_STACK_WORKERS}"
command: ["uvicorn", "llama_stack.core.server.server:create_app", "--host", "0.0.0.0", "--port", "8323", "--workers", "$LLAMA_STACK_WORKERS", "--factory"]
ports: ports:
- containerPort: 8323 - containerPort: 8323
resources:
requests:
cpu: "${LLAMA_STACK_WORKERS}"
limits:
cpu: "${LLAMA_STACK_WORKERS}"
volumeMounts: volumeMounts:
- name: llama-storage - name: llama-storage
mountPath: /root/.llama mountPath: /root/.llama

View file

@ -40,7 +40,7 @@ from llama_stack.core.request_headers import (
from llama_stack.core.resolver import ProviderRegistry from llama_stack.core.resolver import ProviderRegistry
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
from llama_stack.core.stack import ( from llama_stack.core.stack import (
construct_stack, Stack,
get_stack_run_config_from_distro, get_stack_run_config_from_distro,
replace_env_vars, replace_env_vars,
) )
@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
try: try:
self.route_impls = None self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry)
stack = Stack(self.config, self.custom_provider_registry)
await stack.initialize()
self.impls = stack.impls
except ModuleNotFoundError as _e: except ModuleNotFoundError as _e:
cprint(_e.msg, color="red", file=sys.stderr) cprint(_e.msg, color="red", file=sys.stderr)
cprint( cprint(
@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
raise _e raise _e
assert self.impls is not None
if Api.telemetry in self.impls: if Api.telemetry in self.impls:
setup_logger(self.impls[Api.telemetry]) setup_logger(self.impls[Api.telemetry])

View file

@ -6,6 +6,7 @@
import argparse import argparse
import asyncio import asyncio
import concurrent.futures
import functools import functools
import inspect import inspect
import json import json
@ -50,17 +51,15 @@ from llama_stack.core.request_headers import (
request_provider_data_context, request_provider_data_context,
user_from_scope, user_from_scope,
) )
from llama_stack.core.resolver import InvalidProviderError
from llama_stack.core.server.routes import ( from llama_stack.core.server.routes import (
find_matching_route, find_matching_route,
get_all_api_routes, get_all_api_routes,
initialize_route_impls, initialize_route_impls,
) )
from llama_stack.core.stack import ( from llama_stack.core.stack import (
Stack,
cast_image_name_to_string, cast_image_name_to_string,
construct_stack,
replace_env_vars, replace_env_vars,
shutdown_stack,
validate_env_pair, validate_env_pair,
) )
from llama_stack.core.utils.config import redact_sensitive_fields from llama_stack.core.utils.config import redact_sensitive_fields
@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
) )
async def shutdown(app): class StackApp(FastAPI):
"""Initiate a graceful shutdown of the application.
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
""" """
await shutdown_stack(app.__llama_stack_impls__) A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
"""
def __init__(self, config: StackRunConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stack: Stack = Stack(config)
# This code is called from a running event loop managed by uvicorn so we cannot simply call
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
# function.
# As a workaround, we use a thread pool executor to run the initialize() method
# in a separate thread.
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, self.stack.initialize())
future.result()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: StackApp):
logger.info("Starting up") logger.info("Starting up")
assert app.stack is not None
app.stack.create_registry_refresh_task()
yield yield
logger.info("Shutting down") logger.info("Shutting down")
await shutdown(app) await app.stack.shutdown()
def is_streaming_request(func_name: str, request: Request, **kwargs): def is_streaming_request(func_name: str, request: Request, **kwargs):
@ -386,73 +398,61 @@ class ClientVersionMiddleware:
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
def main(args: argparse.Namespace | None = None): def create_app(
"""Start the LlamaStack server.""" config_file: str | None = None,
parser = argparse.ArgumentParser(description="Start the LlamaStack server.") env_vars: list[str] | None = None,
) -> StackApp:
"""Create and configure the FastAPI application.
add_config_distro_args(parser) Args:
parser.add_argument( config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
"--port", env_vars: List of environment variables in KEY=value format.
type=int, disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case Returns:
# the args will be passed as a Namespace object to the main function, otherwise they will be Configured StackApp instance.
# parsed from the command line """
if args is None: config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
args = parser.parse_args() if config_file is None:
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
config_or_distro = get_config_from_args(args) config_file = resolve_config_or_distro(config_file, Mode.RUN)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
# Load and process configuration
logger_config = None logger_config = None
with open(config_file) as fp: with open(config_file) as fp:
config_contents = yaml.safe_load(fp) config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg) logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="core::server", config=logger_config) logger = get_logger(name=__name__, category="core::server", config=logger_config)
if args.env:
for env_pair in args.env: if env_vars:
for env_pair in env_vars:
try: try:
key, value = validate_env_pair(env_pair) key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}") logger.info(f"Setting environment variable {key} => {value}")
os.environ[key] = value os.environ[key] = value
except ValueError as e: except ValueError as e:
logger.error(f"Error: {str(e)}") logger.error(f"Error: {str(e)}")
sys.exit(1) raise ValueError(f"Invalid environment variable format: {env_pair}") from e
config = replace_env_vars(config_contents) config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config)) config = StackRunConfig(**cast_image_name_to_string(config))
_log_run_config(run_config=config) _log_run_config(run_config=config)
app = FastAPI( app = StackApp(
lifespan=lifespan, lifespan=lifespan,
docs_url="/docs", docs_url="/docs",
redoc_url="/redoc", redoc_url="/redoc",
openapi_url="/openapi.json", openapi_url="/openapi.json",
config=config,
) )
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware) app.add_middleware(ClientVersionMiddleware)
try: impls = app.stack.impls
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if config.server.auth: if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}") logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None):
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis) app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
return app
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
add_config_distro_args(parser)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
config_or_distro = get_config_from_args(args)
try:
app = create_app(
config_file=config_or_distro,
env_vars=args.env,
)
except Exception as e:
logger.error(f"Error creating app: {str(e)}")
sys.exit(1)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
else:
logger_config = None
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
import uvicorn import uvicorn
# Configure SSL if certificates are provided # Configure SSL if certificates are provided
@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None):
if ssl_config: if ssl_config:
uvicorn_config.update(ssl_config) uvicorn_config.update(ssl_config)
# Run uvicorn in the existing event loop to preserve background tasks
# We need to catch KeyboardInterrupt because uvicorn's signal handling # We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python # re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing # converts to KeyboardInterrupt. Without this catch, we'd get a confusing
@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None):
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own # Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort. # signal handling but this is quite intrusive and not worth the effort.
try: try:
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...") logger.info("Received interrupt signal, shutting down gracefully...")
finally:
if not loop.is_closed():
logger.debug("Closing event loop")
loop.close()
def _log_run_config(run_config: StackRunConfig): def _log_run_config(run_config: StackRunConfig):

View file

@ -315,11 +315,15 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
impls[Api.prompts] = prompts_impl impls[Api.prompts] = prompts_impl
# Produces a stack of providers for the given run config. Not all APIs may be class Stack:
# asked for in the run config. def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
async def construct_stack( self.run_config = run_config
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None self.provider_registry = provider_registry
) -> dict[Api, Any]: self.impls = None
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def initialize(self):
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ: if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording from llama_stack.testing.inference_recorder import setup_inference_recording
@ -329,24 +333,28 @@ async def construct_stack(
TEST_RECORDING_CONTEXT.__enter__() TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
policy = run_config.server.auth.access_policy if run_config.server.auth else [] policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
impls = await resolve_impls( impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
) )
# Add internal implementations after all other providers are resolved # Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config) add_internal_implementations(impls, self.run_config)
if Api.prompts in impls: if Api.prompts in impls:
await impls[Api.prompts].initialize() await impls[Api.prompts].initialize()
await register_resources(run_config, impls) await register_resources(self.run_config, impls)
await refresh_registry_once(impls) await refresh_registry_once(impls)
self.impls = impls
def create_registry_refresh_task(self):
assert self.impls is not None, "Must call initialize() before starting"
global REGISTRY_REFRESH_TASK global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls)) REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
def cb(task): def cb(task):
import traceback import traceback
@ -360,11 +368,9 @@ async def construct_stack(
logger.debug("Model refresh task completed") logger.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb) REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls
async def shutdown(self):
async def shutdown_stack(impls: dict[Api, Any]): for impl in self.impls.values():
for impl in impls.values():
impl_name = impl.__class__.__name__ impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}") logger.info(f"Shutting down {impl_name}")
try: try:

View file

@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests") client = LlamaStackAsLibraryClient("ci-tests")
@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests") client = AsyncLlamaStackAsLibraryClient("ci-tests")
@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests") client = LlamaStackAsLibraryClient("ci-tests")
@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests") client = AsyncLlamaStackAsLibraryClient("ci-tests")
@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
sync_client = LlamaStackAsLibraryClient("ci-tests") sync_client = LlamaStackAsLibraryClient("ci-tests")