mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore: refactor server.main (#3462)
Some checks failed
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (push) Failing after 6s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 8s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 13s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
Test External API and Providers / test-external (venv) (push) Failing after 7s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Python Package Build Test / build (3.12) (push) Failing after 10s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 18s
API Conformance Tests / check-schema-compatibility (push) Successful in 22s
UI Tests / ui-tests (22) (push) Successful in 29s
Pre-commit / pre-commit (push) Successful in 1m25s
Some checks failed
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (push) Failing after 6s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 8s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 13s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
Test External API and Providers / test-external (venv) (push) Failing after 7s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Python Package Build Test / build (3.12) (push) Failing after 10s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 18s
API Conformance Tests / check-schema-compatibility (push) Successful in 22s
UI Tests / ui-tests (22) (push) Successful in 29s
Pre-commit / pre-commit (push) Successful in 1m25s
# What does this PR do? As shown in #3421, we can scale stack to handle more RPS with k8s replicas. This PR enables multi process stack with uvicorn --workers so that we can achieve the same scaling without being in k8s. To achieve that we refactor main to split out the app construction logic. This method needs to be non-async. We created a new `Stack` class to house impls and have a `start()` method to be called in lifespan to start background tasks instead of starting them in the old `construct_stack`. This way we avoid having to manage an event loop manually. ## Test Plan CI > uv run --with llama-stack python -m llama_stack.core.server.server benchmarking/k8s-benchmark/stack_run_config.yaml works. > LLAMA_STACK_CONFIG=benchmarking/k8s-benchmark/stack_run_config.yaml uv run uvicorn llama_stack.core.server.server:create_app --port 8321 --workers 4 works.
This commit is contained in:
parent
8422bd102a
commit
4c2fcb6b51
7 changed files with 233 additions and 146 deletions
|
@ -17,11 +17,8 @@ export POSTGRES_PASSWORD=llamastack
|
||||||
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||||
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||||
|
|
||||||
export MOCK_INFERENCE_MODEL=mock-inference
|
|
||||||
|
|
||||||
export MOCK_INFERENCE_URL=openai-mock-service:8080
|
|
||||||
|
|
||||||
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
|
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
|
export LLAMA_STACK_WORKERS=4
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
set -x
|
set -x
|
||||||
|
|
|
@ -5,6 +5,7 @@ data:
|
||||||
image_name: kubernetes-benchmark-demo
|
image_name: kubernetes-benchmark-demo
|
||||||
apis:
|
apis:
|
||||||
- agents
|
- agents
|
||||||
|
- files
|
||||||
- inference
|
- inference
|
||||||
- files
|
- files
|
||||||
- safety
|
- safety
|
||||||
|
@ -23,6 +24,14 @@ data:
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
files:
|
||||||
|
- provider_id: meta-reference-files
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
||||||
provider_type: remote::chromadb
|
provider_type: remote::chromadb
|
||||||
|
|
|
@ -52,9 +52,20 @@ spec:
|
||||||
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
|
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
|
||||||
- name: VLLM_TLS_VERIFY
|
- name: VLLM_TLS_VERIFY
|
||||||
value: "false"
|
value: "false"
|
||||||
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"]
|
- name: LLAMA_STACK_LOGGING
|
||||||
|
value: "all=WARNING"
|
||||||
|
- name: LLAMA_STACK_CONFIG
|
||||||
|
value: "/etc/config/stack_run_config.yaml"
|
||||||
|
- name: LLAMA_STACK_WORKERS
|
||||||
|
value: "${LLAMA_STACK_WORKERS}"
|
||||||
|
command: ["uvicorn", "llama_stack.core.server.server:create_app", "--host", "0.0.0.0", "--port", "8323", "--workers", "$LLAMA_STACK_WORKERS", "--factory"]
|
||||||
ports:
|
ports:
|
||||||
- containerPort: 8323
|
- containerPort: 8323
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
cpu: "${LLAMA_STACK_WORKERS}"
|
||||||
|
limits:
|
||||||
|
cpu: "${LLAMA_STACK_WORKERS}"
|
||||||
volumeMounts:
|
volumeMounts:
|
||||||
- name: llama-storage
|
- name: llama-storage
|
||||||
mountPath: /root/.llama
|
mountPath: /root/.llama
|
||||||
|
|
|
@ -40,7 +40,7 @@ from llama_stack.core.request_headers import (
|
||||||
from llama_stack.core.resolver import ProviderRegistry
|
from llama_stack.core.resolver import ProviderRegistry
|
||||||
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||||
from llama_stack.core.stack import (
|
from llama_stack.core.stack import (
|
||||||
construct_stack,
|
Stack,
|
||||||
get_stack_run_config_from_distro,
|
get_stack_run_config_from_distro,
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
)
|
)
|
||||||
|
@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.route_impls = None
|
self.route_impls = None
|
||||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
|
||||||
|
stack = Stack(self.config, self.custom_provider_registry)
|
||||||
|
await stack.initialize()
|
||||||
|
self.impls = stack.impls
|
||||||
except ModuleNotFoundError as _e:
|
except ModuleNotFoundError as _e:
|
||||||
cprint(_e.msg, color="red", file=sys.stderr)
|
cprint(_e.msg, color="red", file=sys.stderr)
|
||||||
cprint(
|
cprint(
|
||||||
|
@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
raise _e
|
raise _e
|
||||||
|
|
||||||
|
assert self.impls is not None
|
||||||
if Api.telemetry in self.impls:
|
if Api.telemetry in self.impls:
|
||||||
setup_logger(self.impls[Api.telemetry])
|
setup_logger(self.impls[Api.telemetry])
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
@ -50,17 +51,15 @@ from llama_stack.core.request_headers import (
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
user_from_scope,
|
user_from_scope,
|
||||||
)
|
)
|
||||||
from llama_stack.core.resolver import InvalidProviderError
|
|
||||||
from llama_stack.core.server.routes import (
|
from llama_stack.core.server.routes import (
|
||||||
find_matching_route,
|
find_matching_route,
|
||||||
get_all_api_routes,
|
get_all_api_routes,
|
||||||
initialize_route_impls,
|
initialize_route_impls,
|
||||||
)
|
)
|
||||||
from llama_stack.core.stack import (
|
from llama_stack.core.stack import (
|
||||||
|
Stack,
|
||||||
cast_image_name_to_string,
|
cast_image_name_to_string,
|
||||||
construct_stack,
|
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
shutdown_stack,
|
|
||||||
validate_env_pair,
|
validate_env_pair,
|
||||||
)
|
)
|
||||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||||
|
@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def shutdown(app):
|
class StackApp(FastAPI):
|
||||||
"""Initiate a graceful shutdown of the application.
|
|
||||||
|
|
||||||
Handled by the lifespan context manager. The shutdown process involves
|
|
||||||
shutting down all implementations registered in the application.
|
|
||||||
"""
|
"""
|
||||||
await shutdown_stack(app.__llama_stack_impls__)
|
A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
|
||||||
|
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: StackRunConfig, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.stack: Stack = Stack(config)
|
||||||
|
|
||||||
|
# This code is called from a running event loop managed by uvicorn so we cannot simply call
|
||||||
|
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
|
||||||
|
# function.
|
||||||
|
# As a workaround, we use a thread pool executor to run the initialize() method
|
||||||
|
# in a separate thread.
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future = executor.submit(asyncio.run, self.stack.initialize())
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: StackApp):
|
||||||
logger.info("Starting up")
|
logger.info("Starting up")
|
||||||
|
assert app.stack is not None
|
||||||
|
app.stack.create_registry_refresh_task()
|
||||||
yield
|
yield
|
||||||
logger.info("Shutting down")
|
logger.info("Shutting down")
|
||||||
await shutdown(app)
|
await app.stack.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||||
|
@ -386,73 +398,61 @@ class ClientVersionMiddleware:
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace | None = None):
|
def create_app(
|
||||||
"""Start the LlamaStack server."""
|
config_file: str | None = None,
|
||||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
env_vars: list[str] | None = None,
|
||||||
|
) -> StackApp:
|
||||||
|
"""Create and configure the FastAPI application.
|
||||||
|
|
||||||
add_config_distro_args(parser)
|
Args:
|
||||||
parser.add_argument(
|
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
|
||||||
"--port",
|
env_vars: List of environment variables in KEY=value format.
|
||||||
type=int,
|
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
|
||||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
|
||||||
help="Port to listen on",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--env",
|
|
||||||
action="append",
|
|
||||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
Returns:
|
||||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
Configured StackApp instance.
|
||||||
# parsed from the command line
|
"""
|
||||||
if args is None:
|
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
|
||||||
args = parser.parse_args()
|
if config_file is None:
|
||||||
|
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
|
||||||
|
|
||||||
config_or_distro = get_config_from_args(args)
|
config_file = resolve_config_or_distro(config_file, Mode.RUN)
|
||||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
|
||||||
|
|
||||||
|
# Load and process configuration
|
||||||
logger_config = None
|
logger_config = None
|
||||||
with open(config_file) as fp:
|
with open(config_file) as fp:
|
||||||
config_contents = yaml.safe_load(fp)
|
config_contents = yaml.safe_load(fp)
|
||||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||||
logger_config = LoggingConfig(**cfg)
|
logger_config = LoggingConfig(**cfg)
|
||||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||||
if args.env:
|
|
||||||
for env_pair in args.env:
|
if env_vars:
|
||||||
|
for env_pair in env_vars:
|
||||||
try:
|
try:
|
||||||
key, value = validate_env_pair(env_pair)
|
key, value = validate_env_pair(env_pair)
|
||||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
logger.info(f"Setting environment variable {key} => {value}")
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error: {str(e)}")
|
logger.error(f"Error: {str(e)}")
|
||||||
sys.exit(1)
|
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
|
||||||
|
|
||||||
config = replace_env_vars(config_contents)
|
config = replace_env_vars(config_contents)
|
||||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||||
|
|
||||||
_log_run_config(run_config=config)
|
_log_run_config(run_config=config)
|
||||||
|
|
||||||
app = FastAPI(
|
app = StackApp(
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
docs_url="/docs",
|
docs_url="/docs",
|
||||||
redoc_url="/redoc",
|
redoc_url="/redoc",
|
||||||
openapi_url="/openapi.json",
|
openapi_url="/openapi.json",
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||||
app.add_middleware(ClientVersionMiddleware)
|
app.add_middleware(ClientVersionMiddleware)
|
||||||
|
|
||||||
try:
|
impls = app.stack.impls
|
||||||
# Create and set the event loop that will be used for both construction and server runtime
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
# Construct the stack in the persistent event loop
|
|
||||||
impls = loop.run_until_complete(construct_stack(config))
|
|
||||||
|
|
||||||
except InvalidProviderError as e:
|
|
||||||
logger.error(f"Error: {str(e)}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if config.server.auth:
|
if config.server.auth:
|
||||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||||
|
@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None):
|
||||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
|
|
||||||
app.__llama_stack_impls__ = impls
|
|
||||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace | None = None):
|
||||||
|
"""Start the LlamaStack server."""
|
||||||
|
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||||
|
|
||||||
|
add_config_distro_args(parser)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||||
|
help="Port to listen on",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--env",
|
||||||
|
action="append",
|
||||||
|
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||||
|
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||||
|
# parsed from the command line
|
||||||
|
if args is None:
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config_or_distro = get_config_from_args(args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
app = create_app(
|
||||||
|
config_file=config_or_distro,
|
||||||
|
env_vars=args.env,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating app: {str(e)}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||||
|
with open(config_file) as fp:
|
||||||
|
config_contents = yaml.safe_load(fp)
|
||||||
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||||
|
logger_config = LoggingConfig(**cfg)
|
||||||
|
else:
|
||||||
|
logger_config = None
|
||||||
|
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
# Configure SSL if certificates are provided
|
# Configure SSL if certificates are provided
|
||||||
|
@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if ssl_config:
|
if ssl_config:
|
||||||
uvicorn_config.update(ssl_config)
|
uvicorn_config.update(ssl_config)
|
||||||
|
|
||||||
# Run uvicorn in the existing event loop to preserve background tasks
|
|
||||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||||
|
@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None):
|
||||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||||
# signal handling but this is quite intrusive and not worth the effort.
|
# signal handling but this is quite intrusive and not worth the effort.
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||||
except (KeyboardInterrupt, SystemExit):
|
except (KeyboardInterrupt, SystemExit):
|
||||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||||
finally:
|
|
||||||
if not loop.is_closed():
|
|
||||||
logger.debug("Closing event loop")
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
def _log_run_config(run_config: StackRunConfig):
|
def _log_run_config(run_config: StackRunConfig):
|
||||||
|
|
|
@ -315,78 +315,84 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
||||||
impls[Api.prompts] = prompts_impl
|
impls[Api.prompts] = prompts_impl
|
||||||
|
|
||||||
|
|
||||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
class Stack:
|
||||||
# asked for in the run config.
|
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||||
async def construct_stack(
|
self.run_config = run_config
|
||||||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
self.provider_registry = provider_registry
|
||||||
) -> dict[Api, Any]:
|
self.impls = None
|
||||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
|
||||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||||
|
# asked for in the run config.
|
||||||
|
async def initialize(self):
|
||||||
|
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||||
|
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||||
|
|
||||||
|
global TEST_RECORDING_CONTEXT
|
||||||
|
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
||||||
|
if TEST_RECORDING_CONTEXT:
|
||||||
|
TEST_RECORDING_CONTEXT.__enter__()
|
||||||
|
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||||
|
|
||||||
|
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
||||||
|
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
||||||
|
impls = await resolve_impls(
|
||||||
|
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add internal implementations after all other providers are resolved
|
||||||
|
add_internal_implementations(impls, self.run_config)
|
||||||
|
|
||||||
|
if Api.prompts in impls:
|
||||||
|
await impls[Api.prompts].initialize()
|
||||||
|
|
||||||
|
await register_resources(self.run_config, impls)
|
||||||
|
|
||||||
|
await refresh_registry_once(impls)
|
||||||
|
self.impls = impls
|
||||||
|
|
||||||
|
def create_registry_refresh_task(self):
|
||||||
|
assert self.impls is not None, "Must call initialize() before starting"
|
||||||
|
|
||||||
|
global REGISTRY_REFRESH_TASK
|
||||||
|
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
|
||||||
|
|
||||||
|
def cb(task):
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
if task.cancelled():
|
||||||
|
logger.error("Model refresh task cancelled")
|
||||||
|
elif task.exception():
|
||||||
|
logger.error(f"Model refresh task failed: {task.exception()}")
|
||||||
|
traceback.print_exception(task.exception())
|
||||||
|
else:
|
||||||
|
logger.debug("Model refresh task completed")
|
||||||
|
|
||||||
|
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
for impl in self.impls.values():
|
||||||
|
impl_name = impl.__class__.__name__
|
||||||
|
logger.info(f"Shutting down {impl_name}")
|
||||||
|
try:
|
||||||
|
if hasattr(impl, "shutdown"):
|
||||||
|
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||||
|
else:
|
||||||
|
logger.warning(f"No shutdown method for {impl_name}")
|
||||||
|
except TimeoutError:
|
||||||
|
logger.exception(f"Shutdown timeout for {impl_name}")
|
||||||
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
|
logger.exception(f"Failed to shutdown {impl_name}: {e}")
|
||||||
|
|
||||||
global TEST_RECORDING_CONTEXT
|
global TEST_RECORDING_CONTEXT
|
||||||
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
|
||||||
if TEST_RECORDING_CONTEXT:
|
if TEST_RECORDING_CONTEXT:
|
||||||
TEST_RECORDING_CONTEXT.__enter__()
|
try:
|
||||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during inference recording cleanup: {e}")
|
||||||
|
|
||||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
global REGISTRY_REFRESH_TASK
|
||||||
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
if REGISTRY_REFRESH_TASK:
|
||||||
impls = await resolve_impls(
|
REGISTRY_REFRESH_TASK.cancel()
|
||||||
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add internal implementations after all other providers are resolved
|
|
||||||
add_internal_implementations(impls, run_config)
|
|
||||||
|
|
||||||
if Api.prompts in impls:
|
|
||||||
await impls[Api.prompts].initialize()
|
|
||||||
|
|
||||||
await register_resources(run_config, impls)
|
|
||||||
|
|
||||||
await refresh_registry_once(impls)
|
|
||||||
|
|
||||||
global REGISTRY_REFRESH_TASK
|
|
||||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
|
|
||||||
|
|
||||||
def cb(task):
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
if task.cancelled():
|
|
||||||
logger.error("Model refresh task cancelled")
|
|
||||||
elif task.exception():
|
|
||||||
logger.error(f"Model refresh task failed: {task.exception()}")
|
|
||||||
traceback.print_exception(task.exception())
|
|
||||||
else:
|
|
||||||
logger.debug("Model refresh task completed")
|
|
||||||
|
|
||||||
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
|
||||||
return impls
|
|
||||||
|
|
||||||
|
|
||||||
async def shutdown_stack(impls: dict[Api, Any]):
|
|
||||||
for impl in impls.values():
|
|
||||||
impl_name = impl.__class__.__name__
|
|
||||||
logger.info(f"Shutting down {impl_name}")
|
|
||||||
try:
|
|
||||||
if hasattr(impl, "shutdown"):
|
|
||||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
|
||||||
else:
|
|
||||||
logger.warning(f"No shutdown method for {impl_name}")
|
|
||||||
except TimeoutError:
|
|
||||||
logger.exception(f"Shutdown timeout for {impl_name}")
|
|
||||||
except (Exception, asyncio.CancelledError) as e:
|
|
||||||
logger.exception(f"Failed to shutdown {impl_name}: {e}")
|
|
||||||
|
|
||||||
global TEST_RECORDING_CONTEXT
|
|
||||||
if TEST_RECORDING_CONTEXT:
|
|
||||||
try:
|
|
||||||
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during inference recording cleanup: {e}")
|
|
||||||
|
|
||||||
global REGISTRY_REFRESH_TASK
|
|
||||||
if REGISTRY_REFRESH_TASK:
|
|
||||||
REGISTRY_REFRESH_TASK.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
async def refresh_registry_once(impls: dict[Api, Any]):
|
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||||
|
|
|
@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = LlamaStackAsLibraryClient("ci-tests")
|
client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = LlamaStackAsLibraryClient("ci-tests")
|
client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue