mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
chore: refactor tracingmiddelware
# What does this PR do? ## Test Plan
This commit is contained in:
parent
8d8261961e
commit
e58b7427a7
2 changed files with 75 additions and 68 deletions
|
@ -25,7 +25,6 @@ from typing import Annotated, Any, get_origin
|
||||||
import httpx
|
import httpx
|
||||||
import rich.pretty
|
import rich.pretty
|
||||||
import yaml
|
import yaml
|
||||||
from aiohttp import hdrs
|
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||||
from fastapi import Path as FastapiPath
|
from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
@ -45,17 +44,13 @@ from llama_stack.core.datatypes import (
|
||||||
process_cors_config,
|
process_cors_config,
|
||||||
)
|
)
|
||||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
from llama_stack.core.external import load_external_apis
|
||||||
from llama_stack.core.request_headers import (
|
from llama_stack.core.request_headers import (
|
||||||
PROVIDER_DATA_VAR,
|
PROVIDER_DATA_VAR,
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
user_from_scope,
|
user_from_scope,
|
||||||
)
|
)
|
||||||
from llama_stack.core.server.routes import (
|
from llama_stack.core.server.routes import get_all_api_routes
|
||||||
find_matching_route,
|
|
||||||
get_all_api_routes,
|
|
||||||
initialize_route_impls,
|
|
||||||
)
|
|
||||||
from llama_stack.core.stack import (
|
from llama_stack.core.stack import (
|
||||||
Stack,
|
Stack,
|
||||||
cast_image_name_to_string,
|
cast_image_name_to_string,
|
||||||
|
@ -73,13 +68,12 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
CURRENT_TRACE_CONTEXT,
|
CURRENT_TRACE_CONTEXT,
|
||||||
end_trace,
|
|
||||||
setup_logger,
|
setup_logger,
|
||||||
start_trace,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from .auth import AuthenticationMiddleware
|
from .auth import AuthenticationMiddleware
|
||||||
from .quota import QuotaMiddleware
|
from .quota import QuotaMiddleware
|
||||||
|
from .tracing import TracingMiddleware
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
@ -299,65 +293,6 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
return route_handler
|
return route_handler
|
||||||
|
|
||||||
|
|
||||||
class TracingMiddleware:
|
|
||||||
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
|
||||||
self.app = app
|
|
||||||
self.impls = impls
|
|
||||||
self.external_apis = external_apis
|
|
||||||
# FastAPI built-in paths that should bypass custom routing
|
|
||||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
|
||||||
|
|
||||||
async def __call__(self, scope, receive, send):
|
|
||||||
if scope.get("type") == "lifespan":
|
|
||||||
return await self.app(scope, receive, send)
|
|
||||||
|
|
||||||
path = scope.get("path", "")
|
|
||||||
|
|
||||||
# Check if the path is a FastAPI built-in path
|
|
||||||
if path.startswith(self.fastapi_paths):
|
|
||||||
# Pass through to FastAPI's built-in handlers
|
|
||||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
|
||||||
return await self.app(scope, receive, send)
|
|
||||||
|
|
||||||
if not hasattr(self, "route_impls"):
|
|
||||||
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
|
||||||
|
|
||||||
try:
|
|
||||||
_, _, route_path, webmethod = find_matching_route(
|
|
||||||
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
# If no matching endpoint is found, pass through to FastAPI
|
|
||||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
|
||||||
return await self.app(scope, receive, send)
|
|
||||||
|
|
||||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
|
||||||
|
|
||||||
# Extract W3C trace context headers and store as trace attributes
|
|
||||||
headers = dict(scope.get("headers", []))
|
|
||||||
traceparent = headers.get(b"traceparent", b"").decode()
|
|
||||||
if traceparent:
|
|
||||||
trace_attributes["traceparent"] = traceparent
|
|
||||||
tracestate = headers.get(b"tracestate", b"").decode()
|
|
||||||
if tracestate:
|
|
||||||
trace_attributes["tracestate"] = tracestate
|
|
||||||
|
|
||||||
trace_path = webmethod.descriptive_name or route_path
|
|
||||||
trace_context = await start_trace(trace_path, trace_attributes)
|
|
||||||
|
|
||||||
async def send_with_trace_id(message):
|
|
||||||
if message["type"] == "http.response.start":
|
|
||||||
headers = message.get("headers", [])
|
|
||||||
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
|
||||||
message["headers"] = headers
|
|
||||||
await send(message)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await self.app(scope, receive, send_with_trace_id)
|
|
||||||
finally:
|
|
||||||
await end_trace()
|
|
||||||
|
|
||||||
|
|
||||||
class ClientVersionMiddleware:
|
class ClientVersionMiddleware:
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
self.app = app
|
self.app = app
|
||||||
|
|
72
llama_stack/core/server/tracing.py
Normal file
72
llama_stack/core/server/tracing.py
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from aiohttp import hdrs
|
||||||
|
|
||||||
|
from llama_stack.core.external import ExternalApiSpec
|
||||||
|
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core::server")
|
||||||
|
|
||||||
|
|
||||||
|
class TracingMiddleware:
|
||||||
|
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
||||||
|
self.app = app
|
||||||
|
self.impls = impls
|
||||||
|
self.external_apis = external_apis
|
||||||
|
# FastAPI built-in paths that should bypass custom routing
|
||||||
|
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
if scope.get("type") == "lifespan":
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
path = scope.get("path", "")
|
||||||
|
|
||||||
|
# Check if the path is a FastAPI built-in path
|
||||||
|
if path.startswith(self.fastapi_paths):
|
||||||
|
# Pass through to FastAPI's built-in handlers
|
||||||
|
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
if not hasattr(self, "route_impls"):
|
||||||
|
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_, _, route_path, webmethod = find_matching_route(
|
||||||
|
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
# If no matching endpoint is found, pass through to FastAPI
|
||||||
|
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||||
|
|
||||||
|
# Extract W3C trace context headers and store as trace attributes
|
||||||
|
headers = dict(scope.get("headers", []))
|
||||||
|
traceparent = headers.get(b"traceparent", b"").decode()
|
||||||
|
if traceparent:
|
||||||
|
trace_attributes["traceparent"] = traceparent
|
||||||
|
tracestate = headers.get(b"tracestate", b"").decode()
|
||||||
|
if tracestate:
|
||||||
|
trace_attributes["tracestate"] = tracestate
|
||||||
|
|
||||||
|
trace_path = webmethod.descriptive_name or route_path
|
||||||
|
trace_context = await start_trace(trace_path, trace_attributes)
|
||||||
|
|
||||||
|
async def send_with_trace_id(message):
|
||||||
|
if message["type"] == "http.response.start":
|
||||||
|
headers = message.get("headers", [])
|
||||||
|
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||||
|
message["headers"] = headers
|
||||||
|
await send(message)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self.app(scope, receive, send_with_trace_id)
|
||||||
|
finally:
|
||||||
|
await end_trace()
|
Loading…
Add table
Add a link
Reference in a new issue