chore: refactor server.main (#3462)
Some checks failed
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (push) Failing after 6s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 8s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 13s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
Test External API and Providers / test-external (venv) (push) Failing after 7s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Python Package Build Test / build (3.12) (push) Failing after 10s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 18s
API Conformance Tests / check-schema-compatibility (push) Successful in 22s
UI Tests / ui-tests (22) (push) Successful in 29s
Pre-commit / pre-commit (push) Successful in 1m25s

# What does this PR do?
As shown in #3421, we can scale stack to handle more RPS with k8s
replicas. This PR enables multi process stack with uvicorn --workers so
that we can achieve the same scaling without being in k8s.

To achieve that we refactor main to split out the app construction
logic. This method needs to be non-async. We created a new `Stack` class
to house impls and have a `start()` method to be called in lifespan to
start background tasks instead of starting them in the old
`construct_stack`. This way we avoid having to manage an event loop
manually.


## Test Plan
CI

> uv run --with llama-stack python -m llama_stack.core.server.server
benchmarking/k8s-benchmark/stack_run_config.yaml

works.

> LLAMA_STACK_CONFIG=benchmarking/k8s-benchmark/stack_run_config.yaml uv
run uvicorn llama_stack.core.server.server:create_app --port 8321
--workers 4

works.
This commit is contained in:
ehhuang 2025-09-18 21:11:13 -07:00 committed by GitHub
parent 8422bd102a
commit 4c2fcb6b51
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 233 additions and 146 deletions

View file

@ -315,78 +315,84 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
impls[Api.prompts] = prompts_impl
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> dict[Api, Any]:
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording
class Stack:
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
self.run_config = run_config
self.provider_registry = provider_registry
self.impls = None
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def initialize(self):
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording
global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_inference_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
impls = await resolve_impls(
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, self.run_config)
if Api.prompts in impls:
await impls[Api.prompts].initialize()
await register_resources(self.run_config, impls)
await refresh_registry_once(impls)
self.impls = impls
def create_registry_refresh_task(self):
assert self.impls is not None, "Must call initialize() before starting"
global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
def cb(task):
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb)
async def shutdown(self):
for impl in self.impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning(f"No shutdown method for {impl_name}")
except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_inference_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}")
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
policy = run_config.server.auth.access_policy if run_config.server.auth else []
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config)
if Api.prompts in impls:
await impls[Api.prompts].initialize()
await register_resources(run_config, impls)
await refresh_registry_once(impls)
global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
def cb(task):
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls
async def shutdown_stack(impls: dict[Api, Any]):
for impl in impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning(f"No shutdown method for {impl_name}")
except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
global TEST_RECORDING_CONTEXT
if TEST_RECORDING_CONTEXT:
try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
REGISTRY_REFRESH_TASK.cancel()
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
REGISTRY_REFRESH_TASK.cancel()
async def refresh_registry_once(impls: dict[Api, Any]):