diff --git a/benchmarking/k8s-benchmark/apply.sh b/benchmarking/k8s-benchmark/apply.sh index 4f2270da8..6e6607663 100755 --- a/benchmarking/k8s-benchmark/apply.sh +++ b/benchmarking/k8s-benchmark/apply.sh @@ -17,11 +17,8 @@ export POSTGRES_PASSWORD=llamastack export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B -export MOCK_INFERENCE_MODEL=mock-inference - -export MOCK_INFERENCE_URL=openai-mock-service:8080 - export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL +export LLAMA_STACK_WORKERS=4 set -euo pipefail set -x diff --git a/benchmarking/k8s-benchmark/stack-configmap.yaml b/benchmarking/k8s-benchmark/stack-configmap.yaml index bf6109b68..286ba5f77 100644 --- a/benchmarking/k8s-benchmark/stack-configmap.yaml +++ b/benchmarking/k8s-benchmark/stack-configmap.yaml @@ -5,6 +5,7 @@ data: image_name: kubernetes-benchmark-demo apis: - agents + - files - inference - files - safety @@ -23,6 +24,14 @@ data: - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} + metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db vector_io: - provider_id: ${env.ENABLE_CHROMADB:+chromadb} provider_type: remote::chromadb diff --git a/benchmarking/k8s-benchmark/stack-k8s.yaml.template b/benchmarking/k8s-benchmark/stack-k8s.yaml.template index 9cb1e5be3..8842c0bea 100644 --- a/benchmarking/k8s-benchmark/stack-k8s.yaml.template +++ b/benchmarking/k8s-benchmark/stack-k8s.yaml.template @@ -52,9 +52,20 @@ spec: value: http://vllm-server-safety.default.svc.cluster.local:8001/v1 - name: VLLM_TLS_VERIFY value: "false" - command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"] + - name: LLAMA_STACK_LOGGING + value: "all=WARNING" + - name: LLAMA_STACK_CONFIG + value: "/etc/config/stack_run_config.yaml" + - name: LLAMA_STACK_WORKERS + value: "${LLAMA_STACK_WORKERS}" + command: ["uvicorn", "llama_stack.core.server.server:create_app", "--host", "0.0.0.0", "--port", "8323", "--workers", "$LLAMA_STACK_WORKERS", "--factory"] ports: - containerPort: 8323 + resources: + requests: + cpu: "${LLAMA_STACK_WORKERS}" + limits: + cpu: "${LLAMA_STACK_WORKERS}" volumeMounts: - name: llama-storage mountPath: /root/.llama diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index ea5a2ac8e..e722e4de6 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -40,7 +40,7 @@ from llama_stack.core.request_headers import ( from llama_stack.core.resolver import ProviderRegistry from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls from llama_stack.core.stack import ( - construct_stack, + Stack, get_stack_run_config_from_distro, replace_env_vars, ) @@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): try: self.route_impls = None - self.impls = await construct_stack(self.config, self.custom_provider_registry) + + stack = Stack(self.config, self.custom_provider_registry) + await stack.initialize() + self.impls = stack.impls except ModuleNotFoundError as _e: cprint(_e.msg, color="red", file=sys.stderr) cprint( @@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) raise _e + assert self.impls is not None if Api.telemetry in self.impls: setup_logger(self.impls[Api.telemetry]) diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index d3e875fec..9cca42268 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -6,6 +6,7 @@ import argparse import asyncio +import concurrent.futures import functools import inspect import json @@ -50,17 +51,15 @@ from llama_stack.core.request_headers import ( request_provider_data_context, user_from_scope, ) -from llama_stack.core.resolver import InvalidProviderError from llama_stack.core.server.routes import ( find_matching_route, get_all_api_routes, initialize_route_impls, ) from llama_stack.core.stack import ( + Stack, cast_image_name_to_string, - construct_stack, replace_env_vars, - shutdown_stack, validate_env_pair, ) from llama_stack.core.utils.config import redact_sensitive_fields @@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro ) -async def shutdown(app): - """Initiate a graceful shutdown of the application. - - Handled by the lifespan context manager. The shutdown process involves - shutting down all implementations registered in the application. +class StackApp(FastAPI): """ - await shutdown_stack(app.__llama_stack_impls__) + A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can + start background tasks (e.g. refresh model registry periodically) from the lifespan context manager. + """ + + def __init__(self, config: StackRunConfig, *args, **kwargs): + super().__init__(*args, **kwargs) + self.stack: Stack = Stack(config) + + # This code is called from a running event loop managed by uvicorn so we cannot simply call + # asyncio.run() to initialize the stack. We cannot await either since this is not an async + # function. + # As a workaround, we use a thread pool executor to run the initialize() method + # in a separate thread. + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self.stack.initialize()) + future.result() @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: StackApp): logger.info("Starting up") + assert app.stack is not None + app.stack.create_registry_refresh_task() yield logger.info("Shutting down") - await shutdown(app) + await app.stack.shutdown() def is_streaming_request(func_name: str, request: Request, **kwargs): @@ -386,73 +398,61 @@ class ClientVersionMiddleware: return await self.app(scope, receive, send) -def main(args: argparse.Namespace | None = None): - """Start the LlamaStack server.""" - parser = argparse.ArgumentParser(description="Start the LlamaStack server.") +def create_app( + config_file: str | None = None, + env_vars: list[str] | None = None, +) -> StackApp: + """Create and configure the FastAPI application. - add_config_distro_args(parser) - parser.add_argument( - "--port", - type=int, - default=int(os.getenv("LLAMA_STACK_PORT", 8321)), - help="Port to listen on", - ) - parser.add_argument( - "--env", - action="append", - help="Environment variables in KEY=value format. Can be specified multiple times.", - ) + Args: + config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution. + env_vars: List of environment variables in KEY=value format. + disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var. - # Determine whether the server args are being passed by the "run" command, if this is the case - # the args will be passed as a Namespace object to the main function, otherwise they will be - # parsed from the command line - if args is None: - args = parser.parse_args() + Returns: + Configured StackApp instance. + """ + config_file = config_file or os.getenv("LLAMA_STACK_CONFIG") + if config_file is None: + raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set") - config_or_distro = get_config_from_args(args) - config_file = resolve_config_or_distro(config_or_distro, Mode.RUN) + config_file = resolve_config_or_distro(config_file, Mode.RUN) + # Load and process configuration logger_config = None with open(config_file) as fp: config_contents = yaml.safe_load(fp) if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): logger_config = LoggingConfig(**cfg) logger = get_logger(name=__name__, category="core::server", config=logger_config) - if args.env: - for env_pair in args.env: + + if env_vars: + for env_pair in env_vars: try: key, value = validate_env_pair(env_pair) - logger.info(f"Setting CLI environment variable {key} => {value}") + logger.info(f"Setting environment variable {key} => {value}") os.environ[key] = value except ValueError as e: logger.error(f"Error: {str(e)}") - sys.exit(1) + raise ValueError(f"Invalid environment variable format: {env_pair}") from e + config = replace_env_vars(config_contents) config = StackRunConfig(**cast_image_name_to_string(config)) _log_run_config(run_config=config) - app = FastAPI( + app = StackApp( lifespan=lifespan, docs_url="/docs", redoc_url="/redoc", openapi_url="/openapi.json", + config=config, ) if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): app.add_middleware(ClientVersionMiddleware) - try: - # Create and set the event loop that will be used for both construction and server runtime - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # Construct the stack in the persistent event loop - impls = loop.run_until_complete(construct_stack(config)) - - except InvalidProviderError as e: - logger.error(f"Error: {str(e)}") - sys.exit(1) + impls = app.stack.impls if config.server.auth: logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}") @@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None): app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler) - app.__llama_stack_impls__ = impls app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis) + return app + + +def main(args: argparse.Namespace | None = None): + """Start the LlamaStack server.""" + parser = argparse.ArgumentParser(description="Start the LlamaStack server.") + + add_config_distro_args(parser) + parser.add_argument( + "--port", + type=int, + default=int(os.getenv("LLAMA_STACK_PORT", 8321)), + help="Port to listen on", + ) + parser.add_argument( + "--env", + action="append", + help="Environment variables in KEY=value format. Can be specified multiple times.", + ) + + # Determine whether the server args are being passed by the "run" command, if this is the case + # the args will be passed as a Namespace object to the main function, otherwise they will be + # parsed from the command line + if args is None: + args = parser.parse_args() + + config_or_distro = get_config_from_args(args) + + try: + app = create_app( + config_file=config_or_distro, + env_vars=args.env, + ) + except Exception as e: + logger.error(f"Error creating app: {str(e)}") + sys.exit(1) + + config_file = resolve_config_or_distro(config_or_distro, Mode.RUN) + with open(config_file) as fp: + config_contents = yaml.safe_load(fp) + if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): + logger_config = LoggingConfig(**cfg) + else: + logger_config = None + config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents))) + import uvicorn # Configure SSL if certificates are provided @@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None): if ssl_config: uvicorn_config.update(ssl_config) - # Run uvicorn in the existing event loop to preserve background tasks # We need to catch KeyboardInterrupt because uvicorn's signal handling # re-raises SIGINT signals using signal.raise_signal(), which Python # converts to KeyboardInterrupt. Without this catch, we'd get a confusing @@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None): # Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own # signal handling but this is quite intrusive and not worth the effort. try: - loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) + asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) except (KeyboardInterrupt, SystemExit): logger.info("Received interrupt signal, shutting down gracefully...") - finally: - if not loop.is_closed(): - logger.debug("Closing event loop") - loop.close() def _log_run_config(run_config: StackRunConfig): diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index 7ab8d2c64..a6c5093eb 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -315,78 +315,84 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf impls[Api.prompts] = prompts_impl -# Produces a stack of providers for the given run config. Not all APIs may be -# asked for in the run config. -async def construct_stack( - run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None -) -> dict[Api, Any]: - if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ: - from llama_stack.testing.inference_recorder import setup_inference_recording +class Stack: + def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None): + self.run_config = run_config + self.provider_registry = provider_registry + self.impls = None + + # Produces a stack of providers for the given run config. Not all APIs may be + # asked for in the run config. + async def initialize(self): + if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ: + from llama_stack.testing.inference_recorder import setup_inference_recording + + global TEST_RECORDING_CONTEXT + TEST_RECORDING_CONTEXT = setup_inference_recording() + if TEST_RECORDING_CONTEXT: + TEST_RECORDING_CONTEXT.__enter__() + logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") + + dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name) + policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else [] + impls = await resolve_impls( + self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy + ) + + # Add internal implementations after all other providers are resolved + add_internal_implementations(impls, self.run_config) + + if Api.prompts in impls: + await impls[Api.prompts].initialize() + + await register_resources(self.run_config, impls) + + await refresh_registry_once(impls) + self.impls = impls + + def create_registry_refresh_task(self): + assert self.impls is not None, "Must call initialize() before starting" + + global REGISTRY_REFRESH_TASK + REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls)) + + def cb(task): + import traceback + + if task.cancelled(): + logger.error("Model refresh task cancelled") + elif task.exception(): + logger.error(f"Model refresh task failed: {task.exception()}") + traceback.print_exception(task.exception()) + else: + logger.debug("Model refresh task completed") + + REGISTRY_REFRESH_TASK.add_done_callback(cb) + + async def shutdown(self): + for impl in self.impls.values(): + impl_name = impl.__class__.__name__ + logger.info(f"Shutting down {impl_name}") + try: + if hasattr(impl, "shutdown"): + await asyncio.wait_for(impl.shutdown(), timeout=5) + else: + logger.warning(f"No shutdown method for {impl_name}") + except TimeoutError: + logger.exception(f"Shutdown timeout for {impl_name}") + except (Exception, asyncio.CancelledError) as e: + logger.exception(f"Failed to shutdown {impl_name}: {e}") global TEST_RECORDING_CONTEXT - TEST_RECORDING_CONTEXT = setup_inference_recording() if TEST_RECORDING_CONTEXT: - TEST_RECORDING_CONTEXT.__enter__() - logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") + try: + TEST_RECORDING_CONTEXT.__exit__(None, None, None) + except Exception as e: + logger.error(f"Error during inference recording cleanup: {e}") - dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) - policy = run_config.server.auth.access_policy if run_config.server.auth else [] - impls = await resolve_impls( - run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy - ) - - # Add internal implementations after all other providers are resolved - add_internal_implementations(impls, run_config) - - if Api.prompts in impls: - await impls[Api.prompts].initialize() - - await register_resources(run_config, impls) - - await refresh_registry_once(impls) - - global REGISTRY_REFRESH_TASK - REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls)) - - def cb(task): - import traceback - - if task.cancelled(): - logger.error("Model refresh task cancelled") - elif task.exception(): - logger.error(f"Model refresh task failed: {task.exception()}") - traceback.print_exception(task.exception()) - else: - logger.debug("Model refresh task completed") - - REGISTRY_REFRESH_TASK.add_done_callback(cb) - return impls - - -async def shutdown_stack(impls: dict[Api, Any]): - for impl in impls.values(): - impl_name = impl.__class__.__name__ - logger.info(f"Shutting down {impl_name}") - try: - if hasattr(impl, "shutdown"): - await asyncio.wait_for(impl.shutdown(), timeout=5) - else: - logger.warning(f"No shutdown method for {impl_name}") - except TimeoutError: - logger.exception(f"Shutdown timeout for {impl_name}") - except (Exception, asyncio.CancelledError) as e: - logger.exception(f"Failed to shutdown {impl_name}: {e}") - - global TEST_RECORDING_CONTEXT - if TEST_RECORDING_CONTEXT: - try: - TEST_RECORDING_CONTEXT.__exit__(None, None, None) - except Exception as e: - logger.error(f"Error during inference recording cleanup: {e}") - - global REGISTRY_REFRESH_TASK - if REGISTRY_REFRESH_TASK: - REGISTRY_REFRESH_TASK.cancel() + global REGISTRY_REFRESH_TASK + if REGISTRY_REFRESH_TASK: + REGISTRY_REFRESH_TASK.cancel() async def refresh_registry_once(impls: dict[Api, Any]): diff --git a/tests/unit/distribution/test_library_client_initialization.py b/tests/unit/distribution/test_library_client_initialization.py index b7e7a1857..b01a5c3e2 100644 --- a/tests/unit/distribution/test_library_client_initialization.py +++ b/tests/unit/distribution/test_library_client_initialization.py @@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization: mock_impls = {} mock_route_impls = RouteImpls({}) - async def mock_construct_stack(config, custom_provider_registry): - return mock_impls + class MockStack: + def __init__(self, config, custom_provider_registry=None): + self.impls = mock_impls + + async def initialize(self): + pass def mock_initialize_route_impls(impls): return mock_route_impls - monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) + monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) client = LlamaStackAsLibraryClient("ci-tests") @@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization: mock_impls = {} mock_route_impls = RouteImpls({}) - async def mock_construct_stack(config, custom_provider_registry): - return mock_impls + class MockStack: + def __init__(self, config, custom_provider_registry=None): + self.impls = mock_impls + + async def initialize(self): + pass def mock_initialize_route_impls(impls): return mock_route_impls - monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) + monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) client = AsyncLlamaStackAsLibraryClient("ci-tests") @@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization: mock_impls = {} mock_route_impls = RouteImpls({}) - async def mock_construct_stack(config, custom_provider_registry): - return mock_impls + class MockStack: + def __init__(self, config, custom_provider_registry=None): + self.impls = mock_impls + + async def initialize(self): + pass def mock_initialize_route_impls(impls): return mock_route_impls - monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) + monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) client = LlamaStackAsLibraryClient("ci-tests") @@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization: mock_impls = {} mock_route_impls = RouteImpls({}) - async def mock_construct_stack(config, custom_provider_registry): - return mock_impls + class MockStack: + def __init__(self, config, custom_provider_registry=None): + self.impls = mock_impls + + async def initialize(self): + pass def mock_initialize_route_impls(impls): return mock_route_impls - monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) + monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) client = AsyncLlamaStackAsLibraryClient("ci-tests") @@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization: mock_impls = {} mock_route_impls = RouteImpls({}) - async def mock_construct_stack(config, custom_provider_registry): - return mock_impls + class MockStack: + def __init__(self, config, custom_provider_registry=None): + self.impls = mock_impls + + async def initialize(self): + pass def mock_initialize_route_impls(impls): return mock_route_impls - monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) + monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) sync_client = LlamaStackAsLibraryClient("ci-tests")