From e58b7427a7c7de953ffb706da9f3f845bfa3c3da Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Mon, 22 Sep 2025 11:23:31 -0700 Subject: [PATCH] chore: refactor tracingmiddelware # What does this PR do? ## Test Plan --- llama_stack/core/server/server.py | 71 ++--------------------------- llama_stack/core/server/tracing.py | 72 ++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 68 deletions(-) create mode 100644 llama_stack/core/server/tracing.py diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index 9cca42268..7d119c139 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -25,7 +25,6 @@ from typing import Annotated, Any, get_origin import httpx import rich.pretty import yaml -from aiohttp import hdrs from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError @@ -45,17 +44,13 @@ from llama_stack.core.datatypes import ( process_cors_config, ) from llama_stack.core.distribution import builtin_automatically_routed_apis -from llama_stack.core.external import ExternalApiSpec, load_external_apis +from llama_stack.core.external import load_external_apis from llama_stack.core.request_headers import ( PROVIDER_DATA_VAR, request_provider_data_context, user_from_scope, ) -from llama_stack.core.server.routes import ( - find_matching_route, - get_all_api_routes, - initialize_route_impls, -) +from llama_stack.core.server.routes import get_all_api_routes from llama_stack.core.stack import ( Stack, cast_image_name_to_string, @@ -73,13 +68,12 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( ) from llama_stack.providers.utils.telemetry.tracing import ( CURRENT_TRACE_CONTEXT, - end_trace, setup_logger, - start_trace, ) from .auth import AuthenticationMiddleware from .quota import QuotaMiddleware +from .tracing import TracingMiddleware REPO_ROOT = Path(__file__).parent.parent.parent.parent @@ -299,65 +293,6 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: return route_handler -class TracingMiddleware: - def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]): - self.app = app - self.impls = impls - self.external_apis = external_apis - # FastAPI built-in paths that should bypass custom routing - self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static") - - async def __call__(self, scope, receive, send): - if scope.get("type") == "lifespan": - return await self.app(scope, receive, send) - - path = scope.get("path", "") - - # Check if the path is a FastAPI built-in path - if path.startswith(self.fastapi_paths): - # Pass through to FastAPI's built-in handlers - logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}") - return await self.app(scope, receive, send) - - if not hasattr(self, "route_impls"): - self.route_impls = initialize_route_impls(self.impls, self.external_apis) - - try: - _, _, route_path, webmethod = find_matching_route( - scope.get("method", hdrs.METH_GET), path, self.route_impls - ) - except ValueError: - # If no matching endpoint is found, pass through to FastAPI - logger.debug(f"No matching route found for path: {path}, falling back to FastAPI") - return await self.app(scope, receive, send) - - trace_attributes = {"__location__": "server", "raw_path": path} - - # Extract W3C trace context headers and store as trace attributes - headers = dict(scope.get("headers", [])) - traceparent = headers.get(b"traceparent", b"").decode() - if traceparent: - trace_attributes["traceparent"] = traceparent - tracestate = headers.get(b"tracestate", b"").decode() - if tracestate: - trace_attributes["tracestate"] = tracestate - - trace_path = webmethod.descriptive_name or route_path - trace_context = await start_trace(trace_path, trace_attributes) - - async def send_with_trace_id(message): - if message["type"] == "http.response.start": - headers = message.get("headers", []) - headers.append([b"x-trace-id", str(trace_context.trace_id).encode()]) - message["headers"] = headers - await send(message) - - try: - return await self.app(scope, receive, send_with_trace_id) - finally: - await end_trace() - - class ClientVersionMiddleware: def __init__(self, app): self.app = app diff --git a/llama_stack/core/server/tracing.py b/llama_stack/core/server/tracing.py new file mode 100644 index 000000000..c48fc4d33 --- /dev/null +++ b/llama_stack/core/server/tracing.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from aiohttp import hdrs + +from llama_stack.core.external import ExternalApiSpec +from llama_stack.core.server.routes import find_matching_route, initialize_route_impls +from llama_stack.log import get_logger +from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace + +logger = get_logger(name=__name__, category="core::server") + + +class TracingMiddleware: + def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]): + self.app = app + self.impls = impls + self.external_apis = external_apis + # FastAPI built-in paths that should bypass custom routing + self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static") + + async def __call__(self, scope, receive, send): + if scope.get("type") == "lifespan": + return await self.app(scope, receive, send) + + path = scope.get("path", "") + + # Check if the path is a FastAPI built-in path + if path.startswith(self.fastapi_paths): + # Pass through to FastAPI's built-in handlers + logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}") + return await self.app(scope, receive, send) + + if not hasattr(self, "route_impls"): + self.route_impls = initialize_route_impls(self.impls, self.external_apis) + + try: + _, _, route_path, webmethod = find_matching_route( + scope.get("method", hdrs.METH_GET), path, self.route_impls + ) + except ValueError: + # If no matching endpoint is found, pass through to FastAPI + logger.debug(f"No matching route found for path: {path}, falling back to FastAPI") + return await self.app(scope, receive, send) + + trace_attributes = {"__location__": "server", "raw_path": path} + + # Extract W3C trace context headers and store as trace attributes + headers = dict(scope.get("headers", [])) + traceparent = headers.get(b"traceparent", b"").decode() + if traceparent: + trace_attributes["traceparent"] = traceparent + tracestate = headers.get(b"tracestate", b"").decode() + if tracestate: + trace_attributes["tracestate"] = tracestate + + trace_path = webmethod.descriptive_name or route_path + trace_context = await start_trace(trace_path, trace_attributes) + + async def send_with_trace_id(message): + if message["type"] == "http.response.start": + headers = message.get("headers", []) + headers.append([b"x-trace-id", str(trace_context.trace_id).encode()]) + message["headers"] = headers + await send(message) + + try: + return await self.app(scope, receive, send_with_trace_id) + finally: + await end_trace()