mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
chore: refactor server.main
# What does this PR do? Refactor main to split out the app construction so that we can use `uvicorn --workers` to enable multi-process stack. ## Test Plan CI > uv run --with llama-stack python -m llama_stack.core.server.server benchmarking/k8s-benchmark/stack_run_config.yaml works. > LLAMA_STACK_CONFIG=benchmarking/k8s-benchmark/stack_run_config.yaml uv run uvicorn llama_stack.core.server.server:create_app --port 8321 --workers 4 works.
This commit is contained in:
parent
ac1414b571
commit
a285f9c95f
7 changed files with 233 additions and 146 deletions
|
@ -17,11 +17,8 @@ export POSTGRES_PASSWORD=llamastack
|
||||||
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||||
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||||
|
|
||||||
export MOCK_INFERENCE_MODEL=mock-inference
|
|
||||||
|
|
||||||
export MOCK_INFERENCE_URL=openai-mock-service:8080
|
|
||||||
|
|
||||||
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
|
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
|
export LLAMA_STACK_WORKERS=4
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
set -x
|
set -x
|
||||||
|
|
|
@ -5,6 +5,7 @@ data:
|
||||||
image_name: kubernetes-benchmark-demo
|
image_name: kubernetes-benchmark-demo
|
||||||
apis:
|
apis:
|
||||||
- agents
|
- agents
|
||||||
|
- files
|
||||||
- inference
|
- inference
|
||||||
- files
|
- files
|
||||||
- safety
|
- safety
|
||||||
|
@ -23,6 +24,14 @@ data:
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
files:
|
||||||
|
- provider_id: meta-reference-files
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
||||||
provider_type: remote::chromadb
|
provider_type: remote::chromadb
|
||||||
|
|
|
@ -52,9 +52,20 @@ spec:
|
||||||
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
|
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
|
||||||
- name: VLLM_TLS_VERIFY
|
- name: VLLM_TLS_VERIFY
|
||||||
value: "false"
|
value: "false"
|
||||||
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"]
|
- name: LLAMA_STACK_LOGGING
|
||||||
|
value: "all=WARNING"
|
||||||
|
- name: LLAMA_STACK_CONFIG
|
||||||
|
value: "/etc/config/stack_run_config.yaml"
|
||||||
|
- name: LLAMA_STACK_WORKERS
|
||||||
|
value: "${LLAMA_STACK_WORKERS}"
|
||||||
|
command: ["uvicorn", "llama_stack.core.server.server:create_app", "--host", "0.0.0.0", "--port", "8323", "--workers", "$LLAMA_STACK_WORKERS", "--factory"]
|
||||||
ports:
|
ports:
|
||||||
- containerPort: 8323
|
- containerPort: 8323
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
cpu: "${LLAMA_STACK_WORKERS}"
|
||||||
|
limits:
|
||||||
|
cpu: "${LLAMA_STACK_WORKERS}"
|
||||||
volumeMounts:
|
volumeMounts:
|
||||||
- name: llama-storage
|
- name: llama-storage
|
||||||
mountPath: /root/.llama
|
mountPath: /root/.llama
|
||||||
|
|
|
@ -40,7 +40,7 @@ from llama_stack.core.request_headers import (
|
||||||
from llama_stack.core.resolver import ProviderRegistry
|
from llama_stack.core.resolver import ProviderRegistry
|
||||||
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||||
from llama_stack.core.stack import (
|
from llama_stack.core.stack import (
|
||||||
construct_stack,
|
Stack,
|
||||||
get_stack_run_config_from_distro,
|
get_stack_run_config_from_distro,
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
)
|
)
|
||||||
|
@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.route_impls = None
|
self.route_impls = None
|
||||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
|
||||||
|
stack = Stack(self.config, self.custom_provider_registry)
|
||||||
|
await stack.initialize()
|
||||||
|
self.impls = stack.impls
|
||||||
except ModuleNotFoundError as _e:
|
except ModuleNotFoundError as _e:
|
||||||
cprint(_e.msg, color="red", file=sys.stderr)
|
cprint(_e.msg, color="red", file=sys.stderr)
|
||||||
cprint(
|
cprint(
|
||||||
|
@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
raise _e
|
raise _e
|
||||||
|
|
||||||
|
assert self.impls is not None
|
||||||
if Api.telemetry in self.impls:
|
if Api.telemetry in self.impls:
|
||||||
setup_logger(self.impls[Api.telemetry])
|
setup_logger(self.impls[Api.telemetry])
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
@ -50,17 +51,15 @@ from llama_stack.core.request_headers import (
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
user_from_scope,
|
user_from_scope,
|
||||||
)
|
)
|
||||||
from llama_stack.core.resolver import InvalidProviderError
|
|
||||||
from llama_stack.core.server.routes import (
|
from llama_stack.core.server.routes import (
|
||||||
find_matching_route,
|
find_matching_route,
|
||||||
get_all_api_routes,
|
get_all_api_routes,
|
||||||
initialize_route_impls,
|
initialize_route_impls,
|
||||||
)
|
)
|
||||||
from llama_stack.core.stack import (
|
from llama_stack.core.stack import (
|
||||||
|
Stack,
|
||||||
cast_image_name_to_string,
|
cast_image_name_to_string,
|
||||||
construct_stack,
|
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
shutdown_stack,
|
|
||||||
validate_env_pair,
|
validate_env_pair,
|
||||||
)
|
)
|
||||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||||
|
@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def shutdown(app):
|
class StackApp(FastAPI):
|
||||||
"""Initiate a graceful shutdown of the application.
|
|
||||||
|
|
||||||
Handled by the lifespan context manager. The shutdown process involves
|
|
||||||
shutting down all implementations registered in the application.
|
|
||||||
"""
|
"""
|
||||||
await shutdown_stack(app.__llama_stack_impls__)
|
A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
|
||||||
|
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: StackRunConfig, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.stack: Stack = Stack(config)
|
||||||
|
|
||||||
|
# This code is called from a running event loop managed by uvicorn so we cannot simply call
|
||||||
|
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
|
||||||
|
# function.
|
||||||
|
# As a workaround, we use a thread pool executor to run the initialize() method
|
||||||
|
# in a separate thread.
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future = executor.submit(asyncio.run, self.stack.initialize())
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: StackApp):
|
||||||
logger.info("Starting up")
|
logger.info("Starting up")
|
||||||
|
assert app.stack is not None
|
||||||
|
app.stack.create_registry_refresh_task()
|
||||||
yield
|
yield
|
||||||
logger.info("Shutting down")
|
logger.info("Shutting down")
|
||||||
await shutdown(app)
|
await app.stack.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||||
|
@ -386,73 +398,61 @@ class ClientVersionMiddleware:
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace | None = None):
|
def create_app(
|
||||||
"""Start the LlamaStack server."""
|
config_file: str | None = None,
|
||||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
env_vars: list[str] | None = None,
|
||||||
|
) -> StackApp:
|
||||||
|
"""Create and configure the FastAPI application.
|
||||||
|
|
||||||
add_config_distro_args(parser)
|
Args:
|
||||||
parser.add_argument(
|
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
|
||||||
"--port",
|
env_vars: List of environment variables in KEY=value format.
|
||||||
type=int,
|
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
|
||||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
|
||||||
help="Port to listen on",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--env",
|
|
||||||
action="append",
|
|
||||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
Returns:
|
||||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
Configured StackApp instance.
|
||||||
# parsed from the command line
|
"""
|
||||||
if args is None:
|
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
|
||||||
args = parser.parse_args()
|
if config_file is None:
|
||||||
|
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
|
||||||
|
|
||||||
config_or_distro = get_config_from_args(args)
|
config_file = resolve_config_or_distro(config_file, Mode.RUN)
|
||||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
|
||||||
|
|
||||||
|
# Load and process configuration
|
||||||
logger_config = None
|
logger_config = None
|
||||||
with open(config_file) as fp:
|
with open(config_file) as fp:
|
||||||
config_contents = yaml.safe_load(fp)
|
config_contents = yaml.safe_load(fp)
|
||||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||||
logger_config = LoggingConfig(**cfg)
|
logger_config = LoggingConfig(**cfg)
|
||||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||||
if args.env:
|
|
||||||
for env_pair in args.env:
|
if env_vars:
|
||||||
|
for env_pair in env_vars:
|
||||||
try:
|
try:
|
||||||
key, value = validate_env_pair(env_pair)
|
key, value = validate_env_pair(env_pair)
|
||||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
logger.info(f"Setting environment variable {key} => {value}")
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error: {str(e)}")
|
logger.error(f"Error: {str(e)}")
|
||||||
sys.exit(1)
|
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
|
||||||
|
|
||||||
config = replace_env_vars(config_contents)
|
config = replace_env_vars(config_contents)
|
||||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||||
|
|
||||||
_log_run_config(run_config=config)
|
_log_run_config(run_config=config)
|
||||||
|
|
||||||
app = FastAPI(
|
app = StackApp(
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
docs_url="/docs",
|
docs_url="/docs",
|
||||||
redoc_url="/redoc",
|
redoc_url="/redoc",
|
||||||
openapi_url="/openapi.json",
|
openapi_url="/openapi.json",
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||||
app.add_middleware(ClientVersionMiddleware)
|
app.add_middleware(ClientVersionMiddleware)
|
||||||
|
|
||||||
try:
|
impls = app.stack.impls
|
||||||
# Create and set the event loop that will be used for both construction and server runtime
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
# Construct the stack in the persistent event loop
|
|
||||||
impls = loop.run_until_complete(construct_stack(config))
|
|
||||||
|
|
||||||
except InvalidProviderError as e:
|
|
||||||
logger.error(f"Error: {str(e)}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if config.server.auth:
|
if config.server.auth:
|
||||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||||
|
@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None):
|
||||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
|
|
||||||
app.__llama_stack_impls__ = impls
|
|
||||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace | None = None):
|
||||||
|
"""Start the LlamaStack server."""
|
||||||
|
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||||
|
|
||||||
|
add_config_distro_args(parser)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||||
|
help="Port to listen on",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--env",
|
||||||
|
action="append",
|
||||||
|
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||||
|
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||||
|
# parsed from the command line
|
||||||
|
if args is None:
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config_or_distro = get_config_from_args(args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
app = create_app(
|
||||||
|
config_file=config_or_distro,
|
||||||
|
env_vars=args.env,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating app: {str(e)}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||||
|
with open(config_file) as fp:
|
||||||
|
config_contents = yaml.safe_load(fp)
|
||||||
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||||
|
logger_config = LoggingConfig(**cfg)
|
||||||
|
else:
|
||||||
|
logger_config = None
|
||||||
|
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
# Configure SSL if certificates are provided
|
# Configure SSL if certificates are provided
|
||||||
|
@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if ssl_config:
|
if ssl_config:
|
||||||
uvicorn_config.update(ssl_config)
|
uvicorn_config.update(ssl_config)
|
||||||
|
|
||||||
# Run uvicorn in the existing event loop to preserve background tasks
|
|
||||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||||
|
@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None):
|
||||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||||
# signal handling but this is quite intrusive and not worth the effort.
|
# signal handling but this is quite intrusive and not worth the effort.
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||||
except (KeyboardInterrupt, SystemExit):
|
except (KeyboardInterrupt, SystemExit):
|
||||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||||
finally:
|
|
||||||
if not loop.is_closed():
|
|
||||||
logger.debug("Closing event loop")
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
def _log_run_config(run_config: StackRunConfig):
|
def _log_run_config(run_config: StackRunConfig):
|
||||||
|
|
|
@ -315,11 +315,15 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
||||||
impls[Api.prompts] = prompts_impl
|
impls[Api.prompts] = prompts_impl
|
||||||
|
|
||||||
|
|
||||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
class Stack:
|
||||||
# asked for in the run config.
|
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||||
async def construct_stack(
|
self.run_config = run_config
|
||||||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
self.provider_registry = provider_registry
|
||||||
) -> dict[Api, Any]:
|
self.impls = None
|
||||||
|
|
||||||
|
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||||
|
# asked for in the run config.
|
||||||
|
async def initialize(self):
|
||||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||||
|
|
||||||
|
@ -329,24 +333,28 @@ async def construct_stack(
|
||||||
TEST_RECORDING_CONTEXT.__enter__()
|
TEST_RECORDING_CONTEXT.__enter__()
|
||||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||||
|
|
||||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
||||||
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
||||||
impls = await resolve_impls(
|
impls = await resolve_impls(
|
||||||
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
|
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add internal implementations after all other providers are resolved
|
# Add internal implementations after all other providers are resolved
|
||||||
add_internal_implementations(impls, run_config)
|
add_internal_implementations(impls, self.run_config)
|
||||||
|
|
||||||
if Api.prompts in impls:
|
if Api.prompts in impls:
|
||||||
await impls[Api.prompts].initialize()
|
await impls[Api.prompts].initialize()
|
||||||
|
|
||||||
await register_resources(run_config, impls)
|
await register_resources(self.run_config, impls)
|
||||||
|
|
||||||
await refresh_registry_once(impls)
|
await refresh_registry_once(impls)
|
||||||
|
self.impls = impls
|
||||||
|
|
||||||
|
def create_registry_refresh_task(self):
|
||||||
|
assert self.impls is not None, "Must call initialize() before starting"
|
||||||
|
|
||||||
global REGISTRY_REFRESH_TASK
|
global REGISTRY_REFRESH_TASK
|
||||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
|
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
|
||||||
|
|
||||||
def cb(task):
|
def cb(task):
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -360,11 +368,9 @@ async def construct_stack(
|
||||||
logger.debug("Model refresh task completed")
|
logger.debug("Model refresh task completed")
|
||||||
|
|
||||||
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
||||||
return impls
|
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
async def shutdown_stack(impls: dict[Api, Any]):
|
for impl in self.impls.values():
|
||||||
for impl in impls.values():
|
|
||||||
impl_name = impl.__class__.__name__
|
impl_name = impl.__class__.__name__
|
||||||
logger.info(f"Shutting down {impl_name}")
|
logger.info(f"Shutting down {impl_name}")
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = LlamaStackAsLibraryClient("ci-tests")
|
client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = LlamaStackAsLibraryClient("ci-tests")
|
client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue