fix(telemetry): allow open telemetry to inject trace headers

This commit is contained in:
Emilio Garcia 2025-11-10 16:56:43 -05:00
parent a15aa94dab
commit 994abc81b4
2 changed files with 1 additions and 90 deletions

View file

@ -50,8 +50,6 @@ from llama_stack.core.stack import (
cast_image_name_to_string, cast_image_name_to_string,
replace_env_vars, replace_env_vars,
) )
from llama_stack.core.telemetry import Telemetry
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, setup_logger
from llama_stack.core.utils.config import redact_sensitive_fields from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
from llama_stack.core.utils.context import preserve_contexts_async_generator from llama_stack.core.utils.context import preserve_contexts_async_generator
@ -60,7 +58,6 @@ from llama_stack_api import Api, ConflictError, PaginatedResponse, ResourceNotFo
from .auth import AuthenticationMiddleware from .auth import AuthenticationMiddleware
from .quota import QuotaMiddleware from .quota import QuotaMiddleware
from .tracing import TracingMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -263,7 +260,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
try: try:
if is_streaming: if is_streaming:
context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR] context_vars = [PROVIDER_DATA_VAR]
if test_context_var is not None: if test_context_var is not None:
context_vars.append(test_context_var) context_vars.append(test_context_var)
gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), context_vars) gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), context_vars)
@ -441,9 +438,6 @@ def create_app() -> StackApp:
if cors_config: if cors_config:
app.add_middleware(CORSMiddleware, **cors_config.model_dump()) app.add_middleware(CORSMiddleware, **cors_config.model_dump())
if config.telemetry.enabled:
setup_logger(Telemetry())
# Load external APIs if configured # Load external APIs if configured
external_apis = load_external_apis(config) external_apis = load_external_apis(config)
all_routes = get_all_api_routes(external_apis) all_routes = get_all_api_routes(external_apis)
@ -500,9 +494,6 @@ def create_app() -> StackApp:
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
if config.telemetry.enabled:
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
return app return app

View file

@ -1,80 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from aiohttp import hdrs
from llama_stack.core.external import ExternalApiSpec
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.core.telemetry.tracing import end_trace, start_trace
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core::server")
class TracingMiddleware:
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
self.app = app
self.impls = impls
self.external_apis = external_apis
# FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
async def __call__(self, scope, receive, send):
if scope.get("type") == "lifespan":
return await self.app(scope, receive, send)
path = scope.get("path", "")
# Check if the path is a FastAPI built-in path
if path.startswith(self.fastapi_paths):
# Pass through to FastAPI's built-in handlers
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
try:
_, _, route_path, webmethod = find_matching_route(
scope.get("method", hdrs.METH_GET), path, self.route_impls
)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)
# Log deprecation warning if route is deprecated
if getattr(webmethod, "deprecated", False):
logger.warning(
f"DEPRECATED ROUTE USED: {scope.get('method', 'GET')} {path} - "
f"This route is deprecated and may be removed in a future version. "
f"Please check the docs for the supported version."
)
trace_attributes = {"__location__": "server", "raw_path": path}
# Extract W3C trace context headers and store as trace attributes
headers = dict(scope.get("headers", []))
traceparent = headers.get(b"traceparent", b"").decode()
if traceparent:
trace_attributes["traceparent"] = traceparent
tracestate = headers.get(b"tracestate", b"").decode()
if tracestate:
trace_attributes["tracestate"] = tracestate
trace_path = webmethod.descriptive_name or route_path
trace_context = await start_trace(trace_path, trace_attributes)
async def send_with_trace_id(message):
if message["type"] == "http.response.start":
headers = message.get("headers", [])
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
message["headers"] = headers
await send(message)
try:
return await self.app(scope, receive, send_with_trace_id)
finally:
await end_trace()