llama-stack-mirror/llama_stack/distribution/server/server.py
Ihar Hrachyshka aca82df7ed
fix: Multiple fixes for server shutdown (fix lifespan handling; fix handling CancelledError when raised by provider; let uvicorn handle signals) (#1495)
# What does this PR do?

If implementation raises CancelledError (e.g. when it runs its own async
loop for jobs), the main server shutdown handler gets confused and
doesn't attempt to shut down the main loop tasks.

While at it, also fixing the following failure when this happens:

```
UnboundLocalError: cannot access local variable 'loop' where it is not
associated with a value
```

Shutdown handlers were not running because lifespan logic was broken
since ~Oct 2024. Fixed that too and enforcing `lifespan` now (making
sure server will crash when it fails to interact with app through
middleware).

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan

Spotted while working on
https://github.com/meta-llama/llama-stack/pull/1437

One way to trigger it without the PR above is to add `raise
CancelledError` in
any of the running providers' `shutdown` methods; then `kill -INT <pid>`
the
server process.

Validated this with the following test patch:

```
diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py
index b85c463a..10dad83e 100644
--- a/llama_stack/distribution/server/server.py
+++ b/llama_stack/distribution/server/server.py
@@ -174,6 +174,7 @@ def handle_signal(app, signum, _) -> None:
         except asyncio.CancelledError:
             pass
         finally:
+            logger.info("Stopping event loop")
             loop.stop()
 
     loop = asyncio.get_running_loop()
diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py
index b837362d..163f43d8 100644
--- a/llama_stack/providers/inline/post_training/torchtune/post_training.py
+++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py
@@ -3,6 +3,7 @@
 #
 # This source code is licensed under the terms described in the LICENSE file in
 # the root directory of this source tree.
+import asyncio
 from datetime import datetime
 from typing import Any, Dict, Optional
 
@@ -43,6 +44,9 @@ class TorchtunePostTrainingImpl:
         self.jobs = {}
         self.checkpoints_dict = {}
 
+    async def shutdown(self) -> None:
+        raise asyncio.CancelledError("Shutdown")
+
     async def supervised_fine_tune(
         self,
         job_uuid: str,
```

Without the fix:

```
INFO:     Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
INFO:     Shutting down
INFO:     Finished server process [52099]
INFO     2025-03-07 23:25:33,548 __main__:143 server: Received signal SIGINT (2). Exiting gracefully...
INFO     2025-03-07 23:25:33,550 __main__:150 server: Shutting down DatasetsRoutingTable
INFO     2025-03-07 23:25:33,551 __main__:177 server: Stopping event loop
ERROR    2025-03-07 23:25:33,552 asyncio:1785 uncategorized: unhandled exception during asyncio.run() shutdown
         task: <Task finished name='Task-12' coro=<handle_signal.<locals>.shutdown() done, defined at
         /home/ec2-user/src/llama-stack/schedule/llama_stack/distribution/server/server.py:145>
         exception=UnboundLocalError("cannot access local variable 'loop' where it is not associated with a value")>
         ╭───────────────────────────────────── Traceback (most recent call last) ─────────────────────────────────────╮
         │ /home/ec2-user/src/llama-stack/schedule/llama_stack/distribution/server/server.py:178 in shutdown           │
         │                                                                                                             │
         │   175 │   │   │   pass                                                                                      │
         │   176 │   │   finally:                                                                                      │
         │   177 │   │   │   logger.info("Stopping event loop")                                                        │
         │ ❱ 178 │   │   │   loop.stop()                                                                               │
         │   179 │                                                                                                     │
         │   180 │   loop = asyncio.get_running_loop()                                                                 │
         │   181 │   loop.create_task(shutdown())                                                                      │
         ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
         UnboundLocalError: cannot access local variable 'loop' where it is not associated with a value

```

With the fix, now seeing the following messages when the server is
killed:

```
INFO:     Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
INFO:     Shutting down
INFO:     Finished server process [50836]
INFO     2025-03-07 23:20:35,182 __main__:143 server: Received signal SIGINT (2). Exiting gracefully...
INFO     2025-03-07 23:20:35,184 __main__:149 server: Shutting down DatasetsRoutingTable
ERROR    2025-03-07 23:20:35,185 __main__:158 server: Failed to shutdown DatasetsRoutingTable: {CancelledError()}
         ╭───────────────────────────────────── Traceback (most recent call last) ─────────────────────────────────────╮
         │ /usr/lib64/python3.11/asyncio/tasks.py:476 in wait_for                                                      │
         │                                                                                                             │
         │   473 │   try:                                                                                              │
         │   474 │   │   # wait until the future completes or the timeout                                              │
         │   475 │   │   try:                                                                                          │
         │ ❱ 476 │   │   │   await waiter                                                                              │
         │   477 │   │   except exceptions.CancelledError:                                                             │
         │   478 │   │   │   if fut.done():                                                                            │
         │   479 │   │   │   │   return fut.result()                                                                   │
         ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
         CancelledError

         During handling of the above exception, another exception occurred:

         ╭───────────────────────────────────── Traceback (most recent call last) ─────────────────────────────────────╮
         │ /home/ec2-user/src/llama-stack/schedule/llama_stack/distribution/server/server.py:152 in shutdown           │
         │                                                                                                             │
         │   149 │   │   │   logger.info("Shutting down %s", impl_name)                                                │
         │   150 │   │   │   try:                                                                                      │
         │   151 │   │   │   │   if hasattr(impl, "shutdown"):                                                         │
         │ ❱ 152 │   │   │   │   │   await asyncio.wait_for(impl.shutdown(), timeout=5)                                │
         │   153 │   │   │   │   else:                                                                                 │
         │   154 │   │   │   │   │   logger.warning("No shutdown method for %s", impl_name)                            │
         │   155 │   │   │   except asyncio.TimeoutError:                                                              │
         │                                                                                                             │
         │ /usr/lib64/python3.11/asyncio/tasks.py:479 in wait_for                                                      │
         │                                                                                                             │
         │   476 │   │   │   await waiter                                                                              │
         │   477 │   │   except exceptions.CancelledError:                                                             │
         │   478 │   │   │   if fut.done():                                                                            │
         │ ❱ 479 │   │   │   │   return fut.result()                                                                   │
         │   480 │   │   │   else:                                                                                     │
         │   481 │   │   │   │   fut.remove_done_callback(cb)                                                          │
         │   482 │   │   │   │   # We must ensure that the task is not running                                         │
         │                                                                                                             │
         │ /home/ec2-user/src/llama-stack/schedule/llama_stack/distribution/routers/routing_tables.py:131 in shutdown  │
         │                                                                                                             │
         │   128 │   │   │   elif api == Api.tool_runtime:                                                             │
         │   129 │   │   │   │   p.tool_store = self                                                                   │
         │   130 │                                                                                                     │
         │ ❱ 131 │   async def shutdown(self) -> None:                                                                 │
         │   132 │   │   for p in self.impls_by_provider_id.values():                                                  │
         │   133 │   │   │   await p.shutdown()                                                                        │
         │   134                                                                                                       │
         ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
         CancelledError
INFO     2025-03-07 23:20:35,295 __main__:149 server: Shutting down DatasetIORouter
INFO     2025-03-07 23:20:35,296 __main__:149 server: Shutting down ScoringFunctionsRoutingTable
INFO     2025-03-07 23:20:35,297 __main__:149 server: Shutting down ScoringRouter
INFO     2025-03-07 23:20:35,298 __main__:149 server: Shutting down ModelsRoutingTable
INFO     2025-03-07 23:20:35,299 __main__:149 server: Shutting down InferenceRouter
INFO     2025-03-07 23:20:35,300 __main__:149 server: Shutting down ShieldsRoutingTable
INFO     2025-03-07 23:20:35,300 __main__:149 server: Shutting down SafetyRouter
INFO     2025-03-07 23:20:35,301 __main__:149 server: Shutting down VectorDBsRoutingTable
INFO     2025-03-07 23:20:35,302 __main__:149 server: Shutting down VectorIORouter
INFO     2025-03-07 23:20:35,303 __main__:149 server: Shutting down ToolGroupsRoutingTable
INFO     2025-03-07 23:20:35,304 __main__:149 server: Shutting down ToolRuntimeRouter
INFO     2025-03-07 23:20:35,304 __main__:149 server: Shutting down MetaReferenceAgentsImpl
INFO     2025-03-07 23:20:35,305 __main__:149 server: Shutting down TelemetryAdapter
INFO     2025-03-07 23:20:35,306 __main__:149 server: Shutting down TorchtunePostTrainingImpl
ERROR    2025-03-07 23:20:35,307 __main__:158 server: Failed to shutdown TorchtunePostTrainingImpl:
         {CancelledError('Shutdown')}
         ╭───────────────────────────────────── Traceback (most recent call last) ─────────────────────────────────────╮
         │ /home/ec2-user/src/llama-stack/schedule/llama_stack/distribution/server/server.py:152 in shutdown           │
         │                                                                                                             │
         │   149 │   │   │   logger.info("Shutting down %s", impl_name)                                                │
         │   150 │   │   │   try:                                                                                      │
         │   151 │   │   │   │   if hasattr(impl, "shutdown"):                                                         │
         │ ❱ 152 │   │   │   │   │   await asyncio.wait_for(impl.shutdown(), timeout=5)                                │
         │   153 │   │   │   │   else:                                                                                 │
         │   154 │   │   │   │   │   logger.warning("No shutdown method for %s", impl_name)                            │
         │   155 │   │   │   except asyncio.TimeoutError:                                                              │
         │                                                                                                             │
         │ /usr/lib64/python3.11/asyncio/tasks.py:489 in wait_for                                                      │
         │                                                                                                             │
         │   486 │   │   │   │   raise                                                                                 │
         │   487 │   │                                                                                                 │
         │   488 │   │   if fut.done():                                                                                │
         │ ❱ 489 │   │   │   return fut.result()                                                                       │
         │   490 │   │   else:                                                                                         │
         │   491 │   │   │   fut.remove_done_callback(cb)                                                              │
         │   492 │   │   │   # We must ensure that the task is not running                                             │
         │                                                                                                             │
         │ /home/ec2-user/src/llama-stack/schedule/llama_stack/providers/inline/post_training/torchtune/post_training. │
         │ py:48 in shutdown                                                                                           │
         │                                                                                                             │
         │    45 │   │   self.checkpoints_dict = {}                                                                    │
         │    46 │                                                                                                     │
         │    47 │   async def shutdown(self) -> None:                                                                 │
         │ ❱  48 │   │   raise asyncio.CancelledError("Shutdown")                                                      │
         │    49 │                                                                                                     │
         │    50 │   async def supervised_fine_tune(                                                                   │
         │    51 │   │   self,                                                                                         │
         ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
         CancelledError: Shutdown
INFO     2025-03-07 23:20:35,352 __main__:149 server: Shutting down BenchmarksRoutingTable
INFO     2025-03-07 23:20:35,353 __main__:149 server: Shutting down EvalRouter
INFO     2025-03-07 23:20:35,354 __main__:149 server: Shutting down DistributionInspectImpl
INFO     2025-03-07 23:20:35,355 __main__:177 server: Stopping event loop
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/ec2-user/src/llama-stack/schedule/llama_stack/distribution/server/server.py", line 488, in <module>
    main()
  File "/home/ec2-user/src/llama-stack/schedule/llama_stack/distribution/server/server.py", line 476, in main
    uvicorn.run(**uvicorn_config)
  File "/home/ec2-user/src/llama-stack/schedule/venv/lib64/python3.11/site-packages/uvicorn/main.py", line 579, in run
    server.run()
  File "/home/ec2-user/src/llama-stack/schedule/venv/lib64/python3.11/site-packages/uvicorn/server.py", line 66, in run
    return asyncio.run(self.serve(sockets=sockets))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/asyncio/runners.py", line 189, in run
    with Runner(debug=debug) as runner:
  File "/usr/lib64/python3.11/asyncio/runners.py", line 63, in __exit__
    self.close()
  File "/usr/lib64/python3.11/asyncio/runners.py", line 71, in close
    _cancel_all_tasks(loop)
  File "/usr/lib64/python3.11/asyncio/runners.py", line 201, in _cancel_all_tasks
    loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
  File "/usr/lib64/python3.11/asyncio/base_events.py", line 652, in run_until_complete
    raise RuntimeError('Event loop stopped before Future completed.')
RuntimeError: Event loop stopped before Future completed.
++ error_handler 104
++ echo 'Error occurred in script at line: 104'
Error occurred in script at line: 104
++ exit 1
```

With all patches included, the shutdown now looks as follows:

```
$ kill -INT $(ps ax | grep  llama_stack.distribution.server.server | grep -v nvim | awk -e '{print $1}' | sort | head -n 1)
```

```
20:56:09.308 [START]
INFO:     Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO     2025-03-10 20:56:43,961 __main__:140 server: Shutting down
INFO     2025-03-10 20:56:43,962 __main__:124 server: Shutting down DatasetsRoutingTable
INFO     2025-03-10 20:56:43,964 __main__:124 server: Shutting down DatasetIORouter
INFO     2025-03-10 20:56:43,965 __main__:124 server: Shutting down ScoringFunctionsRoutingTable
INFO     2025-03-10 20:56:43,966 __main__:124 server: Shutting down ScoringRouter
INFO     2025-03-10 20:56:43,967 __main__:124 server: Shutting down ModelsRoutingTable
INFO     2025-03-10 20:56:43,968 __main__:124 server: Shutting down InferenceRouter
INFO     2025-03-10 20:56:43,969 __main__:124 server: Shutting down ShieldsRoutingTable
INFO     2025-03-10 20:56:43,971 __main__:124 server: Shutting down SafetyRouter
INFO     2025-03-10 20:56:43,972 __main__:124 server: Shutting down VectorDBsRoutingTable
INFO     2025-03-10 20:56:43,973 __main__:124 server: Shutting down VectorIORouter
INFO     2025-03-10 20:56:43,974 __main__:124 server: Shutting down ToolGroupsRoutingTable
INFO     2025-03-10 20:56:43,975 __main__:124 server: Shutting down ToolRuntimeRouter
INFO     2025-03-10 20:56:43,976 __main__:124 server: Shutting down MetaReferenceAgentsImpl
INFO     2025-03-10 20:56:43,977 __main__:124 server: Shutting down TelemetryAdapter
INFO     2025-03-10 20:56:43,978 __main__:124 server: Shutting down TorchtunePostTrainingImpl
WARNING  2025-03-10 20:56:43,979 __main__:129 server: No shutdown method for TorchtunePostTrainingImpl
INFO     2025-03-10 20:56:43,979 __main__:124 server: Shutting down BenchmarksRoutingTable
INFO     2025-03-10 20:56:43,980 __main__:124 server: Shutting down EvalRouter
INFO     2025-03-10 20:56:43,981 __main__:124 server: Shutting down DistributionInspectImpl
INFO:     Application shutdown complete.
INFO:     Finished server process [33862]
```

[//]: # (## Documentation)

---------

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
2025-03-11 10:30:55 -07:00

441 lines
15 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import inspect
import json
import os
import sys
import traceback
import warnings
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Any, List, Union
import yaml
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import (
preserve_headers_context_async_generator,
request_provider_data_context,
)
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import (
construct_stack,
redact_sensitive_fields,
replace_env_vars,
validate_env_pair,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
start_trace,
)
from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
log = file if hasattr(file, "write") else sys.stderr
traceback.print_stack(file=log)
log.write(warnings.formatwarning(message, category, filename, lineno, line))
if os.environ.get("LLAMA_STACK_TRACE_WARNINGS"):
warnings.showwarning = warn_with_traceback
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.model_dump_json()
else:
data = json.dumps(data)
return f"data: {data}\n\n"
async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc)
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
if isinstance(exc, ValidationError):
exc = RequestValidationError(exc.raw_errors)
if isinstance(exc, RequestValidationError):
return HTTPException(
status_code=400,
detail={
"errors": [
{
"loc": list(error["loc"]),
"msg": error["msg"],
"type": error["type"],
}
for error in exc.errors()
]
},
)
elif isinstance(exc, ValueError):
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, PermissionError):
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, TimeoutError):
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
else:
return HTTPException(
status_code=500,
detail="Internal server error: An unexpected error occurred.",
)
async def shutdown(app):
"""Initiate a graceful shutdown of the application.
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
"""
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except asyncio.TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except (Exception, asyncio.CancelledError) as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting up")
yield
logger.info("Shutting down")
await shutdown(app)
def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False)
async def maybe_await(value):
if inspect.iscoroutine(value):
return await value
return value
async def sse_generator(event_gen):
try:
async for item in await event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logger.info("Generator cancelled")
await event_gen.aclose()
except Exception as e:
logger.exception("Error in sse_generator")
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
# Use context manager for request provider data
with request_provider_data_context(request.headers):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
return StreamingResponse(gen, media_type="text/event-stream")
else:
value = func(**kwargs)
return await maybe_await(value)
except Exception as e:
logger.exception(f"Error executing endpoint {route=} {method=}")
raise translate_exception(e) from e
sig = inspect.signature(func)
new_params = [inspect.Parameter("request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request)]
new_params.extend(sig.parameters.values())
path_params = extract_path_params(route)
if method == "post":
# Annotate parameters that are in the path with Path(...) and others with Body(...)
new_params = [new_params[0]] + [
(
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
if param.name in path_params
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
)
for param in new_params[1:]
]
endpoint.__signature__ = sig.replace(parameters=new_params)
return endpoint
class TracingMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
path = scope.get("path", "")
await start_trace(path, {"__location__": "server"})
try:
return await self.app(scope, receive, send)
finally:
await end_trace()
class ClientVersionMiddleware:
def __init__(self, app):
self.app = app
self.server_version = parse_version("llama-stack")
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
headers = dict(scope.get("headers", []))
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
if client_version:
try:
client_version_parts = tuple(map(int, client_version.split(".")[:2]))
server_version_parts = tuple(map(int, self.server_version.split(".")[:2]))
if client_version_parts != server_version_parts:
async def send_version_error(send):
await send(
{
"type": "http.response.start",
"status": 426,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps(
{
"error": {
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please update your client."
}
}
).encode()
await send({"type": "http.response.body", "body": error_msg})
return await send_version_error(send)
except (ValueError, IndexError):
# If version parsing fails, let the request through
pass
return await self.app(scope, receive, send)
def main():
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
"--yaml-config",
help="Path to YAML configuration file",
)
parser.add_argument(
"--template",
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument("--disable-ipv6", action="store_true", help="Whether to disable IPv6 support")
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
parser.add_argument(
"--tls-keyfile",
help="Path to TLS key file for HTTPS",
required="--tls-certfile" in sys.argv,
)
parser.add_argument(
"--tls-certfile",
help="Path to TLS certificate file for HTTPS",
required="--tls-keyfile" in sys.argv,
)
args = parser.parse_args()
if args.env:
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if args.yaml_config:
# if the user provided a config file, use it, even if template was specified
config_file = Path(args.yaml_config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
logger.info(f"Using config file: {config_file}")
elif args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
logger.info(f"Using template {args.template} config file: {config_file}")
else:
raise ValueError("Either --yaml-config or --template must be provided")
with open(config_file, "r") as fp:
config = replace_env_vars(yaml.safe_load(fp))
config = StackRunConfig(**config)
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump())
logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
try:
impls = asyncio.run(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(TelemetryAdapter(TelemetryConfig()))
all_endpoints = get_all_api_endpoints()
if config.apis:
apis_to_serve = set(config.apis)
else:
apis_to_serve = set(impls.keys())
for inf in builtin_automatically_routed_apis():
# if we do not serve the corresponding router API, we should not serve the routing table API
if inf.router_api.value not in apis_to_serve:
continue
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api]
impl = impls[api]
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
impl_method = getattr(impl, endpoint.name)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(
impl_method,
endpoint.method,
endpoint.route,
)
)
logger.debug(f"serving APIs: {apis_to_serve}")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
import uvicorn
# Configure SSL if certificates are provided
port = args.port or config.server.port
ssl_config = None
if args.tls_keyfile:
keyfile = args.tls_keyfile
certfile = args.tls_certfile
else:
keyfile = config.server.tls_keyfile
certfile = config.server.tls_certfile
if keyfile and certfile:
ssl_config = {
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
logger.info(f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,
"host": listen_host,
"port": port,
"lifespan": "on",
}
if ssl_config:
uvicorn_config.update(ssl_config)
uvicorn.run(**uvicorn_config)
def extract_path_params(route: str) -> List[str]:
segments = route.split("/")
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
# to handle path params like {param:path}
params = [param.split(":")[0] for param in params]
return params
if __name__ == "__main__":
main()