mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? This extracts the W3C trace context headers (traceparent and tracestate) from incoming requests, stuffs them as attributes on the spans we create, and uses them within the tracing provider implementation to actually wrap our spans in the proper context. What this means in practice is that when a client (such as an OpenAI client) is instrumented to create these traces, we'll continue that distributed trace within Llama Stack as opposed to creating our own root span that breaks the distributed trace between client and server. It's slightly awkward to do this in Llama Stack because our Tracing API knows nothing about opentelemetry, W3C trace headers, etc - that's only knowledge the specific provider implementation has. So, that's why the trace headers get extracted by in the server code but not actually used until the provider implementation to form the proper context. This also centralizes how we were adding the `__root__` and `__root_span__` attributes, as those two were being added in different parts of the code instead of from a single place. Closes #2097 ## Test Plan This was tested manually using the helpful scripts from #2097. I verified that Llama Stack properly joined the client's span when the client was instrumented for distributed tracing, and that Llama Stack properly started its own root span when the incoming request was not part of an existing trace. Here's an example of the joined spans:  Signed-off-by: Ben Browning <bbrownin@redhat.com>
544 lines
20 KiB
Python
544 lines
20 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import argparse
|
|
import asyncio
|
|
import inspect
|
|
import json
|
|
import os
|
|
import ssl
|
|
import sys
|
|
import traceback
|
|
import warnings
|
|
from contextlib import asynccontextmanager
|
|
from importlib.metadata import version as parse_version
|
|
from pathlib import Path
|
|
from typing import Annotated, Any
|
|
|
|
import rich.pretty
|
|
import yaml
|
|
from fastapi import Body, FastAPI, HTTPException, Request
|
|
from fastapi import Path as FastapiPath
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from openai import BadRequestError
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
|
from llama_stack.distribution.request_headers import (
|
|
PROVIDER_DATA_VAR,
|
|
request_provider_data_context,
|
|
)
|
|
from llama_stack.distribution.resolver import InvalidProviderError
|
|
from llama_stack.distribution.server.endpoints import (
|
|
find_matching_endpoint,
|
|
initialize_endpoint_impls,
|
|
)
|
|
from llama_stack.distribution.stack import (
|
|
construct_stack,
|
|
replace_env_vars,
|
|
validate_env_pair,
|
|
)
|
|
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.datatypes import Api
|
|
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
|
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
|
TelemetryAdapter,
|
|
)
|
|
from llama_stack.providers.utils.telemetry.tracing import (
|
|
CURRENT_TRACE_CONTEXT,
|
|
end_trace,
|
|
setup_logger,
|
|
start_trace,
|
|
)
|
|
|
|
from .auth import AuthenticationMiddleware
|
|
from .endpoints import get_all_api_endpoints
|
|
|
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
|
|
|
logger = get_logger(name=__name__, category="server")
|
|
|
|
|
|
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
|
log = file if hasattr(file, "write") else sys.stderr
|
|
traceback.print_stack(file=log)
|
|
log.write(warnings.formatwarning(message, category, filename, lineno, line))
|
|
|
|
|
|
if os.environ.get("LLAMA_STACK_TRACE_WARNINGS"):
|
|
warnings.showwarning = warn_with_traceback
|
|
|
|
|
|
def create_sse_event(data: Any) -> str:
|
|
if isinstance(data, BaseModel):
|
|
data = data.model_dump_json()
|
|
else:
|
|
data = json.dumps(data)
|
|
|
|
return f"data: {data}\n\n"
|
|
|
|
|
|
async def global_exception_handler(request: Request, exc: Exception):
|
|
traceback.print_exception(exc)
|
|
http_exc = translate_exception(exc)
|
|
|
|
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
|
|
|
|
|
|
def translate_exception(exc: Exception) -> HTTPException | RequestValidationError:
|
|
if isinstance(exc, ValidationError):
|
|
exc = RequestValidationError(exc.errors())
|
|
|
|
if isinstance(exc, RequestValidationError):
|
|
return HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"errors": [
|
|
{
|
|
"loc": list(error["loc"]),
|
|
"msg": error["msg"],
|
|
"type": error["type"],
|
|
}
|
|
for error in exc.errors()
|
|
]
|
|
},
|
|
)
|
|
elif isinstance(exc, ValueError):
|
|
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
|
|
elif isinstance(exc, BadRequestError):
|
|
return HTTPException(status_code=400, detail=str(exc))
|
|
elif isinstance(exc, PermissionError):
|
|
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
|
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
|
|
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
|
elif isinstance(exc, NotImplementedError):
|
|
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
|
else:
|
|
return HTTPException(
|
|
status_code=500,
|
|
detail="Internal server error: An unexpected error occurred.",
|
|
)
|
|
|
|
|
|
async def shutdown(app):
|
|
"""Initiate a graceful shutdown of the application.
|
|
|
|
Handled by the lifespan context manager. The shutdown process involves
|
|
shutting down all implementations registered in the application.
|
|
"""
|
|
for impl in app.__llama_stack_impls__.values():
|
|
impl_name = impl.__class__.__name__
|
|
logger.info("Shutting down %s", impl_name)
|
|
try:
|
|
if hasattr(impl, "shutdown"):
|
|
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
|
else:
|
|
logger.warning("No shutdown method for %s", impl_name)
|
|
except (asyncio.TimeoutError, TimeoutError):
|
|
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
|
except (Exception, asyncio.CancelledError) as e:
|
|
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
logger.info("Starting up")
|
|
yield
|
|
logger.info("Shutting down")
|
|
await shutdown(app)
|
|
|
|
|
|
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
|
# TODO: pass the api method and punt it to the Protocol definition directly
|
|
return kwargs.get("stream", False)
|
|
|
|
|
|
async def maybe_await(value):
|
|
if inspect.iscoroutine(value):
|
|
return await value
|
|
return value
|
|
|
|
|
|
async def sse_generator(event_gen_coroutine):
|
|
event_gen = None
|
|
try:
|
|
event_gen = await event_gen_coroutine
|
|
async for item in event_gen:
|
|
yield create_sse_event(item)
|
|
await asyncio.sleep(0.01)
|
|
except asyncio.CancelledError:
|
|
logger.info("Generator cancelled")
|
|
if event_gen:
|
|
await event_gen.aclose()
|
|
except Exception as e:
|
|
logger.exception("Error in sse_generator")
|
|
yield create_sse_event(
|
|
{
|
|
"error": {
|
|
"message": str(translate_exception(e)),
|
|
},
|
|
}
|
|
)
|
|
|
|
|
|
async def log_request_pre_validation(request: Request):
|
|
if request.method in ("POST", "PUT", "PATCH"):
|
|
try:
|
|
body_bytes = await request.body()
|
|
if body_bytes:
|
|
try:
|
|
parsed_body = json.loads(body_bytes.decode())
|
|
log_output = rich.pretty.pretty_repr(parsed_body)
|
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
log_output = repr(body_bytes)
|
|
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
|
|
else:
|
|
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
|
|
except Exception as e:
|
|
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
|
|
|
|
|
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|
async def endpoint(request: Request, **kwargs):
|
|
# Get auth attributes from the request scope
|
|
user_attributes = request.scope.get("user_attributes", {})
|
|
|
|
await log_request_pre_validation(request)
|
|
|
|
# Use context manager with both provider data and auth attributes
|
|
with request_provider_data_context(request.headers, user_attributes):
|
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
|
|
|
try:
|
|
if is_streaming:
|
|
gen = preserve_contexts_async_generator(
|
|
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
|
)
|
|
return StreamingResponse(gen, media_type="text/event-stream")
|
|
else:
|
|
value = func(**kwargs)
|
|
return await maybe_await(value)
|
|
except Exception as e:
|
|
logger.exception(f"Error executing endpoint {route=} {method=}")
|
|
raise translate_exception(e) from e
|
|
|
|
sig = inspect.signature(func)
|
|
|
|
new_params = [inspect.Parameter("request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request)]
|
|
new_params.extend(sig.parameters.values())
|
|
|
|
path_params = extract_path_params(route)
|
|
if method == "post":
|
|
# Annotate parameters that are in the path with Path(...) and others with Body(...)
|
|
new_params = [new_params[0]] + [
|
|
(
|
|
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
|
|
if param.name in path_params
|
|
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
|
)
|
|
for param in new_params[1:]
|
|
]
|
|
|
|
endpoint.__signature__ = sig.replace(parameters=new_params)
|
|
|
|
return endpoint
|
|
|
|
|
|
class TracingMiddleware:
|
|
def __init__(self, app, impls):
|
|
self.app = app
|
|
self.impls = impls
|
|
# FastAPI built-in paths that should bypass custom routing
|
|
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
|
|
|
async def __call__(self, scope, receive, send):
|
|
if scope.get("type") == "lifespan":
|
|
return await self.app(scope, receive, send)
|
|
|
|
path = scope.get("path", "")
|
|
|
|
# Check if the path is a FastAPI built-in path
|
|
if path.startswith(self.fastapi_paths):
|
|
# Pass through to FastAPI's built-in handlers
|
|
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
|
return await self.app(scope, receive, send)
|
|
|
|
if not hasattr(self, "endpoint_impls"):
|
|
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
|
|
|
try:
|
|
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
|
except ValueError:
|
|
# If no matching endpoint is found, pass through to FastAPI
|
|
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
|
|
return await self.app(scope, receive, send)
|
|
|
|
trace_attributes = {"__location__": "server", "raw_path": path}
|
|
|
|
# Extract W3C trace context headers and store as trace attributes
|
|
headers = dict(scope.get("headers", []))
|
|
traceparent = headers.get(b"traceparent", b"").decode()
|
|
if traceparent:
|
|
trace_attributes["traceparent"] = traceparent
|
|
tracestate = headers.get(b"tracestate", b"").decode()
|
|
if tracestate:
|
|
trace_attributes["tracestate"] = tracestate
|
|
|
|
trace_context = await start_trace(trace_path, trace_attributes)
|
|
|
|
async def send_with_trace_id(message):
|
|
if message["type"] == "http.response.start":
|
|
headers = message.get("headers", [])
|
|
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
|
message["headers"] = headers
|
|
await send(message)
|
|
|
|
try:
|
|
return await self.app(scope, receive, send_with_trace_id)
|
|
finally:
|
|
await end_trace()
|
|
|
|
|
|
class ClientVersionMiddleware:
|
|
def __init__(self, app):
|
|
self.app = app
|
|
self.server_version = parse_version("llama-stack")
|
|
|
|
async def __call__(self, scope, receive, send):
|
|
if scope["type"] == "http":
|
|
headers = dict(scope.get("headers", []))
|
|
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
|
|
if client_version:
|
|
try:
|
|
client_version_parts = tuple(map(int, client_version.split(".")[:2]))
|
|
server_version_parts = tuple(map(int, self.server_version.split(".")[:2]))
|
|
if client_version_parts != server_version_parts:
|
|
|
|
async def send_version_error(send):
|
|
await send(
|
|
{
|
|
"type": "http.response.start",
|
|
"status": 426,
|
|
"headers": [[b"content-type", b"application/json"]],
|
|
}
|
|
)
|
|
error_msg = json.dumps(
|
|
{
|
|
"error": {
|
|
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please update your client."
|
|
}
|
|
}
|
|
).encode()
|
|
await send({"type": "http.response.body", "body": error_msg})
|
|
|
|
return await send_version_error(send)
|
|
except (ValueError, IndexError):
|
|
# If version parsing fails, let the request through
|
|
pass
|
|
|
|
return await self.app(scope, receive, send)
|
|
|
|
|
|
def main(args: argparse.Namespace | None = None):
|
|
"""Start the LlamaStack server."""
|
|
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
|
parser.add_argument(
|
|
"--yaml-config",
|
|
dest="config",
|
|
help="(Deprecated) Path to YAML configuration file - use --config instead",
|
|
)
|
|
parser.add_argument(
|
|
"--config",
|
|
dest="config",
|
|
help="Path to YAML configuration file",
|
|
)
|
|
parser.add_argument(
|
|
"--template",
|
|
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
|
help="Port to listen on",
|
|
)
|
|
parser.add_argument(
|
|
"--env",
|
|
action="append",
|
|
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
|
)
|
|
|
|
# Determine whether the server args are being passed by the "run" command, if this is the case
|
|
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
|
# parsed from the command line
|
|
if args is None:
|
|
args = parser.parse_args()
|
|
|
|
log_line = ""
|
|
if args.config:
|
|
# if the user provided a config file, use it, even if template was specified
|
|
config_file = Path(args.config)
|
|
if not config_file.exists():
|
|
raise ValueError(f"Config file {config_file} does not exist")
|
|
log_line = f"Using config file: {config_file}"
|
|
elif args.template:
|
|
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
|
if not config_file.exists():
|
|
raise ValueError(f"Template {args.template} does not exist")
|
|
log_line = f"Using template {args.template} config file: {config_file}"
|
|
else:
|
|
raise ValueError("Either --config or --template must be provided")
|
|
|
|
logger_config = None
|
|
with open(config_file) as fp:
|
|
config_contents = yaml.safe_load(fp)
|
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
|
logger_config = LoggingConfig(**cfg)
|
|
logger = get_logger(name=__name__, category="server", config=logger_config)
|
|
if args.env:
|
|
for env_pair in args.env:
|
|
try:
|
|
key, value = validate_env_pair(env_pair)
|
|
logger.info(f"Setting CLI environment variable {key} => {value}")
|
|
os.environ[key] = value
|
|
except ValueError as e:
|
|
logger.error(f"Error: {str(e)}")
|
|
sys.exit(1)
|
|
config = replace_env_vars(config_contents)
|
|
config = StackRunConfig(**config)
|
|
|
|
# now that the logger is initialized, print the line about which type of config we are using.
|
|
logger.info(log_line)
|
|
|
|
logger.info("Run configuration:")
|
|
safe_config = redact_sensitive_fields(config.model_dump())
|
|
logger.info(yaml.dump(safe_config, indent=2))
|
|
|
|
app = FastAPI(
|
|
lifespan=lifespan,
|
|
docs_url="/docs",
|
|
redoc_url="/redoc",
|
|
openapi_url="/openapi.json",
|
|
)
|
|
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
|
app.add_middleware(ClientVersionMiddleware)
|
|
|
|
# Add authentication middleware if configured
|
|
if config.server.auth:
|
|
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
|
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
|
|
|
try:
|
|
impls = asyncio.run(construct_stack(config))
|
|
except InvalidProviderError as e:
|
|
logger.error(f"Error: {str(e)}")
|
|
sys.exit(1)
|
|
|
|
if Api.telemetry in impls:
|
|
setup_logger(impls[Api.telemetry])
|
|
else:
|
|
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
|
|
|
all_endpoints = get_all_api_endpoints()
|
|
|
|
if config.apis:
|
|
apis_to_serve = set(config.apis)
|
|
else:
|
|
apis_to_serve = set(impls.keys())
|
|
|
|
for inf in builtin_automatically_routed_apis():
|
|
# if we do not serve the corresponding router API, we should not serve the routing table API
|
|
if inf.router_api.value not in apis_to_serve:
|
|
continue
|
|
apis_to_serve.add(inf.routing_table_api.value)
|
|
|
|
apis_to_serve.add("inspect")
|
|
apis_to_serve.add("providers")
|
|
for api_str in apis_to_serve:
|
|
api = Api(api_str)
|
|
|
|
endpoints = all_endpoints[api]
|
|
impl = impls[api]
|
|
|
|
for endpoint in endpoints:
|
|
if not hasattr(impl, endpoint.name):
|
|
# ideally this should be a typing violation already
|
|
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
|
|
|
impl_method = getattr(impl, endpoint.name)
|
|
logger.debug(f"{endpoint.method.upper()} {endpoint.route}")
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
|
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
|
create_dynamic_typed_route(
|
|
impl_method,
|
|
endpoint.method,
|
|
endpoint.route,
|
|
)
|
|
)
|
|
|
|
logger.debug(f"serving APIs: {apis_to_serve}")
|
|
|
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
|
app.exception_handler(Exception)(global_exception_handler)
|
|
|
|
app.__llama_stack_impls__ = impls
|
|
app.add_middleware(TracingMiddleware, impls=impls)
|
|
|
|
import uvicorn
|
|
|
|
# Configure SSL if certificates are provided
|
|
port = args.port or config.server.port
|
|
|
|
ssl_config = None
|
|
keyfile = config.server.tls_keyfile
|
|
certfile = config.server.tls_certfile
|
|
|
|
if keyfile and certfile:
|
|
ssl_config = {
|
|
"ssl_keyfile": keyfile,
|
|
"ssl_certfile": certfile,
|
|
}
|
|
if config.server.tls_cafile:
|
|
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
|
|
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
|
logger.info(
|
|
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
|
)
|
|
else:
|
|
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
|
|
|
listen_host = config.server.host or ["::", "0.0.0.0"]
|
|
logger.info(f"Listening on {listen_host}:{port}")
|
|
|
|
uvicorn_config = {
|
|
"app": app,
|
|
"host": listen_host,
|
|
"port": port,
|
|
"lifespan": "on",
|
|
"log_level": logger.getEffectiveLevel(),
|
|
}
|
|
if ssl_config:
|
|
uvicorn_config.update(ssl_config)
|
|
|
|
uvicorn.run(**uvicorn_config)
|
|
|
|
|
|
def extract_path_params(route: str) -> list[str]:
|
|
segments = route.split("/")
|
|
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
|
|
# to handle path params like {param:path}
|
|
params = [param.split(":")[0] for param in params]
|
|
return params
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|