mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-08 04:54:38 +00:00
feat(major): move new telemetry architecture into new provider
This commit is contained in:
parent
ce3a804893
commit
e45e77f7b0
10 changed files with 207 additions and 52 deletions
|
@ -0,0 +1,83 @@
|
|||
from aiohttp import hdrs
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.core.external import ExternalApiSpec
|
||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace
|
||||
|
||||
|
||||
logger = get_logger(name=__name__, category="telemetry::meta_reference")
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
impls: dict[Api, Any],
|
||||
external_apis: dict[str, ExternalApiSpec],
|
||||
):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.external_apis = external_apis
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Check if the path is a FastAPI built-in path
|
||||
if path.startswith(self.fastapi_paths):
|
||||
# Pass through to FastAPI's built-in handlers
|
||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
external_api_map = {Api(api_name): spec for api_name, spec in self.external_apis.items()}
|
||||
self.route_impls = initialize_route_impls(self.impls, external_api_map)
|
||||
|
||||
try:
|
||||
_, _, route_path, webmethod = find_matching_route(
|
||||
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||
)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Log deprecation warning if route is deprecated
|
||||
if getattr(webmethod, "deprecated", False):
|
||||
logger.warning(
|
||||
f"DEPRECATED ROUTE USED: {scope.get('method', 'GET')} {path} - "
|
||||
f"This route is deprecated and may be removed in a future version. "
|
||||
f"Please check the docs for the supported version."
|
||||
)
|
||||
|
||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||
|
||||
# Extract W3C trace context headers and store as trace attributes
|
||||
headers = dict(scope.get("headers", []))
|
||||
traceparent = headers.get(b"traceparent", b"").decode()
|
||||
if traceparent:
|
||||
trace_attributes["traceparent"] = traceparent
|
||||
tracestate = headers.get(b"tracestate", b"").decode()
|
||||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = message.get("headers", [])
|
||||
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send_with_trace_id)
|
||||
finally:
|
||||
await end_trace()
|
|
@ -7,8 +7,7 @@
|
|||
import datetime
|
||||
import os
|
||||
import threading
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
@ -22,7 +21,12 @@ from opentelemetry.sdk.trace import TracerProvider
|
|||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from opentelemetry.util.types import Attributes
|
||||
|
||||
from llama_stack.core.external import ExternalApiSpec
|
||||
from llama_stack.core.server.tracing import TelemetryProvider
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.middleware import TracingMiddleware
|
||||
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
Event,
|
||||
|
@ -73,7 +77,7 @@ def is_tracing_enabled(tracer):
|
|||
return span.is_recording()
|
||||
|
||||
|
||||
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry, TelemetryProvider):
|
||||
def __init__(self, config: TelemetryConfig, deps: dict[Api, Any]) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = deps.get(Api.datasetio)
|
||||
|
@ -266,12 +270,13 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
# Log to OpenTelemetry meter if available
|
||||
if self.meter is None:
|
||||
return
|
||||
normalized_attributes = self._normalize_attributes(event.attributes)
|
||||
if isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
counter.add(event.value, attributes=normalized_attributes)
|
||||
elif isinstance(event.value, float):
|
||||
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
|
||||
up_down_counter.add(event.value, attributes=event.attributes)
|
||||
up_down_counter.add(event.value, attributes=normalized_attributes)
|
||||
|
||||
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
|
||||
assert self.meter is not None
|
||||
|
@ -287,18 +292,17 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
with self._lock:
|
||||
span_id = int(event.span_id, 16)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
if event.attributes is None:
|
||||
event.attributes = {}
|
||||
event.attributes["__ttl__"] = ttl_seconds
|
||||
event_attributes = dict(event.attributes or {})
|
||||
event_attributes["__ttl__"] = ttl_seconds
|
||||
|
||||
# Extract these W3C trace context attributes so they are not written to
|
||||
# underlying storage, as we just need them to propagate the trace context.
|
||||
traceparent = event.attributes.pop("traceparent", None)
|
||||
tracestate = event.attributes.pop("tracestate", None)
|
||||
traceparent = event_attributes.pop("traceparent", None)
|
||||
tracestate = event_attributes.pop("tracestate", None)
|
||||
if traceparent:
|
||||
# If we have a traceparent header value, we're not the root span.
|
||||
for root_attribute in ROOT_SPAN_MARKERS:
|
||||
event.attributes.pop(root_attribute, None)
|
||||
event_attributes.pop(root_attribute, None)
|
||||
|
||||
if isinstance(event.payload, SpanStartPayload):
|
||||
# Check if span already exists to prevent duplicates
|
||||
|
@ -309,7 +313,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
if event.payload.parent_span_id:
|
||||
parent_span_id = int(event.payload.parent_span_id, 16)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
context = trace.set_span_in_context(parent_span)
|
||||
if parent_span:
|
||||
context = trace.set_span_in_context(parent_span)
|
||||
elif traceparent:
|
||||
carrier = {
|
||||
"traceparent": traceparent,
|
||||
|
@ -320,15 +325,15 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
context=context,
|
||||
attributes=event.attributes or {},
|
||||
attributes=self._normalize_attributes(event_attributes),
|
||||
)
|
||||
_GLOBAL_STORAGE["active_spans"][span_id] = span
|
||||
|
||||
elif isinstance(event.payload, SpanEndPayload):
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
if span:
|
||||
if event.attributes:
|
||||
span.set_attributes(event.attributes)
|
||||
if event_attributes:
|
||||
span.set_attributes(self._normalize_attributes(event_attributes))
|
||||
|
||||
status = (
|
||||
trace.Status(status_code=trace.StatusCode.OK)
|
||||
|
@ -377,5 +382,14 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
)
|
||||
)
|
||||
|
||||
def fastapi_middleware(self, app: FastAPI) -> None:
|
||||
FastAPIInstrumentor.instrument_app(app)
|
||||
def fastapi_middleware(
|
||||
self,
|
||||
app: FastAPI,
|
||||
impls: dict[Api, Any],
|
||||
external_apis: dict[str, ExternalApiSpec],
|
||||
):
|
||||
TracingMiddleware(app, impls, external_apis)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_attributes(attributes: dict[str, Any] | None) -> Attributes:
|
||||
return cast(Attributes, dict(attributes) if attributes else {})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue