fix(telemetry): remove legacy telemetry tools

This change removes all the hand written telemetry machinery
that has been replaced in prior changes with open telemetry
library calls.
This commit is contained in:
Emilio Garcia 2025-11-11 13:58:04 -05:00
parent cb3ce80917
commit 503522716f
3 changed files with 0 additions and 1161 deletions

View file

@ -1,619 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import threading
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum
from typing import (
Annotated,
Any,
Literal,
cast,
)
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from pydantic import BaseModel, Field
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import Primitive
from llama_stack_api import json_schema_type, register_schema
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
# Type alias for OpenTelemetry attribute values (excludes None)
AttributeValue = str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]
Attributes = Mapping[str, AttributeValue]
@json_schema_type
class SpanStatus(Enum):
"""The status of a span indicating whether it completed successfully or with an error.
:cvar OK: Span completed successfully without errors
:cvar ERROR: Span completed with an error or failure
"""
OK = "ok"
ERROR = "error"
@json_schema_type
class Span(BaseModel):
"""A span representing a single operation within a trace.
:param span_id: Unique identifier for the span
:param trace_id: Unique identifier for the trace this span belongs to
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
:param name: Human-readable name describing the operation this span represents
:param start_time: Timestamp when the operation began
:param end_time: (Optional) Timestamp when the operation finished, if completed
:param attributes: (Optional) Key-value pairs containing additional metadata about the span
"""
span_id: str
trace_id: str
parent_span_id: str | None = None
name: str
start_time: datetime
end_time: datetime | None = None
attributes: dict[str, Any] | None = Field(default_factory=lambda: {})
def set_attribute(self, key: str, value: Any):
if self.attributes is None:
self.attributes = {}
self.attributes[key] = value
@json_schema_type
class Trace(BaseModel):
"""A trace representing the complete execution path of a request across multiple operations.
:param trace_id: Unique identifier for the trace
:param root_span_id: Unique identifier for the root span that started this trace
:param start_time: Timestamp when the trace began
:param end_time: (Optional) Timestamp when the trace finished, if completed
"""
trace_id: str
root_span_id: str
start_time: datetime
end_time: datetime | None = None
@json_schema_type
class EventType(Enum):
"""The type of telemetry event being logged.
:cvar UNSTRUCTURED_LOG: A simple log message with severity level
:cvar STRUCTURED_LOG: A structured log event with typed payload data
:cvar METRIC: A metric measurement with value and unit
"""
UNSTRUCTURED_LOG = "unstructured_log"
STRUCTURED_LOG = "structured_log"
METRIC = "metric"
@json_schema_type
class LogSeverity(Enum):
"""The severity level of a log message.
:cvar VERBOSE: Detailed diagnostic information for troubleshooting
:cvar DEBUG: Debug information useful during development
:cvar INFO: General informational messages about normal operation
:cvar WARN: Warning messages about potentially problematic situations
:cvar ERROR: Error messages indicating failures that don't stop execution
:cvar CRITICAL: Critical error messages indicating severe failures
"""
VERBOSE = "verbose"
DEBUG = "debug"
INFO = "info"
WARN = "warn"
ERROR = "error"
CRITICAL = "critical"
class EventCommon(BaseModel):
"""Common fields shared by all telemetry events.
:param trace_id: Unique identifier for the trace this event belongs to
:param span_id: Unique identifier for the span this event belongs to
:param timestamp: Timestamp when the event occurred
:param attributes: (Optional) Key-value pairs containing additional metadata about the event
"""
trace_id: str
span_id: str
timestamp: datetime
attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {})
@json_schema_type
class UnstructuredLogEvent(EventCommon):
"""An unstructured log event containing a simple text message.
:param type: Event type identifier set to UNSTRUCTURED_LOG
:param message: The log message text
:param severity: The severity level of the log message
"""
type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG
message: str
severity: LogSeverity
@json_schema_type
class MetricEvent(EventCommon):
"""A metric event containing a measured value.
:param type: Event type identifier set to METRIC
:param metric: The name of the metric being measured
:param value: The numeric value of the metric measurement
:param unit: The unit of measurement for the metric value
"""
type: Literal[EventType.METRIC] = EventType.METRIC
metric: str # this would be an enum
value: int | float
unit: str
@json_schema_type
class StructuredLogType(Enum):
"""The type of structured log event payload.
:cvar SPAN_START: Event indicating the start of a new span
:cvar SPAN_END: Event indicating the completion of a span
"""
SPAN_START = "span_start"
SPAN_END = "span_end"
@json_schema_type
class SpanStartPayload(BaseModel):
"""Payload for a span start event.
:param type: Payload type identifier set to SPAN_START
:param name: Human-readable name describing the operation this span represents
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
"""
type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START
name: str
parent_span_id: str | None = None
@json_schema_type
class SpanEndPayload(BaseModel):
"""Payload for a span end event.
:param type: Payload type identifier set to SPAN_END
:param status: The final status of the span indicating success or failure
"""
type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END
status: SpanStatus
StructuredLogPayload = Annotated[
SpanStartPayload | SpanEndPayload,
Field(discriminator="type"),
]
register_schema(StructuredLogPayload, name="StructuredLogPayload")
@json_schema_type
class StructuredLogEvent(EventCommon):
"""A structured log event containing typed payload data.
:param type: Event type identifier set to STRUCTURED_LOG
:param payload: The structured payload data for the log event
"""
type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG
payload: StructuredLogPayload
Event = Annotated[
UnstructuredLogEvent | MetricEvent | StructuredLogEvent,
Field(discriminator="type"),
]
register_schema(Event, name="Event")
@json_schema_type
class EvalTrace(BaseModel):
"""A trace record for evaluation purposes.
:param session_id: Unique identifier for the evaluation session
:param step: The evaluation step or phase identifier
:param input: The input data for the evaluation
:param output: The actual output produced during evaluation
:param expected_output: The expected output for comparison during evaluation
"""
session_id: str
step: str
input: str
output: str
expected_output: str
@json_schema_type
class SpanWithStatus(Span):
"""A span that includes status information.
:param status: (Optional) The current status of the span
"""
status: SpanStatus | None = None
@json_schema_type
class QueryConditionOp(Enum):
"""Comparison operators for query conditions.
:cvar EQ: Equal to comparison
:cvar NE: Not equal to comparison
:cvar GT: Greater than comparison
:cvar LT: Less than comparison
"""
EQ = "eq"
NE = "ne"
GT = "gt"
LT = "lt"
@json_schema_type
class QueryCondition(BaseModel):
"""A condition for filtering query results.
:param key: The attribute key to filter on
:param op: The comparison operator to apply
:param value: The value to compare against
"""
key: str
op: QueryConditionOp
value: Any
class QueryTracesResponse(BaseModel):
"""Response containing a list of traces.
:param data: List of traces matching the query criteria
"""
data: list[Trace]
class QuerySpansResponse(BaseModel):
"""Response containing a list of spans.
:param data: List of spans matching the query criteria
"""
data: list[Span]
class QuerySpanTreeResponse(BaseModel):
"""Response containing a tree structure of spans.
:param data: Dictionary mapping span IDs to spans with status information
"""
data: dict[str, SpanWithStatus]
class MetricQueryType(Enum):
"""The type of metric query to perform.
:cvar RANGE: Query metrics over a time range
:cvar INSTANT: Query metrics at a specific point in time
"""
RANGE = "range"
INSTANT = "instant"
class MetricLabelOperator(Enum):
"""Operators for matching metric labels.
:cvar EQUALS: Label value must equal the specified value
:cvar NOT_EQUALS: Label value must not equal the specified value
:cvar REGEX_MATCH: Label value must match the specified regular expression
:cvar REGEX_NOT_MATCH: Label value must not match the specified regular expression
"""
EQUALS = "="
NOT_EQUALS = "!="
REGEX_MATCH = "=~"
REGEX_NOT_MATCH = "!~"
class MetricLabelMatcher(BaseModel):
"""A matcher for filtering metrics by label values.
:param name: The name of the label to match
:param value: The value to match against
:param operator: The comparison operator to use for matching
"""
name: str
value: str
operator: MetricLabelOperator = MetricLabelOperator.EQUALS
@json_schema_type
class MetricLabel(BaseModel):
"""A label associated with a metric.
:param name: The name of the label
:param value: The value of the label
"""
name: str
value: str
@json_schema_type
class MetricDataPoint(BaseModel):
"""A single data point in a metric time series.
:param timestamp: Unix timestamp when the metric value was recorded
:param value: The numeric value of the metric at this timestamp
"""
timestamp: int
value: float
unit: str
@json_schema_type
class MetricSeries(BaseModel):
"""A time series of metric data points.
:param metric: The name of the metric
:param labels: List of labels associated with this metric series
:param values: List of data points in chronological order
"""
metric: str
labels: list[MetricLabel]
values: list[MetricDataPoint]
class QueryMetricsResponse(BaseModel):
"""Response containing metric time series data.
:param data: List of metric series matching the query criteria
"""
data: list[MetricSeries]
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"active_spans": {},
"counters": {},
"gauges": {},
"up_down_counters": {},
"histograms": {},
}
_global_lock = threading.Lock()
_TRACER_PROVIDER = None
logger = get_logger(name=__name__, category="telemetry")
def _clean_attributes(attrs: dict[str, Any] | None) -> Attributes | None:
"""Remove None values from attributes dict to match OpenTelemetry's expected type."""
if attrs is None:
return None
return {k: v for k, v in attrs.items() if v is not None}
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording()
class Telemetry:
def __init__(self) -> None:
self.meter = None
global _TRACER_PROVIDER
# Initialize the correct span processor based on the provider state.
# This is needed since once the span processor is set, it cannot be unset.
# Recreating the telemetry adapter multiple times will result in duplicate span processors.
# Since the library client can be recreated multiple times in a notebook,
# the kernel will hold on to the span processor and cause duplicate spans to be written.
if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"):
if _TRACER_PROVIDER is None:
provider = TracerProvider()
trace.set_tracer_provider(provider)
_TRACER_PROVIDER = provider
# Use single OTLP endpoint for all telemetry signals
# Let OpenTelemetry SDK handle endpoint construction automatically
# The SDK will read OTEL_EXPORTER_OTLP_ENDPOINT and construct appropriate URLs
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
span_exporter = OTLPSpanExporter()
span_processor = BatchSpanProcessor(span_exporter)
cast(TracerProvider, trace.get_tracer_provider()).add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
metric_provider = MeterProvider(metric_readers=[metric_reader])
metrics.set_meter_provider(metric_provider)
self.is_otel_endpoint_set = True
else:
logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT is not set, skipping telemetry")
self.is_otel_endpoint_set = False
self.meter = metrics.get_meter(__name__)
self._lock = _global_lock
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
if self.is_otel_endpoint_set:
cast(TracerProvider, trace.get_tracer_provider()).force_flush()
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
if isinstance(event, UnstructuredLogEvent):
self._log_unstructured(event, ttl_seconds)
elif isinstance(event, MetricEvent):
self._log_metric(event)
elif isinstance(event, StructuredLogEvent):
self._log_structured(event, ttl_seconds)
else:
raise ValueError(f"Unknown event type: {event}")
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
# Use global storage instead of instance storage
span_id = int(event.span_id, 16)
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
span.add_event(
name=event.type.value,
attributes={
"message": event.message,
"severity": event.severity.value,
"__ttl__": ttl_seconds,
**(event.attributes or {}),
},
timestamp=timestamp_ns,
)
else:
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["counters"]:
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
name=name,
unit=unit,
description=f"Counter for {name}",
)
return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name])
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["histograms"]:
_GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram(
name=name,
unit=unit,
description=f"Histogram for {name}",
)
return cast(metrics.Histogram, _GLOBAL_STORAGE["histograms"][name])
def _log_metric(self, event: MetricEvent) -> None:
# Add metric as an event to the current span
try:
with self._lock:
# Only try to add to span if we have a valid span_id
if event.span_id:
try:
span_id = int(event.span_id, 16)
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
span.add_event(
name=f"metric.{event.metric}",
attributes={
"value": event.value,
"unit": event.unit,
**(event.attributes or {}),
},
timestamp=timestamp_ns,
)
except (ValueError, KeyError):
# Invalid span_id or span not found, but we already logged to console above
pass
except Exception:
# Lock acquisition failed
logger.debug("Failed to acquire lock to add metric to span")
# Log to OpenTelemetry meter if available
if self.meter is None:
return
# Use histograms for token-related metrics (per-request measurements)
# Use counters for other cumulative metrics
token_metrics = {"prompt_tokens", "completion_tokens", "total_tokens"}
if event.metric in token_metrics:
# Token metrics are per-request measurements, use histogram
histogram = self._get_or_create_histogram(event.metric, event.unit)
histogram.record(event.value, attributes=_clean_attributes(event.attributes))
elif isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=_clean_attributes(event.attributes))
elif isinstance(event.value, float):
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
up_down_counter.add(event.value, attributes=_clean_attributes(event.attributes))
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["up_down_counters"]:
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
name=name,
unit=unit,
description=f"UpDownCounter for {name}",
)
return cast(metrics.UpDownCounter, _GLOBAL_STORAGE["up_down_counters"][name])
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
span_id = int(event.span_id, 16)
tracer = trace.get_tracer(__name__)
if event.attributes is None:
event.attributes = {}
event.attributes["__ttl__"] = ttl_seconds
# Extract these W3C trace context attributes so they are not written to
# underlying storage, as we just need them to propagate the trace context.
traceparent = event.attributes.pop("traceparent", None)
tracestate = event.attributes.pop("tracestate", None)
if traceparent:
# If we have a traceparent header value, we're not the root span.
for root_attribute in ROOT_SPAN_MARKERS:
event.attributes.pop(root_attribute, None)
if isinstance(event.payload, SpanStartPayload):
# Check if span already exists to prevent duplicates
if span_id in _GLOBAL_STORAGE["active_spans"]:
return
context = None
if event.payload.parent_span_id:
parent_span_id = int(event.payload.parent_span_id, 16)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
if parent_span:
context = trace.set_span_in_context(parent_span)
elif traceparent:
carrier = {
"traceparent": traceparent,
"tracestate": tracestate,
}
context = TraceContextTextMapPropagator().extract(carrier=carrier)
span = tracer.start_span(
name=event.payload.name,
context=context,
attributes=_clean_attributes(event.attributes),
)
_GLOBAL_STORAGE["active_spans"][span_id] = span
elif isinstance(event.payload, SpanEndPayload):
span = _GLOBAL_STORAGE["active_spans"].get(span_id) # type: ignore[assignment]
if span:
if event.attributes:
cleaned_attrs = _clean_attributes(event.attributes)
if cleaned_attrs:
span.set_attributes(cleaned_attrs)
status = (
trace.Status(status_code=trace.StatusCode.OK)
if event.payload.status == SpanStatus.OK
else trace.Status(status_code=trace.StatusCode.ERROR)
)
span.set_status(status)
span.end()
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
else:
raise ValueError(f"Unknown structured log event: {event}")

View file

@ -1,154 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
import json
from collections.abc import AsyncGenerator, Callable
from functools import wraps
from typing import Any, cast
from pydantic import BaseModel
from llama_stack.models.llama.datatypes import Primitive
type JSONValue = Primitive | list["JSONValue"] | dict[str, "JSONValue"]
def serialize_value(value: Any) -> str:
return str(_prepare_for_json(value))
def _prepare_for_json(value: Any) -> JSONValue:
"""Serialize a single value into JSON-compatible format."""
if value is None:
return ""
elif isinstance(value, str | int | float | bool):
return value
elif hasattr(value, "_name_"):
return cast(str, value._name_)
elif isinstance(value, BaseModel):
return cast(JSONValue, json.loads(value.model_dump_json()))
elif isinstance(value, list | tuple | set):
return [_prepare_for_json(item) for item in value]
elif isinstance(value, dict):
return {str(k): _prepare_for_json(v) for k, v in value.items()}
else:
try:
json.dumps(value)
return cast(JSONValue, value)
except Exception:
return str(value)
def trace_protocol[T: type[Any]](cls: T) -> T:
"""
A class decorator that automatically traces all methods in a protocol/base class
and its inheriting classes.
"""
def trace_method(method: Callable[..., Any]) -> Callable[..., Any]:
is_async = asyncio.iscoroutinefunction(method)
is_async_gen = inspect.isasyncgenfunction(method)
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple[str, str, dict[str, Primitive]]:
class_name = self.__class__.__name__
method_name = method.__name__
span_type = "async_generator" if is_async_gen else "async" if is_async else "sync"
sig = inspect.signature(method)
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
combined_args: dict[str, str] = {}
for i, arg in enumerate(args):
param_name = param_names[i] if i < len(param_names) else f"position_{i + 1}"
combined_args[param_name] = serialize_value(arg)
for k, v in kwargs.items():
combined_args[str(k)] = serialize_value(v)
span_attributes: dict[str, Primitive] = {
"__autotraced__": True,
"__class__": class_name,
"__method__": method_name,
"__type__": span_type,
"__args__": json.dumps(combined_args),
}
return class_name, method_name, span_attributes
@wraps(method)
async def async_gen_wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
from llama_stack.core.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
count = 0
try:
async for item in method(self, *args, **kwargs):
yield item
count += 1
finally:
span.set_attribute("chunk_count", count)
@wraps(method)
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
from llama_stack.core.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
result = await method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
return result
except Exception as e:
span.set_attribute("error", str(e))
raise
@wraps(method)
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
from llama_stack.core.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
result = method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
return result
except Exception as e:
span.set_attribute("error", str(e))
raise
if is_async_gen:
return async_gen_wrapper
elif is_async:
return async_wrapper
else:
return sync_wrapper
# Wrap methods on the class itself (for classes applied at runtime)
# Skip if already wrapped (indicated by __wrapped__ attribute)
for name, method in vars(cls).items():
if inspect.isfunction(method) and not name.startswith("_"):
if not hasattr(method, "__wrapped__"):
wrapped = trace_method(method)
setattr(cls, name, wrapped) # noqa: B010
# Also set up __init_subclass__ for future subclasses
original_init_subclass = cast(Callable[..., Any] | None, getattr(cls, "__init_subclass__", None))
def __init_subclass__(cls_child: type[Any], **kwargs: Any) -> None: # noqa: N807
if original_init_subclass:
cast(Callable[..., None], original_init_subclass)(**kwargs)
for name, method in vars(cls_child).items():
if inspect.isfunction(method) and not name.startswith("_"):
setattr(cls_child, name, trace_method(method)) # noqa: B010
cls_any = cast(Any, cls)
cls_any.__init_subclass__ = classmethod(__init_subclass__)
return cls

View file

@ -1,388 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import contextvars
import logging # allow-direct-logging
import queue
import secrets
import sys
import threading
import time
from collections.abc import Callable
from datetime import UTC, datetime
from functools import wraps
from typing import Any, Self
from llama_stack.core.telemetry.telemetry import (
ROOT_SPAN_MARKERS,
Event,
LogSeverity,
Span,
SpanEndPayload,
SpanStartPayload,
SpanStatus,
StructuredLogEvent,
Telemetry,
UnstructuredLogEvent,
)
from llama_stack.core.telemetry.trace_protocol import serialize_value
from llama_stack.log import get_logger
logger = get_logger(__name__, category="core")
# Fallback logger that does NOT propagate to TelemetryHandler to avoid recursion
_fallback_logger = logging.getLogger("llama_stack.telemetry.background")
if not _fallback_logger.handlers:
_fallback_logger.propagate = False
_fallback_logger.setLevel(logging.ERROR)
_fallback_handler = logging.StreamHandler(sys.stderr)
_fallback_handler.setLevel(logging.ERROR)
_fallback_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
_fallback_logger.addHandler(_fallback_handler)
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000
# The logical root span may not be visible to this process if a parent context
# is passed in. The local root span is the first local span in a trace.
LOCAL_ROOT_SPAN_MARKER = "__local_root_span__"
def trace_id_to_str(trace_id: int) -> str:
"""Convenience trace ID formatting method
Args:
trace_id: Trace ID int
Returns:
The trace ID as 32-byte hexadecimal string
"""
return format(trace_id, "032x")
def span_id_to_str(span_id: int) -> str:
"""Convenience span ID formatting method
Args:
span_id: Span ID int
Returns:
The span ID as 16-byte hexadecimal string
"""
return format(span_id, "016x")
def generate_span_id() -> str:
span_id = secrets.randbits(64)
while span_id == INVALID_SPAN_ID:
span_id = secrets.randbits(64)
return span_id_to_str(span_id)
def generate_trace_id() -> str:
trace_id = secrets.randbits(128)
while trace_id == INVALID_TRACE_ID:
trace_id = secrets.randbits(128)
return trace_id_to_str(trace_id)
LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS = 60.0
class BackgroundLogger:
def __init__(self, api: Telemetry, capacity: int = 100000):
self.api = api
self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity)
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
self.worker_thread.start()
self._last_queue_full_log_time: float = 0.0
self._dropped_since_last_notice: int = 0
def log_event(self, event: Event) -> None:
try:
self.log_queue.put_nowait(event)
except queue.Full:
# Aggregate drops and emit at most once per interval via fallback logger
self._dropped_since_last_notice += 1
current_time = time.time()
if current_time - self._last_queue_full_log_time >= LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS:
_fallback_logger.error(
"Log queue is full; dropped %d events since last notice",
self._dropped_since_last_notice,
)
self._last_queue_full_log_time = current_time
self._dropped_since_last_notice = 0
def _worker(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._process_logs())
async def _process_logs(self):
while True:
try:
event = self.log_queue.get()
await self.api.log_event(event)
except Exception:
import traceback
traceback.print_exc()
print("Error processing log event")
finally:
self.log_queue.task_done()
def __del__(self) -> None:
self.log_queue.join()
BACKGROUND_LOGGER: BackgroundLogger | None = None
def enqueue_event(event: Event) -> None:
"""Enqueue a telemetry event to the background logger if available.
This provides a non-blocking path for routers and other hot paths to
submit telemetry without awaiting the Telemetry API, reducing contention
with the main event loop.
"""
global BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
raise RuntimeError("Telemetry API not initialized")
BACKGROUND_LOGGER.log_event(event)
class TraceContext:
def __init__(self, logger: BackgroundLogger, trace_id: str):
self.logger = logger
self.trace_id = trace_id
self.spans: list[Span] = []
def push_span(self, name: str, attributes: dict[str, Any] | None = None) -> Span:
current_span = self.get_current_span()
span = Span(
span_id=generate_span_id(),
trace_id=self.trace_id,
name=name,
start_time=datetime.now(UTC),
parent_span_id=current_span.span_id if current_span else None,
attributes=attributes,
)
self.logger.log_event(
StructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=span.start_time,
attributes=span.attributes,
payload=SpanStartPayload(
name=span.name,
parent_span_id=span.parent_span_id,
),
)
)
self.spans.append(span)
return span
def pop_span(self, status: SpanStatus = SpanStatus.OK) -> None:
span = self.spans.pop()
if span is not None:
self.logger.log_event(
StructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=span.start_time,
attributes=span.attributes,
payload=SpanEndPayload(
status=status,
),
)
)
def get_current_span(self) -> Span | None:
return self.spans[-1] if self.spans else None
CURRENT_TRACE_CONTEXT: contextvars.ContextVar[TraceContext | None] = contextvars.ContextVar(
"trace_context", default=None
)
def setup_logger(api: Telemetry, level: int = logging.INFO):
global BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
BACKGROUND_LOGGER = BackgroundLogger(api)
root_logger = logging.getLogger()
root_logger.setLevel(level)
root_logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: dict[str, Any] | None = None) -> TraceContext | None:
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
return None
trace_id = generate_trace_id()
context = TraceContext(BACKGROUND_LOGGER, trace_id)
# Mark this span as the root for the trace for now. The processing of
# traceparent context if supplied comes later and will result in the
# ROOT_SPAN_MARKERS being removed. Also mark this is the 'local' root,
# i.e. the root of the spans originating in this process as this is
# needed to ensure that we insert this 'local' root span's id into
# the trace record in sqlite store.
attributes = dict.fromkeys(ROOT_SPAN_MARKERS, True) | {LOCAL_ROOT_SPAN_MARKER: True} | (attributes or {})
context.push_span(name, attributes)
CURRENT_TRACE_CONTEXT.set(context)
return context
async def end_trace(status: SpanStatus = SpanStatus.OK):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if context is None:
logger.debug("No trace context to end")
return
context.pop_span(status)
CURRENT_TRACE_CONTEXT.set(None)
def severity(levelname: str) -> LogSeverity:
if levelname == "DEBUG":
return LogSeverity.DEBUG
elif levelname == "INFO":
return LogSeverity.INFO
elif levelname == "WARNING":
return LogSeverity.WARN
elif levelname == "ERROR":
return LogSeverity.ERROR
elif levelname == "CRITICAL":
return LogSeverity.CRITICAL
else:
raise ValueError(f"Unknown log level: {levelname}")
# TODO: ideally, the actual emitting should be done inside a separate daemon
# process completely isolated from the server
class TelemetryHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
# horrendous hack to avoid logging from asyncio and getting into an infinite loop
if record.module in ("asyncio", "selector_events"):
return
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if context is None:
return
span = context.get_current_span()
if span is None:
return
enqueue_event(
UnstructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=datetime.now(UTC),
message=self.format(record),
severity=severity(record.levelname),
)
)
def close(self) -> None:
pass
class SpanContextManager:
def __init__(self, name: str, attributes: dict[str, Any] | None = None):
self.name = name
self.attributes = attributes
self.span: Span | None = None
def __enter__(self) -> Self:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to push span")
return self
self.span = context.push_span(self.name, self.attributes)
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to pop span")
return
context.pop_span()
def set_attribute(self, key: str, value: Any) -> None:
if self.span:
if self.span.attributes is None:
self.span.attributes = {}
self.span.attributes[key] = serialize_value(value)
async def __aenter__(self) -> Self:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to push span")
return self
self.span = context.push_span(self.name, self.attributes)
return self
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to pop span")
return
context.pop_span()
def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
with self:
return func(*args, **kwargs)
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
async with self:
return await func(*args, **kwargs)
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if asyncio.iscoroutinefunction(func):
return async_wrapper(*args, **kwargs)
else:
return sync_wrapper(*args, **kwargs)
return wrapper
def span(name: str, attributes: dict[str, Any] | None = None) -> SpanContextManager:
return SpanContextManager(name, attributes)
def get_current_span() -> Span | None:
global CURRENT_TRACE_CONTEXT
if CURRENT_TRACE_CONTEXT is None:
logger.debug("No trace context to get current span")
return None
context = CURRENT_TRACE_CONTEXT.get()
if context:
return context.get_current_span()
return None