feat(major): move new telemetry architecture into new provider

This commit is contained in:
Emilio Garcia 2025-10-01 11:54:14 -04:00
parent ce3a804893
commit e45e77f7b0
10 changed files with 207 additions and 52 deletions

View file

@ -7,8 +7,7 @@
import datetime
import os
import threading
import logging
from typing import Any
from typing import Any, cast
from fastapi import FastAPI
@ -22,7 +21,12 @@ from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.util.types import Attributes
from llama_stack.core.external import ExternalApiSpec
from llama_stack.core.server.tracing import TelemetryProvider
from llama_stack.providers.inline.telemetry.meta_reference.middleware import TracingMiddleware
from llama_stack.apis.telemetry import (
Event,
@ -73,7 +77,7 @@ def is_tracing_enabled(tracer):
return span.is_recording()
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry, TelemetryProvider):
def __init__(self, config: TelemetryConfig, deps: dict[Api, Any]) -> None:
self.config = config
self.datasetio_api = deps.get(Api.datasetio)
@ -266,12 +270,13 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
# Log to OpenTelemetry meter if available
if self.meter is None:
return
normalized_attributes = self._normalize_attributes(event.attributes)
if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)
counter.add(event.value, attributes=normalized_attributes)
elif isinstance(event.value, float):
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
up_down_counter.add(event.value, attributes=event.attributes)
up_down_counter.add(event.value, attributes=normalized_attributes)
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
assert self.meter is not None
@ -287,18 +292,17 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
with self._lock:
span_id = int(event.span_id, 16)
tracer = trace.get_tracer(__name__)
if event.attributes is None:
event.attributes = {}
event.attributes["__ttl__"] = ttl_seconds
event_attributes = dict(event.attributes or {})
event_attributes["__ttl__"] = ttl_seconds
# Extract these W3C trace context attributes so they are not written to
# underlying storage, as we just need them to propagate the trace context.
traceparent = event.attributes.pop("traceparent", None)
tracestate = event.attributes.pop("tracestate", None)
traceparent = event_attributes.pop("traceparent", None)
tracestate = event_attributes.pop("tracestate", None)
if traceparent:
# If we have a traceparent header value, we're not the root span.
for root_attribute in ROOT_SPAN_MARKERS:
event.attributes.pop(root_attribute, None)
event_attributes.pop(root_attribute, None)
if isinstance(event.payload, SpanStartPayload):
# Check if span already exists to prevent duplicates
@ -309,7 +313,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
if event.payload.parent_span_id:
parent_span_id = int(event.payload.parent_span_id, 16)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
context = trace.set_span_in_context(parent_span)
if parent_span:
context = trace.set_span_in_context(parent_span)
elif traceparent:
carrier = {
"traceparent": traceparent,
@ -320,15 +325,15 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
span = tracer.start_span(
name=event.payload.name,
context=context,
attributes=event.attributes or {},
attributes=self._normalize_attributes(event_attributes),
)
_GLOBAL_STORAGE["active_spans"][span_id] = span
elif isinstance(event.payload, SpanEndPayload):
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
if event.attributes:
span.set_attributes(event.attributes)
if event_attributes:
span.set_attributes(self._normalize_attributes(event_attributes))
status = (
trace.Status(status_code=trace.StatusCode.OK)
@ -377,5 +382,14 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
)
)
def fastapi_middleware(self, app: FastAPI) -> None:
FastAPIInstrumentor.instrument_app(app)
def fastapi_middleware(
self,
app: FastAPI,
impls: dict[Api, Any],
external_apis: dict[str, ExternalApiSpec],
):
TracingMiddleware(app, impls, external_apis)
@staticmethod
def _normalize_attributes(attributes: dict[str, Any] | None) -> Attributes:
return cast(Attributes, dict(attributes) if attributes else {})