mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore: refactor server.main
# What does this PR do? Refactor main to split out the app construction so that we can use `uvicorn --workers` to enable multi-process stack. ## Test Plan CI > uv run --with llama-stack python -m llama_stack.core.server.server benchmarking/k8s-benchmark/stack_run_config.yaml works. > LLAMA_STACK_CONFIG=benchmarking/k8s-benchmark/stack_run_config.yaml uv run uvicorn llama_stack.core.server.server:create_app --port 8321 --workers 4 works.
This commit is contained in:
parent
ac1414b571
commit
a285f9c95f
7 changed files with 233 additions and 146 deletions
|
@ -17,11 +17,8 @@ export POSTGRES_PASSWORD=llamastack
|
|||
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
|
||||
export MOCK_INFERENCE_MODEL=mock-inference
|
||||
|
||||
export MOCK_INFERENCE_URL=openai-mock-service:8080
|
||||
|
||||
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
export LLAMA_STACK_WORKERS=4
|
||||
|
||||
set -euo pipefail
|
||||
set -x
|
||||
|
|
|
@ -5,6 +5,7 @@ data:
|
|||
image_name: kubernetes-benchmark-demo
|
||||
apis:
|
||||
- agents
|
||||
- files
|
||||
- inference
|
||||
- files
|
||||
- safety
|
||||
|
@ -23,6 +24,14 @@ data:
|
|||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||
vector_io:
|
||||
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
|
|
|
@ -52,9 +52,20 @@ spec:
|
|||
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
|
||||
- name: VLLM_TLS_VERIFY
|
||||
value: "false"
|
||||
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"]
|
||||
- name: LLAMA_STACK_LOGGING
|
||||
value: "all=WARNING"
|
||||
- name: LLAMA_STACK_CONFIG
|
||||
value: "/etc/config/stack_run_config.yaml"
|
||||
- name: LLAMA_STACK_WORKERS
|
||||
value: "${LLAMA_STACK_WORKERS}"
|
||||
command: ["uvicorn", "llama_stack.core.server.server:create_app", "--host", "0.0.0.0", "--port", "8323", "--workers", "$LLAMA_STACK_WORKERS", "--factory"]
|
||||
ports:
|
||||
- containerPort: 8323
|
||||
resources:
|
||||
requests:
|
||||
cpu: "${LLAMA_STACK_WORKERS}"
|
||||
limits:
|
||||
cpu: "${LLAMA_STACK_WORKERS}"
|
||||
volumeMounts:
|
||||
- name: llama-storage
|
||||
mountPath: /root/.llama
|
||||
|
|
|
@ -40,7 +40,7 @@ from llama_stack.core.request_headers import (
|
|||
from llama_stack.core.resolver import ProviderRegistry
|
||||
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||
from llama_stack.core.stack import (
|
||||
construct_stack,
|
||||
Stack,
|
||||
get_stack_run_config_from_distro,
|
||||
replace_env_vars,
|
||||
)
|
||||
|
@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
try:
|
||||
self.route_impls = None
|
||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||
|
||||
stack = Stack(self.config, self.custom_provider_registry)
|
||||
await stack.initialize()
|
||||
self.impls = stack.impls
|
||||
except ModuleNotFoundError as _e:
|
||||
cprint(_e.msg, color="red", file=sys.stderr)
|
||||
cprint(
|
||||
|
@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
raise _e
|
||||
|
||||
assert self.impls is not None
|
||||
if Api.telemetry in self.impls:
|
||||
setup_logger(self.impls[Api.telemetry])
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import argparse
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
|
@ -50,17 +51,15 @@ from llama_stack.core.request_headers import (
|
|||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.core.resolver import InvalidProviderError
|
||||
from llama_stack.core.server.routes import (
|
||||
find_matching_route,
|
||||
get_all_api_routes,
|
||||
initialize_route_impls,
|
||||
)
|
||||
from llama_stack.core.stack import (
|
||||
Stack,
|
||||
cast_image_name_to_string,
|
||||
construct_stack,
|
||||
replace_env_vars,
|
||||
shutdown_stack,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
|
@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
)
|
||||
|
||||
|
||||
async def shutdown(app):
|
||||
"""Initiate a graceful shutdown of the application.
|
||||
|
||||
Handled by the lifespan context manager. The shutdown process involves
|
||||
shutting down all implementations registered in the application.
|
||||
class StackApp(FastAPI):
|
||||
"""
|
||||
await shutdown_stack(app.__llama_stack_impls__)
|
||||
A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
|
||||
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, config: StackRunConfig, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.stack: Stack = Stack(config)
|
||||
|
||||
# This code is called from a running event loop managed by uvicorn so we cannot simply call
|
||||
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
|
||||
# function.
|
||||
# As a workaround, we use a thread pool executor to run the initialize() method
|
||||
# in a separate thread.
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(asyncio.run, self.stack.initialize())
|
||||
future.result()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async def lifespan(app: StackApp):
|
||||
logger.info("Starting up")
|
||||
assert app.stack is not None
|
||||
app.stack.create_registry_refresh_task()
|
||||
yield
|
||||
logger.info("Shutting down")
|
||||
await shutdown(app)
|
||||
await app.stack.shutdown()
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
|
@ -386,73 +398,61 @@ class ClientVersionMiddleware:
|
|||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
def create_app(
|
||||
config_file: str | None = None,
|
||||
env_vars: list[str] | None = None,
|
||||
) -> StackApp:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
add_config_distro_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||
)
|
||||
Args:
|
||||
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
|
||||
env_vars: List of environment variables in KEY=value format.
|
||||
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
|
||||
|
||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||
# parsed from the command line
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
Returns:
|
||||
Configured StackApp instance.
|
||||
"""
|
||||
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
|
||||
if config_file is None:
|
||||
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
|
||||
|
||||
config_or_distro = get_config_from_args(args)
|
||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||
config_file = resolve_config_or_distro(config_file, Mode.RUN)
|
||||
|
||||
# Load and process configuration
|
||||
logger_config = None
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
|
||||
if env_vars:
|
||||
for env_pair in env_vars:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
logger.info(f"Setting environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
|
||||
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
_log_run_config(run_config=config)
|
||||
|
||||
app = FastAPI(
|
||||
app = StackApp(
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
config=config,
|
||||
)
|
||||
|
||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
try:
|
||||
# Create and set the event loop that will be used for both construction and server runtime
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Construct the stack in the persistent event loop
|
||||
impls = loop.run_until_complete(construct_stack(config))
|
||||
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
impls = app.stack.impls
|
||||
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||
|
@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None):
|
|||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
|
||||
add_config_distro_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||
)
|
||||
|
||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||
# parsed from the command line
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
config_or_distro = get_config_from_args(args)
|
||||
|
||||
try:
|
||||
app = create_app(
|
||||
config_file=config_or_distro,
|
||||
env_vars=args.env,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating app: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
import uvicorn
|
||||
|
||||
# Configure SSL if certificates are provided
|
||||
|
@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None):
|
|||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
||||
# Run uvicorn in the existing event loop to preserve background tasks
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
|
@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None):
|
|||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
finally:
|
||||
if not loop.is_closed():
|
||||
logger.debug("Closing event loop")
|
||||
loop.close()
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
|
|
|
@ -315,11 +315,15 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
|||
impls[Api.prompts] = prompts_impl
|
||||
|
||||
|
||||
class Stack:
|
||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||
self.run_config = run_config
|
||||
self.provider_registry = provider_registry
|
||||
self.impls = None
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(
|
||||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
||||
) -> dict[Api, Any]:
|
||||
async def initialize(self):
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||
|
||||
|
@ -329,24 +333,28 @@ async def construct_stack(
|
|||
TEST_RECORDING_CONTEXT.__enter__()
|
||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
||||
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
||||
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
||||
impls = await resolve_impls(
|
||||
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
|
||||
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
|
||||
)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, run_config)
|
||||
add_internal_implementations(impls, self.run_config)
|
||||
|
||||
if Api.prompts in impls:
|
||||
await impls[Api.prompts].initialize()
|
||||
|
||||
await register_resources(run_config, impls)
|
||||
await register_resources(self.run_config, impls)
|
||||
|
||||
await refresh_registry_once(impls)
|
||||
self.impls = impls
|
||||
|
||||
def create_registry_refresh_task(self):
|
||||
assert self.impls is not None, "Must call initialize() before starting"
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
|
||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
@ -360,11 +368,9 @@ async def construct_stack(
|
|||
logger.debug("Model refresh task completed")
|
||||
|
||||
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
||||
return impls
|
||||
|
||||
|
||||
async def shutdown_stack(impls: dict[Api, Any]):
|
||||
for impl in impls.values():
|
||||
async def shutdown(self):
|
||||
for impl in self.impls.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info(f"Shutting down {impl_name}")
|
||||
try:
|
||||
|
|
|
@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
|||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
class MockStack:
|
||||
def __init__(self, config, custom_provider_registry=None):
|
||||
self.impls = mock_impls
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = LlamaStackAsLibraryClient("ci-tests")
|
||||
|
@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
|||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
class MockStack:
|
||||
def __init__(self, config, custom_provider_registry=None):
|
||||
self.impls = mock_impls
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||
|
@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
|||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
class MockStack:
|
||||
def __init__(self, config, custom_provider_registry=None):
|
||||
self.impls = mock_impls
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = LlamaStackAsLibraryClient("ci-tests")
|
||||
|
@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
|||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
class MockStack:
|
||||
def __init__(self, config, custom_provider_registry=None):
|
||||
self.impls = mock_impls
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||
|
@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
|||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
class MockStack:
|
||||
def __init__(self, config, custom_provider_registry=None):
|
||||
self.impls = mock_impls
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue