From 503522716f0f87a28cd1c9e8adf1a250e9c624ad Mon Sep 17 00:00:00 2001 From: Emilio Garcia Date: Tue, 11 Nov 2025 13:58:04 -0500 Subject: [PATCH] fix(telemetry): remove legacy telemetry tools This change removes all the hand written telemetry machinery that has been replaced in prior changes with open telemetry library calls. --- src/llama_stack/core/telemetry/telemetry.py | 619 ------------------ .../core/telemetry/trace_protocol.py | 154 ----- src/llama_stack/core/telemetry/tracing.py | 388 ----------- 3 files changed, 1161 deletions(-) delete mode 100644 src/llama_stack/core/telemetry/telemetry.py delete mode 100644 src/llama_stack/core/telemetry/trace_protocol.py delete mode 100644 src/llama_stack/core/telemetry/tracing.py diff --git a/src/llama_stack/core/telemetry/telemetry.py b/src/llama_stack/core/telemetry/telemetry.py deleted file mode 100644 index b0913b477..000000000 --- a/src/llama_stack/core/telemetry/telemetry.py +++ /dev/null @@ -1,619 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import os -import threading -from collections.abc import Mapping, Sequence -from datetime import datetime -from enum import Enum -from typing import ( - Annotated, - Any, - Literal, - cast, -) - -from opentelemetry import metrics, trace -from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter -from opentelemetry.sdk.metrics import MeterProvider -from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator -from pydantic import BaseModel, Field - -from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import Primitive -from llama_stack_api import json_schema_type, register_schema - -ROOT_SPAN_MARKERS = ["__root__", "__root_span__"] - -# Type alias for OpenTelemetry attribute values (excludes None) -AttributeValue = str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float] -Attributes = Mapping[str, AttributeValue] - - -@json_schema_type -class SpanStatus(Enum): - """The status of a span indicating whether it completed successfully or with an error. - :cvar OK: Span completed successfully without errors - :cvar ERROR: Span completed with an error or failure - """ - - OK = "ok" - ERROR = "error" - - -@json_schema_type -class Span(BaseModel): - """A span representing a single operation within a trace. - :param span_id: Unique identifier for the span - :param trace_id: Unique identifier for the trace this span belongs to - :param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span - :param name: Human-readable name describing the operation this span represents - :param start_time: Timestamp when the operation began - :param end_time: (Optional) Timestamp when the operation finished, if completed - :param attributes: (Optional) Key-value pairs containing additional metadata about the span - """ - - span_id: str - trace_id: str - parent_span_id: str | None = None - name: str - start_time: datetime - end_time: datetime | None = None - attributes: dict[str, Any] | None = Field(default_factory=lambda: {}) - - def set_attribute(self, key: str, value: Any): - if self.attributes is None: - self.attributes = {} - self.attributes[key] = value - - -@json_schema_type -class Trace(BaseModel): - """A trace representing the complete execution path of a request across multiple operations. - :param trace_id: Unique identifier for the trace - :param root_span_id: Unique identifier for the root span that started this trace - :param start_time: Timestamp when the trace began - :param end_time: (Optional) Timestamp when the trace finished, if completed - """ - - trace_id: str - root_span_id: str - start_time: datetime - end_time: datetime | None = None - - -@json_schema_type -class EventType(Enum): - """The type of telemetry event being logged. - :cvar UNSTRUCTURED_LOG: A simple log message with severity level - :cvar STRUCTURED_LOG: A structured log event with typed payload data - :cvar METRIC: A metric measurement with value and unit - """ - - UNSTRUCTURED_LOG = "unstructured_log" - STRUCTURED_LOG = "structured_log" - METRIC = "metric" - - -@json_schema_type -class LogSeverity(Enum): - """The severity level of a log message. - :cvar VERBOSE: Detailed diagnostic information for troubleshooting - :cvar DEBUG: Debug information useful during development - :cvar INFO: General informational messages about normal operation - :cvar WARN: Warning messages about potentially problematic situations - :cvar ERROR: Error messages indicating failures that don't stop execution - :cvar CRITICAL: Critical error messages indicating severe failures - """ - - VERBOSE = "verbose" - DEBUG = "debug" - INFO = "info" - WARN = "warn" - ERROR = "error" - CRITICAL = "critical" - - -class EventCommon(BaseModel): - """Common fields shared by all telemetry events. - :param trace_id: Unique identifier for the trace this event belongs to - :param span_id: Unique identifier for the span this event belongs to - :param timestamp: Timestamp when the event occurred - :param attributes: (Optional) Key-value pairs containing additional metadata about the event - """ - - trace_id: str - span_id: str - timestamp: datetime - attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {}) - - -@json_schema_type -class UnstructuredLogEvent(EventCommon): - """An unstructured log event containing a simple text message. - :param type: Event type identifier set to UNSTRUCTURED_LOG - :param message: The log message text - :param severity: The severity level of the log message - """ - - type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG - message: str - severity: LogSeverity - - -@json_schema_type -class MetricEvent(EventCommon): - """A metric event containing a measured value. - :param type: Event type identifier set to METRIC - :param metric: The name of the metric being measured - :param value: The numeric value of the metric measurement - :param unit: The unit of measurement for the metric value - """ - - type: Literal[EventType.METRIC] = EventType.METRIC - metric: str # this would be an enum - value: int | float - unit: str - - -@json_schema_type -class StructuredLogType(Enum): - """The type of structured log event payload. - :cvar SPAN_START: Event indicating the start of a new span - :cvar SPAN_END: Event indicating the completion of a span - """ - - SPAN_START = "span_start" - SPAN_END = "span_end" - - -@json_schema_type -class SpanStartPayload(BaseModel): - """Payload for a span start event. - :param type: Payload type identifier set to SPAN_START - :param name: Human-readable name describing the operation this span represents - :param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span - """ - - type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START - name: str - parent_span_id: str | None = None - - -@json_schema_type -class SpanEndPayload(BaseModel): - """Payload for a span end event. - :param type: Payload type identifier set to SPAN_END - :param status: The final status of the span indicating success or failure - """ - - type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END - status: SpanStatus - - -StructuredLogPayload = Annotated[ - SpanStartPayload | SpanEndPayload, - Field(discriminator="type"), -] -register_schema(StructuredLogPayload, name="StructuredLogPayload") - - -@json_schema_type -class StructuredLogEvent(EventCommon): - """A structured log event containing typed payload data. - :param type: Event type identifier set to STRUCTURED_LOG - :param payload: The structured payload data for the log event - """ - - type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG - payload: StructuredLogPayload - - -Event = Annotated[ - UnstructuredLogEvent | MetricEvent | StructuredLogEvent, - Field(discriminator="type"), -] -register_schema(Event, name="Event") - - -@json_schema_type -class EvalTrace(BaseModel): - """A trace record for evaluation purposes. - :param session_id: Unique identifier for the evaluation session - :param step: The evaluation step or phase identifier - :param input: The input data for the evaluation - :param output: The actual output produced during evaluation - :param expected_output: The expected output for comparison during evaluation - """ - - session_id: str - step: str - input: str - output: str - expected_output: str - - -@json_schema_type -class SpanWithStatus(Span): - """A span that includes status information. - :param status: (Optional) The current status of the span - """ - - status: SpanStatus | None = None - - -@json_schema_type -class QueryConditionOp(Enum): - """Comparison operators for query conditions. - :cvar EQ: Equal to comparison - :cvar NE: Not equal to comparison - :cvar GT: Greater than comparison - :cvar LT: Less than comparison - """ - - EQ = "eq" - NE = "ne" - GT = "gt" - LT = "lt" - - -@json_schema_type -class QueryCondition(BaseModel): - """A condition for filtering query results. - :param key: The attribute key to filter on - :param op: The comparison operator to apply - :param value: The value to compare against - """ - - key: str - op: QueryConditionOp - value: Any - - -class QueryTracesResponse(BaseModel): - """Response containing a list of traces. - :param data: List of traces matching the query criteria - """ - - data: list[Trace] - - -class QuerySpansResponse(BaseModel): - """Response containing a list of spans. - :param data: List of spans matching the query criteria - """ - - data: list[Span] - - -class QuerySpanTreeResponse(BaseModel): - """Response containing a tree structure of spans. - :param data: Dictionary mapping span IDs to spans with status information - """ - - data: dict[str, SpanWithStatus] - - -class MetricQueryType(Enum): - """The type of metric query to perform. - :cvar RANGE: Query metrics over a time range - :cvar INSTANT: Query metrics at a specific point in time - """ - - RANGE = "range" - INSTANT = "instant" - - -class MetricLabelOperator(Enum): - """Operators for matching metric labels. - :cvar EQUALS: Label value must equal the specified value - :cvar NOT_EQUALS: Label value must not equal the specified value - :cvar REGEX_MATCH: Label value must match the specified regular expression - :cvar REGEX_NOT_MATCH: Label value must not match the specified regular expression - """ - - EQUALS = "=" - NOT_EQUALS = "!=" - REGEX_MATCH = "=~" - REGEX_NOT_MATCH = "!~" - - -class MetricLabelMatcher(BaseModel): - """A matcher for filtering metrics by label values. - :param name: The name of the label to match - :param value: The value to match against - :param operator: The comparison operator to use for matching - """ - - name: str - value: str - operator: MetricLabelOperator = MetricLabelOperator.EQUALS - - -@json_schema_type -class MetricLabel(BaseModel): - """A label associated with a metric. - :param name: The name of the label - :param value: The value of the label - """ - - name: str - value: str - - -@json_schema_type -class MetricDataPoint(BaseModel): - """A single data point in a metric time series. - :param timestamp: Unix timestamp when the metric value was recorded - :param value: The numeric value of the metric at this timestamp - """ - - timestamp: int - value: float - unit: str - - -@json_schema_type -class MetricSeries(BaseModel): - """A time series of metric data points. - :param metric: The name of the metric - :param labels: List of labels associated with this metric series - :param values: List of data points in chronological order - """ - - metric: str - labels: list[MetricLabel] - values: list[MetricDataPoint] - - -class QueryMetricsResponse(BaseModel): - """Response containing metric time series data. - :param data: List of metric series matching the query criteria - """ - - data: list[MetricSeries] - - -_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = { - "active_spans": {}, - "counters": {}, - "gauges": {}, - "up_down_counters": {}, - "histograms": {}, -} -_global_lock = threading.Lock() -_TRACER_PROVIDER = None - -logger = get_logger(name=__name__, category="telemetry") - - -def _clean_attributes(attrs: dict[str, Any] | None) -> Attributes | None: - """Remove None values from attributes dict to match OpenTelemetry's expected type.""" - if attrs is None: - return None - return {k: v for k, v in attrs.items() if v is not None} - - -def is_tracing_enabled(tracer): - with tracer.start_as_current_span("check_tracing") as span: - return span.is_recording() - - -class Telemetry: - def __init__(self) -> None: - self.meter = None - - global _TRACER_PROVIDER - # Initialize the correct span processor based on the provider state. - # This is needed since once the span processor is set, it cannot be unset. - # Recreating the telemetry adapter multiple times will result in duplicate span processors. - # Since the library client can be recreated multiple times in a notebook, - # the kernel will hold on to the span processor and cause duplicate spans to be written. - if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"): - if _TRACER_PROVIDER is None: - provider = TracerProvider() - trace.set_tracer_provider(provider) - _TRACER_PROVIDER = provider - - # Use single OTLP endpoint for all telemetry signals - - # Let OpenTelemetry SDK handle endpoint construction automatically - # The SDK will read OTEL_EXPORTER_OTLP_ENDPOINT and construct appropriate URLs - # https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter - span_exporter = OTLPSpanExporter() - span_processor = BatchSpanProcessor(span_exporter) - cast(TracerProvider, trace.get_tracer_provider()).add_span_processor(span_processor) - - metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter()) - metric_provider = MeterProvider(metric_readers=[metric_reader]) - metrics.set_meter_provider(metric_provider) - self.is_otel_endpoint_set = True - else: - logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT is not set, skipping telemetry") - self.is_otel_endpoint_set = False - - self.meter = metrics.get_meter(__name__) - self._lock = _global_lock - - async def initialize(self) -> None: - pass - - async def shutdown(self) -> None: - if self.is_otel_endpoint_set: - cast(TracerProvider, trace.get_tracer_provider()).force_flush() - - async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: - if isinstance(event, UnstructuredLogEvent): - self._log_unstructured(event, ttl_seconds) - elif isinstance(event, MetricEvent): - self._log_metric(event) - elif isinstance(event, StructuredLogEvent): - self._log_structured(event, ttl_seconds) - else: - raise ValueError(f"Unknown event type: {event}") - - def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None: - with self._lock: - # Use global storage instead of instance storage - span_id = int(event.span_id, 16) - span = _GLOBAL_STORAGE["active_spans"].get(span_id) - - if span: - timestamp_ns = int(event.timestamp.timestamp() * 1e9) - span.add_event( - name=event.type.value, - attributes={ - "message": event.message, - "severity": event.severity.value, - "__ttl__": ttl_seconds, - **(event.attributes or {}), - }, - timestamp=timestamp_ns, - ) - else: - print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}") - - def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter: - assert self.meter is not None - if name not in _GLOBAL_STORAGE["counters"]: - _GLOBAL_STORAGE["counters"][name] = self.meter.create_counter( - name=name, - unit=unit, - description=f"Counter for {name}", - ) - return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name]) - - def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram: - assert self.meter is not None - if name not in _GLOBAL_STORAGE["histograms"]: - _GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram( - name=name, - unit=unit, - description=f"Histogram for {name}", - ) - return cast(metrics.Histogram, _GLOBAL_STORAGE["histograms"][name]) - - def _log_metric(self, event: MetricEvent) -> None: - # Add metric as an event to the current span - try: - with self._lock: - # Only try to add to span if we have a valid span_id - if event.span_id: - try: - span_id = int(event.span_id, 16) - span = _GLOBAL_STORAGE["active_spans"].get(span_id) - - if span: - timestamp_ns = int(event.timestamp.timestamp() * 1e9) - span.add_event( - name=f"metric.{event.metric}", - attributes={ - "value": event.value, - "unit": event.unit, - **(event.attributes or {}), - }, - timestamp=timestamp_ns, - ) - except (ValueError, KeyError): - # Invalid span_id or span not found, but we already logged to console above - pass - except Exception: - # Lock acquisition failed - logger.debug("Failed to acquire lock to add metric to span") - - # Log to OpenTelemetry meter if available - if self.meter is None: - return - - # Use histograms for token-related metrics (per-request measurements) - # Use counters for other cumulative metrics - token_metrics = {"prompt_tokens", "completion_tokens", "total_tokens"} - - if event.metric in token_metrics: - # Token metrics are per-request measurements, use histogram - histogram = self._get_or_create_histogram(event.metric, event.unit) - histogram.record(event.value, attributes=_clean_attributes(event.attributes)) - elif isinstance(event.value, int): - counter = self._get_or_create_counter(event.metric, event.unit) - counter.add(event.value, attributes=_clean_attributes(event.attributes)) - elif isinstance(event.value, float): - up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit) - up_down_counter.add(event.value, attributes=_clean_attributes(event.attributes)) - - def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter: - assert self.meter is not None - if name not in _GLOBAL_STORAGE["up_down_counters"]: - _GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter( - name=name, - unit=unit, - description=f"UpDownCounter for {name}", - ) - return cast(metrics.UpDownCounter, _GLOBAL_STORAGE["up_down_counters"][name]) - - def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: - with self._lock: - span_id = int(event.span_id, 16) - tracer = trace.get_tracer(__name__) - if event.attributes is None: - event.attributes = {} - event.attributes["__ttl__"] = ttl_seconds - - # Extract these W3C trace context attributes so they are not written to - # underlying storage, as we just need them to propagate the trace context. - traceparent = event.attributes.pop("traceparent", None) - tracestate = event.attributes.pop("tracestate", None) - if traceparent: - # If we have a traceparent header value, we're not the root span. - for root_attribute in ROOT_SPAN_MARKERS: - event.attributes.pop(root_attribute, None) - - if isinstance(event.payload, SpanStartPayload): - # Check if span already exists to prevent duplicates - if span_id in _GLOBAL_STORAGE["active_spans"]: - return - - context = None - if event.payload.parent_span_id: - parent_span_id = int(event.payload.parent_span_id, 16) - parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) - if parent_span: - context = trace.set_span_in_context(parent_span) - elif traceparent: - carrier = { - "traceparent": traceparent, - "tracestate": tracestate, - } - context = TraceContextTextMapPropagator().extract(carrier=carrier) - - span = tracer.start_span( - name=event.payload.name, - context=context, - attributes=_clean_attributes(event.attributes), - ) - _GLOBAL_STORAGE["active_spans"][span_id] = span - - elif isinstance(event.payload, SpanEndPayload): - span = _GLOBAL_STORAGE["active_spans"].get(span_id) # type: ignore[assignment] - if span: - if event.attributes: - cleaned_attrs = _clean_attributes(event.attributes) - if cleaned_attrs: - span.set_attributes(cleaned_attrs) - - status = ( - trace.Status(status_code=trace.StatusCode.OK) - if event.payload.status == SpanStatus.OK - else trace.Status(status_code=trace.StatusCode.ERROR) - ) - span.set_status(status) - span.end() - _GLOBAL_STORAGE["active_spans"].pop(span_id, None) - else: - raise ValueError(f"Unknown structured log event: {event}") diff --git a/src/llama_stack/core/telemetry/trace_protocol.py b/src/llama_stack/core/telemetry/trace_protocol.py deleted file mode 100644 index 95b33a4bc..000000000 --- a/src/llama_stack/core/telemetry/trace_protocol.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import asyncio -import inspect -import json -from collections.abc import AsyncGenerator, Callable -from functools import wraps -from typing import Any, cast - -from pydantic import BaseModel - -from llama_stack.models.llama.datatypes import Primitive - -type JSONValue = Primitive | list["JSONValue"] | dict[str, "JSONValue"] - - -def serialize_value(value: Any) -> str: - return str(_prepare_for_json(value)) - - -def _prepare_for_json(value: Any) -> JSONValue: - """Serialize a single value into JSON-compatible format.""" - if value is None: - return "" - elif isinstance(value, str | int | float | bool): - return value - elif hasattr(value, "_name_"): - return cast(str, value._name_) - elif isinstance(value, BaseModel): - return cast(JSONValue, json.loads(value.model_dump_json())) - elif isinstance(value, list | tuple | set): - return [_prepare_for_json(item) for item in value] - elif isinstance(value, dict): - return {str(k): _prepare_for_json(v) for k, v in value.items()} - else: - try: - json.dumps(value) - return cast(JSONValue, value) - except Exception: - return str(value) - - -def trace_protocol[T: type[Any]](cls: T) -> T: - """ - A class decorator that automatically traces all methods in a protocol/base class - and its inheriting classes. - """ - - def trace_method(method: Callable[..., Any]) -> Callable[..., Any]: - is_async = asyncio.iscoroutinefunction(method) - is_async_gen = inspect.isasyncgenfunction(method) - - def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple[str, str, dict[str, Primitive]]: - class_name = self.__class__.__name__ - method_name = method.__name__ - span_type = "async_generator" if is_async_gen else "async" if is_async else "sync" - sig = inspect.signature(method) - param_names = list(sig.parameters.keys())[1:] # Skip 'self' - combined_args: dict[str, str] = {} - for i, arg in enumerate(args): - param_name = param_names[i] if i < len(param_names) else f"position_{i + 1}" - combined_args[param_name] = serialize_value(arg) - for k, v in kwargs.items(): - combined_args[str(k)] = serialize_value(v) - - span_attributes: dict[str, Primitive] = { - "__autotraced__": True, - "__class__": class_name, - "__method__": method_name, - "__type__": span_type, - "__args__": json.dumps(combined_args), - } - - return class_name, method_name, span_attributes - - @wraps(method) - async def async_gen_wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]: - from llama_stack.core.telemetry import tracing - - class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs) - - with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: - count = 0 - try: - async for item in method(self, *args, **kwargs): - yield item - count += 1 - finally: - span.set_attribute("chunk_count", count) - - @wraps(method) - async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - from llama_stack.core.telemetry import tracing - - class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs) - - with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: - try: - result = await method(self, *args, **kwargs) - span.set_attribute("output", serialize_value(result)) - return result - except Exception as e: - span.set_attribute("error", str(e)) - raise - - @wraps(method) - def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - from llama_stack.core.telemetry import tracing - - class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs) - - with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: - try: - result = method(self, *args, **kwargs) - span.set_attribute("output", serialize_value(result)) - return result - except Exception as e: - span.set_attribute("error", str(e)) - raise - - if is_async_gen: - return async_gen_wrapper - elif is_async: - return async_wrapper - else: - return sync_wrapper - - # Wrap methods on the class itself (for classes applied at runtime) - # Skip if already wrapped (indicated by __wrapped__ attribute) - for name, method in vars(cls).items(): - if inspect.isfunction(method) and not name.startswith("_"): - if not hasattr(method, "__wrapped__"): - wrapped = trace_method(method) - setattr(cls, name, wrapped) # noqa: B010 - - # Also set up __init_subclass__ for future subclasses - original_init_subclass = cast(Callable[..., Any] | None, getattr(cls, "__init_subclass__", None)) - - def __init_subclass__(cls_child: type[Any], **kwargs: Any) -> None: # noqa: N807 - if original_init_subclass: - cast(Callable[..., None], original_init_subclass)(**kwargs) - - for name, method in vars(cls_child).items(): - if inspect.isfunction(method) and not name.startswith("_"): - setattr(cls_child, name, trace_method(method)) # noqa: B010 - - cls_any = cast(Any, cls) - cls_any.__init_subclass__ = classmethod(__init_subclass__) - - return cls diff --git a/src/llama_stack/core/telemetry/tracing.py b/src/llama_stack/core/telemetry/tracing.py deleted file mode 100644 index a67cbe784..000000000 --- a/src/llama_stack/core/telemetry/tracing.py +++ /dev/null @@ -1,388 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import asyncio -import contextvars -import logging # allow-direct-logging -import queue -import secrets -import sys -import threading -import time -from collections.abc import Callable -from datetime import UTC, datetime -from functools import wraps -from typing import Any, Self - -from llama_stack.core.telemetry.telemetry import ( - ROOT_SPAN_MARKERS, - Event, - LogSeverity, - Span, - SpanEndPayload, - SpanStartPayload, - SpanStatus, - StructuredLogEvent, - Telemetry, - UnstructuredLogEvent, -) -from llama_stack.core.telemetry.trace_protocol import serialize_value -from llama_stack.log import get_logger - -logger = get_logger(__name__, category="core") - -# Fallback logger that does NOT propagate to TelemetryHandler to avoid recursion -_fallback_logger = logging.getLogger("llama_stack.telemetry.background") -if not _fallback_logger.handlers: - _fallback_logger.propagate = False - _fallback_logger.setLevel(logging.ERROR) - _fallback_handler = logging.StreamHandler(sys.stderr) - _fallback_handler.setLevel(logging.ERROR) - _fallback_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")) - _fallback_logger.addHandler(_fallback_handler) - - -INVALID_SPAN_ID = 0x0000000000000000 -INVALID_TRACE_ID = 0x00000000000000000000000000000000 - -# The logical root span may not be visible to this process if a parent context -# is passed in. The local root span is the first local span in a trace. -LOCAL_ROOT_SPAN_MARKER = "__local_root_span__" - - -def trace_id_to_str(trace_id: int) -> str: - """Convenience trace ID formatting method - Args: - trace_id: Trace ID int - - Returns: - The trace ID as 32-byte hexadecimal string - """ - return format(trace_id, "032x") - - -def span_id_to_str(span_id: int) -> str: - """Convenience span ID formatting method - Args: - span_id: Span ID int - - Returns: - The span ID as 16-byte hexadecimal string - """ - return format(span_id, "016x") - - -def generate_span_id() -> str: - span_id = secrets.randbits(64) - while span_id == INVALID_SPAN_ID: - span_id = secrets.randbits(64) - return span_id_to_str(span_id) - - -def generate_trace_id() -> str: - trace_id = secrets.randbits(128) - while trace_id == INVALID_TRACE_ID: - trace_id = secrets.randbits(128) - return trace_id_to_str(trace_id) - - -LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS = 60.0 - - -class BackgroundLogger: - def __init__(self, api: Telemetry, capacity: int = 100000): - self.api = api - self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity) - self.worker_thread = threading.Thread(target=self._worker, daemon=True) - self.worker_thread.start() - self._last_queue_full_log_time: float = 0.0 - self._dropped_since_last_notice: int = 0 - - def log_event(self, event: Event) -> None: - try: - self.log_queue.put_nowait(event) - except queue.Full: - # Aggregate drops and emit at most once per interval via fallback logger - self._dropped_since_last_notice += 1 - current_time = time.time() - if current_time - self._last_queue_full_log_time >= LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS: - _fallback_logger.error( - "Log queue is full; dropped %d events since last notice", - self._dropped_since_last_notice, - ) - self._last_queue_full_log_time = current_time - self._dropped_since_last_notice = 0 - - def _worker(self): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self._process_logs()) - - async def _process_logs(self): - while True: - try: - event = self.log_queue.get() - await self.api.log_event(event) - except Exception: - import traceback - - traceback.print_exc() - print("Error processing log event") - finally: - self.log_queue.task_done() - - def __del__(self) -> None: - self.log_queue.join() - - -BACKGROUND_LOGGER: BackgroundLogger | None = None - - -def enqueue_event(event: Event) -> None: - """Enqueue a telemetry event to the background logger if available. - - This provides a non-blocking path for routers and other hot paths to - submit telemetry without awaiting the Telemetry API, reducing contention - with the main event loop. - """ - global BACKGROUND_LOGGER - if BACKGROUND_LOGGER is None: - raise RuntimeError("Telemetry API not initialized") - BACKGROUND_LOGGER.log_event(event) - - -class TraceContext: - def __init__(self, logger: BackgroundLogger, trace_id: str): - self.logger = logger - self.trace_id = trace_id - self.spans: list[Span] = [] - - def push_span(self, name: str, attributes: dict[str, Any] | None = None) -> Span: - current_span = self.get_current_span() - span = Span( - span_id=generate_span_id(), - trace_id=self.trace_id, - name=name, - start_time=datetime.now(UTC), - parent_span_id=current_span.span_id if current_span else None, - attributes=attributes, - ) - - self.logger.log_event( - StructuredLogEvent( - trace_id=span.trace_id, - span_id=span.span_id, - timestamp=span.start_time, - attributes=span.attributes, - payload=SpanStartPayload( - name=span.name, - parent_span_id=span.parent_span_id, - ), - ) - ) - - self.spans.append(span) - return span - - def pop_span(self, status: SpanStatus = SpanStatus.OK) -> None: - span = self.spans.pop() - if span is not None: - self.logger.log_event( - StructuredLogEvent( - trace_id=span.trace_id, - span_id=span.span_id, - timestamp=span.start_time, - attributes=span.attributes, - payload=SpanEndPayload( - status=status, - ), - ) - ) - - def get_current_span(self) -> Span | None: - return self.spans[-1] if self.spans else None - - -CURRENT_TRACE_CONTEXT: contextvars.ContextVar[TraceContext | None] = contextvars.ContextVar( - "trace_context", default=None -) - - -def setup_logger(api: Telemetry, level: int = logging.INFO): - global BACKGROUND_LOGGER - - if BACKGROUND_LOGGER is None: - BACKGROUND_LOGGER = BackgroundLogger(api) - root_logger = logging.getLogger() - root_logger.setLevel(level) - root_logger.addHandler(TelemetryHandler()) - - -async def start_trace(name: str, attributes: dict[str, Any] | None = None) -> TraceContext | None: - global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER - - if BACKGROUND_LOGGER is None: - logger.debug("No Telemetry implementation set. Skipping trace initialization...") - return None - - trace_id = generate_trace_id() - context = TraceContext(BACKGROUND_LOGGER, trace_id) - # Mark this span as the root for the trace for now. The processing of - # traceparent context if supplied comes later and will result in the - # ROOT_SPAN_MARKERS being removed. Also mark this is the 'local' root, - # i.e. the root of the spans originating in this process as this is - # needed to ensure that we insert this 'local' root span's id into - # the trace record in sqlite store. - attributes = dict.fromkeys(ROOT_SPAN_MARKERS, True) | {LOCAL_ROOT_SPAN_MARKER: True} | (attributes or {}) - context.push_span(name, attributes) - - CURRENT_TRACE_CONTEXT.set(context) - return context - - -async def end_trace(status: SpanStatus = SpanStatus.OK): - global CURRENT_TRACE_CONTEXT - - context = CURRENT_TRACE_CONTEXT.get() - if context is None: - logger.debug("No trace context to end") - return - - context.pop_span(status) - CURRENT_TRACE_CONTEXT.set(None) - - -def severity(levelname: str) -> LogSeverity: - if levelname == "DEBUG": - return LogSeverity.DEBUG - elif levelname == "INFO": - return LogSeverity.INFO - elif levelname == "WARNING": - return LogSeverity.WARN - elif levelname == "ERROR": - return LogSeverity.ERROR - elif levelname == "CRITICAL": - return LogSeverity.CRITICAL - else: - raise ValueError(f"Unknown log level: {levelname}") - - -# TODO: ideally, the actual emitting should be done inside a separate daemon -# process completely isolated from the server -class TelemetryHandler(logging.Handler): - def emit(self, record: logging.LogRecord) -> None: - # horrendous hack to avoid logging from asyncio and getting into an infinite loop - if record.module in ("asyncio", "selector_events"): - return - - global CURRENT_TRACE_CONTEXT - context = CURRENT_TRACE_CONTEXT.get() - if context is None: - return - - span = context.get_current_span() - if span is None: - return - - enqueue_event( - UnstructuredLogEvent( - trace_id=span.trace_id, - span_id=span.span_id, - timestamp=datetime.now(UTC), - message=self.format(record), - severity=severity(record.levelname), - ) - ) - - def close(self) -> None: - pass - - -class SpanContextManager: - def __init__(self, name: str, attributes: dict[str, Any] | None = None): - self.name = name - self.attributes = attributes - self.span: Span | None = None - - def __enter__(self) -> Self: - global CURRENT_TRACE_CONTEXT - context = CURRENT_TRACE_CONTEXT.get() - if not context: - logger.debug("No trace context to push span") - return self - - self.span = context.push_span(self.name, self.attributes) - return self - - def __exit__(self, exc_type, exc_value, traceback) -> None: - global CURRENT_TRACE_CONTEXT - context = CURRENT_TRACE_CONTEXT.get() - if not context: - logger.debug("No trace context to pop span") - return - - context.pop_span() - - def set_attribute(self, key: str, value: Any) -> None: - if self.span: - if self.span.attributes is None: - self.span.attributes = {} - self.span.attributes[key] = serialize_value(value) - - async def __aenter__(self) -> Self: - global CURRENT_TRACE_CONTEXT - context = CURRENT_TRACE_CONTEXT.get() - if not context: - logger.debug("No trace context to push span") - return self - - self.span = context.push_span(self.name, self.attributes) - return self - - async def __aexit__(self, exc_type, exc_value, traceback) -> None: - global CURRENT_TRACE_CONTEXT - context = CURRENT_TRACE_CONTEXT.get() - if not context: - logger.debug("No trace context to pop span") - return - - context.pop_span() - - def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]: - @wraps(func) - def sync_wrapper(*args: Any, **kwargs: Any) -> Any: - with self: - return func(*args, **kwargs) - - @wraps(func) - async def async_wrapper(*args: Any, **kwargs: Any) -> Any: - async with self: - return await func(*args, **kwargs) - - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - if asyncio.iscoroutinefunction(func): - return async_wrapper(*args, **kwargs) - else: - return sync_wrapper(*args, **kwargs) - - return wrapper - - -def span(name: str, attributes: dict[str, Any] | None = None) -> SpanContextManager: - return SpanContextManager(name, attributes) - - -def get_current_span() -> Span | None: - global CURRENT_TRACE_CONTEXT - if CURRENT_TRACE_CONTEXT is None: - logger.debug("No trace context to get current span") - return None - - context = CURRENT_TRACE_CONTEXT.get() - if context: - return context.get_current_span() - return None