chore: refactor tracingmiddelware

# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-09-22 11:23:31 -07:00
parent 8d8261961e
commit e58b7427a7
2 changed files with 75 additions and 68 deletions

View file

@ -25,7 +25,6 @@ from typing import Annotated, Any, get_origin
import httpx import httpx
import rich.pretty import rich.pretty
import yaml import yaml
from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Path as FastapiPath from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
@ -45,17 +44,13 @@ from llama_stack.core.datatypes import (
process_cors_config, process_cors_config,
) )
from llama_stack.core.distribution import builtin_automatically_routed_apis from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import ExternalApiSpec, load_external_apis from llama_stack.core.external import load_external_apis
from llama_stack.core.request_headers import ( from llama_stack.core.request_headers import (
PROVIDER_DATA_VAR, PROVIDER_DATA_VAR,
request_provider_data_context, request_provider_data_context,
user_from_scope, user_from_scope,
) )
from llama_stack.core.server.routes import ( from llama_stack.core.server.routes import get_all_api_routes
find_matching_route,
get_all_api_routes,
initialize_route_impls,
)
from llama_stack.core.stack import ( from llama_stack.core.stack import (
Stack, Stack,
cast_image_name_to_string, cast_image_name_to_string,
@ -73,13 +68,12 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
) )
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT, CURRENT_TRACE_CONTEXT,
end_trace,
setup_logger, setup_logger,
start_trace,
) )
from .auth import AuthenticationMiddleware from .auth import AuthenticationMiddleware
from .quota import QuotaMiddleware from .quota import QuotaMiddleware
from .tracing import TracingMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -299,65 +293,6 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
return route_handler return route_handler
class TracingMiddleware:
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
self.app = app
self.impls = impls
self.external_apis = external_apis
# FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
async def __call__(self, scope, receive, send):
if scope.get("type") == "lifespan":
return await self.app(scope, receive, send)
path = scope.get("path", "")
# Check if the path is a FastAPI built-in path
if path.startswith(self.fastapi_paths):
# Pass through to FastAPI's built-in handlers
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
try:
_, _, route_path, webmethod = find_matching_route(
scope.get("method", hdrs.METH_GET), path, self.route_impls
)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)
trace_attributes = {"__location__": "server", "raw_path": path}
# Extract W3C trace context headers and store as trace attributes
headers = dict(scope.get("headers", []))
traceparent = headers.get(b"traceparent", b"").decode()
if traceparent:
trace_attributes["traceparent"] = traceparent
tracestate = headers.get(b"tracestate", b"").decode()
if tracestate:
trace_attributes["tracestate"] = tracestate
trace_path = webmethod.descriptive_name or route_path
trace_context = await start_trace(trace_path, trace_attributes)
async def send_with_trace_id(message):
if message["type"] == "http.response.start":
headers = message.get("headers", [])
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
message["headers"] = headers
await send(message)
try:
return await self.app(scope, receive, send_with_trace_id)
finally:
await end_trace()
class ClientVersionMiddleware: class ClientVersionMiddleware:
def __init__(self, app): def __init__(self, app):
self.app = app self.app = app

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from aiohttp import hdrs
from llama_stack.core.external import ExternalApiSpec
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace
logger = get_logger(name=__name__, category="core::server")
class TracingMiddleware:
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
self.app = app
self.impls = impls
self.external_apis = external_apis
# FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
async def __call__(self, scope, receive, send):
if scope.get("type") == "lifespan":
return await self.app(scope, receive, send)
path = scope.get("path", "")
# Check if the path is a FastAPI built-in path
if path.startswith(self.fastapi_paths):
# Pass through to FastAPI's built-in handlers
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
try:
_, _, route_path, webmethod = find_matching_route(
scope.get("method", hdrs.METH_GET), path, self.route_impls
)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)
trace_attributes = {"__location__": "server", "raw_path": path}
# Extract W3C trace context headers and store as trace attributes
headers = dict(scope.get("headers", []))
traceparent = headers.get(b"traceparent", b"").decode()
if traceparent:
trace_attributes["traceparent"] = traceparent
tracestate = headers.get(b"tracestate", b"").decode()
if tracestate:
trace_attributes["tracestate"] = tracestate
trace_path = webmethod.descriptive_name or route_path
trace_context = await start_trace(trace_path, trace_attributes)
async def send_with_trace_id(message):
if message["type"] == "http.response.start":
headers = message.get("headers", [])
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
message["headers"] = headers
await send(message)
try:
return await self.app(scope, receive, send_with_trace_id)
finally:
await end_trace()