Revert "fix(mypy): resolve OpenTelemetry typing issues in telemetry.py (#3931)"

This reverts commit 9afc52a36a.
This commit is contained in:
Ashwin Bharambe 2025-10-28 09:48:46 -07:00
parent 9afc52a36a
commit 85887d724f
2 changed files with 423 additions and 49 deletions

View file

@ -6,8 +6,13 @@
import os import os
import threading import threading
from collections.abc import Mapping, Sequence from datetime import datetime
from typing import Any, cast from enum import Enum
from typing import (
Annotated,
Any,
Literal,
)
from opentelemetry import metrics, trace from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@ -17,22 +22,399 @@ from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from pydantic import BaseModel, Field
# Type alias for OpenTelemetry attribute values (excludes None)
AttributeValue = str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]
Attributes = Mapping[str, AttributeValue]
from llama_stack.apis.telemetry import (
Event,
MetricEvent,
SpanEndPayload,
SpanStartPayload,
SpanStatus,
StructuredLogEvent,
UnstructuredLogEvent,
)
from llama_stack.core.telemetry.tracing import ROOT_SPAN_MARKERS
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import Primitive
from llama_stack.schema_utils import json_schema_type, register_schema
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
@json_schema_type
class SpanStatus(Enum):
"""The status of a span indicating whether it completed successfully or with an error.
:cvar OK: Span completed successfully without errors
:cvar ERROR: Span completed with an error or failure
"""
OK = "ok"
ERROR = "error"
@json_schema_type
class Span(BaseModel):
"""A span representing a single operation within a trace.
:param span_id: Unique identifier for the span
:param trace_id: Unique identifier for the trace this span belongs to
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
:param name: Human-readable name describing the operation this span represents
:param start_time: Timestamp when the operation began
:param end_time: (Optional) Timestamp when the operation finished, if completed
:param attributes: (Optional) Key-value pairs containing additional metadata about the span
"""
span_id: str
trace_id: str
parent_span_id: str | None = None
name: str
start_time: datetime
end_time: datetime | None = None
attributes: dict[str, Any] | None = Field(default_factory=lambda: {})
def set_attribute(self, key: str, value: Any):
if self.attributes is None:
self.attributes = {}
self.attributes[key] = value
@json_schema_type
class Trace(BaseModel):
"""A trace representing the complete execution path of a request across multiple operations.
:param trace_id: Unique identifier for the trace
:param root_span_id: Unique identifier for the root span that started this trace
:param start_time: Timestamp when the trace began
:param end_time: (Optional) Timestamp when the trace finished, if completed
"""
trace_id: str
root_span_id: str
start_time: datetime
end_time: datetime | None = None
@json_schema_type
class EventType(Enum):
"""The type of telemetry event being logged.
:cvar UNSTRUCTURED_LOG: A simple log message with severity level
:cvar STRUCTURED_LOG: A structured log event with typed payload data
:cvar METRIC: A metric measurement with value and unit
"""
UNSTRUCTURED_LOG = "unstructured_log"
STRUCTURED_LOG = "structured_log"
METRIC = "metric"
@json_schema_type
class LogSeverity(Enum):
"""The severity level of a log message.
:cvar VERBOSE: Detailed diagnostic information for troubleshooting
:cvar DEBUG: Debug information useful during development
:cvar INFO: General informational messages about normal operation
:cvar WARN: Warning messages about potentially problematic situations
:cvar ERROR: Error messages indicating failures that don't stop execution
:cvar CRITICAL: Critical error messages indicating severe failures
"""
VERBOSE = "verbose"
DEBUG = "debug"
INFO = "info"
WARN = "warn"
ERROR = "error"
CRITICAL = "critical"
class EventCommon(BaseModel):
"""Common fields shared by all telemetry events.
:param trace_id: Unique identifier for the trace this event belongs to
:param span_id: Unique identifier for the span this event belongs to
:param timestamp: Timestamp when the event occurred
:param attributes: (Optional) Key-value pairs containing additional metadata about the event
"""
trace_id: str
span_id: str
timestamp: datetime
attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {})
@json_schema_type
class UnstructuredLogEvent(EventCommon):
"""An unstructured log event containing a simple text message.
:param type: Event type identifier set to UNSTRUCTURED_LOG
:param message: The log message text
:param severity: The severity level of the log message
"""
type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG
message: str
severity: LogSeverity
@json_schema_type
class MetricEvent(EventCommon):
"""A metric event containing a measured value.
:param type: Event type identifier set to METRIC
:param metric: The name of the metric being measured
:param value: The numeric value of the metric measurement
:param unit: The unit of measurement for the metric value
"""
type: Literal[EventType.METRIC] = EventType.METRIC
metric: str # this would be an enum
value: int | float
unit: str
@json_schema_type
class MetricInResponse(BaseModel):
"""A metric value included in API responses.
:param metric: The name of the metric
:param value: The numeric value of the metric
:param unit: (Optional) The unit of measurement for the metric value
"""
metric: str
value: int | float
unit: str | None = None
# This is a short term solution to allow inference API to return metrics
# The ideal way to do this is to have a way for all response types to include metrics
# and all metric events logged to the telemetry API to be included with the response
# To do this, we will need to augment all response types with a metrics field.
# We have hit a blocker from stainless SDK that prevents us from doing this.
# The blocker is that if we were to augment the response types that have a data field
# in them like so
# class ListModelsResponse(BaseModel):
# metrics: Optional[List[MetricEvent]] = None
# data: List[Models]
# ...
# The client SDK will need to access the data by using a .data field, which is not
# ergonomic. Stainless SDK does support unwrapping the response type, but it
# requires that the response type to only have a single field.
# We will need a way in the client SDK to signal that the metrics are needed
# and if they are needed, the client SDK has to return the full response type
# without unwrapping it.
class MetricResponseMixin(BaseModel):
"""Mixin class for API responses that can include metrics.
:param metrics: (Optional) List of metrics associated with the API response
"""
metrics: list[MetricInResponse] | None = None
@json_schema_type
class StructuredLogType(Enum):
"""The type of structured log event payload.
:cvar SPAN_START: Event indicating the start of a new span
:cvar SPAN_END: Event indicating the completion of a span
"""
SPAN_START = "span_start"
SPAN_END = "span_end"
@json_schema_type
class SpanStartPayload(BaseModel):
"""Payload for a span start event.
:param type: Payload type identifier set to SPAN_START
:param name: Human-readable name describing the operation this span represents
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
"""
type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START
name: str
parent_span_id: str | None = None
@json_schema_type
class SpanEndPayload(BaseModel):
"""Payload for a span end event.
:param type: Payload type identifier set to SPAN_END
:param status: The final status of the span indicating success or failure
"""
type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END
status: SpanStatus
StructuredLogPayload = Annotated[
SpanStartPayload | SpanEndPayload,
Field(discriminator="type"),
]
register_schema(StructuredLogPayload, name="StructuredLogPayload")
@json_schema_type
class StructuredLogEvent(EventCommon):
"""A structured log event containing typed payload data.
:param type: Event type identifier set to STRUCTURED_LOG
:param payload: The structured payload data for the log event
"""
type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG
payload: StructuredLogPayload
Event = Annotated[
UnstructuredLogEvent | MetricEvent | StructuredLogEvent,
Field(discriminator="type"),
]
register_schema(Event, name="Event")
@json_schema_type
class EvalTrace(BaseModel):
"""A trace record for evaluation purposes.
:param session_id: Unique identifier for the evaluation session
:param step: The evaluation step or phase identifier
:param input: The input data for the evaluation
:param output: The actual output produced during evaluation
:param expected_output: The expected output for comparison during evaluation
"""
session_id: str
step: str
input: str
output: str
expected_output: str
@json_schema_type
class SpanWithStatus(Span):
"""A span that includes status information.
:param status: (Optional) The current status of the span
"""
status: SpanStatus | None = None
@json_schema_type
class QueryConditionOp(Enum):
"""Comparison operators for query conditions.
:cvar EQ: Equal to comparison
:cvar NE: Not equal to comparison
:cvar GT: Greater than comparison
:cvar LT: Less than comparison
"""
EQ = "eq"
NE = "ne"
GT = "gt"
LT = "lt"
@json_schema_type
class QueryCondition(BaseModel):
"""A condition for filtering query results.
:param key: The attribute key to filter on
:param op: The comparison operator to apply
:param value: The value to compare against
"""
key: str
op: QueryConditionOp
value: Any
class QueryTracesResponse(BaseModel):
"""Response containing a list of traces.
:param data: List of traces matching the query criteria
"""
data: list[Trace]
class QuerySpansResponse(BaseModel):
"""Response containing a list of spans.
:param data: List of spans matching the query criteria
"""
data: list[Span]
class QuerySpanTreeResponse(BaseModel):
"""Response containing a tree structure of spans.
:param data: Dictionary mapping span IDs to spans with status information
"""
data: dict[str, SpanWithStatus]
class MetricQueryType(Enum):
"""The type of metric query to perform.
:cvar RANGE: Query metrics over a time range
:cvar INSTANT: Query metrics at a specific point in time
"""
RANGE = "range"
INSTANT = "instant"
class MetricLabelOperator(Enum):
"""Operators for matching metric labels.
:cvar EQUALS: Label value must equal the specified value
:cvar NOT_EQUALS: Label value must not equal the specified value
:cvar REGEX_MATCH: Label value must match the specified regular expression
:cvar REGEX_NOT_MATCH: Label value must not match the specified regular expression
"""
EQUALS = "="
NOT_EQUALS = "!="
REGEX_MATCH = "=~"
REGEX_NOT_MATCH = "!~"
class MetricLabelMatcher(BaseModel):
"""A matcher for filtering metrics by label values.
:param name: The name of the label to match
:param value: The value to match against
:param operator: The comparison operator to use for matching
"""
name: str
value: str
operator: MetricLabelOperator = MetricLabelOperator.EQUALS
@json_schema_type
class MetricLabel(BaseModel):
"""A label associated with a metric.
:param name: The name of the label
:param value: The value of the label
"""
name: str
value: str
@json_schema_type
class MetricDataPoint(BaseModel):
"""A single data point in a metric time series.
:param timestamp: Unix timestamp when the metric value was recorded
:param value: The numeric value of the metric at this timestamp
"""
timestamp: int
value: float
unit: str
@json_schema_type
class MetricSeries(BaseModel):
"""A time series of metric data points.
:param metric: The name of the metric
:param labels: List of labels associated with this metric series
:param values: List of data points in chronological order
"""
metric: str
labels: list[MetricLabel]
values: list[MetricDataPoint]
class QueryMetricsResponse(BaseModel):
"""Response containing metric time series data.
:param data: List of metric series matching the query criteria
"""
data: list[MetricSeries]
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = { _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"active_spans": {}, "active_spans": {},
@ -46,13 +428,6 @@ _TRACER_PROVIDER = None
logger = get_logger(name=__name__, category="telemetry") logger = get_logger(name=__name__, category="telemetry")
def _clean_attributes(attrs: dict[str, Any] | None) -> Attributes | None:
"""Remove None values from attributes dict to match OpenTelemetry's expected type."""
if attrs is None:
return None
return {k: v for k, v in attrs.items() if v is not None}
def is_tracing_enabled(tracer): def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span: with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording() return span.is_recording()
@ -81,7 +456,7 @@ class Telemetry:
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter # https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
span_exporter = OTLPSpanExporter() span_exporter = OTLPSpanExporter()
span_processor = BatchSpanProcessor(span_exporter) span_processor = BatchSpanProcessor(span_exporter)
cast(TracerProvider, trace.get_tracer_provider()).add_span_processor(span_processor) trace.get_tracer_provider().add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter()) metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
metric_provider = MeterProvider(metric_readers=[metric_reader]) metric_provider = MeterProvider(metric_readers=[metric_reader])
@ -99,7 +474,7 @@ class Telemetry:
async def shutdown(self) -> None: async def shutdown(self) -> None:
if self.is_otel_endpoint_set: if self.is_otel_endpoint_set:
cast(TracerProvider, trace.get_tracer_provider()).force_flush() trace.get_tracer_provider().force_flush()
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
if isinstance(event, UnstructuredLogEvent): if isinstance(event, UnstructuredLogEvent):
@ -140,7 +515,7 @@ class Telemetry:
unit=unit, unit=unit,
description=f"Counter for {name}", description=f"Counter for {name}",
) )
return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name]) return _GLOBAL_STORAGE["counters"][name]
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
assert self.meter is not None assert self.meter is not None
@ -150,7 +525,7 @@ class Telemetry:
unit=unit, unit=unit,
description=f"Gauge for {name}", description=f"Gauge for {name}",
) )
return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name]) return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None: def _log_metric(self, event: MetricEvent) -> None:
# Add metric as an event to the current span # Add metric as an event to the current span
@ -185,10 +560,10 @@ class Telemetry:
return return
if isinstance(event.value, int): if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit) counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=_clean_attributes(event.attributes)) counter.add(event.value, attributes=event.attributes)
elif isinstance(event.value, float): elif isinstance(event.value, float):
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit) up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
up_down_counter.add(event.value, attributes=_clean_attributes(event.attributes)) up_down_counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter: def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
assert self.meter is not None assert self.meter is not None
@ -198,7 +573,7 @@ class Telemetry:
unit=unit, unit=unit,
description=f"UpDownCounter for {name}", description=f"UpDownCounter for {name}",
) )
return cast(metrics.UpDownCounter, _GLOBAL_STORAGE["up_down_counters"][name]) return _GLOBAL_STORAGE["up_down_counters"][name]
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock: with self._lock:
@ -226,8 +601,7 @@ class Telemetry:
if event.payload.parent_span_id: if event.payload.parent_span_id:
parent_span_id = int(event.payload.parent_span_id, 16) parent_span_id = int(event.payload.parent_span_id, 16)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
if parent_span is not None: context = trace.set_span_in_context(parent_span)
context = trace.set_span_in_context(parent_span)
elif traceparent: elif traceparent:
carrier = { carrier = {
"traceparent": traceparent, "traceparent": traceparent,
@ -238,25 +612,23 @@ class Telemetry:
span = tracer.start_span( span = tracer.start_span(
name=event.payload.name, name=event.payload.name,
context=context, context=context,
attributes=_clean_attributes(event.attributes) or {}, attributes=event.attributes or {},
) )
_GLOBAL_STORAGE["active_spans"][span_id] = span _GLOBAL_STORAGE["active_spans"][span_id] = span
elif isinstance(event.payload, SpanEndPayload): elif isinstance(event.payload, SpanEndPayload):
end_span = cast(trace.Span | None, _GLOBAL_STORAGE["active_spans"].get(span_id)) span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if end_span: if span:
if event.attributes: if event.attributes:
cleaned_attrs = _clean_attributes(event.attributes) span.set_attributes(event.attributes)
if cleaned_attrs:
end_span.set_attributes(cleaned_attrs)
status = ( status = (
trace.Status(status_code=trace.StatusCode.OK) trace.Status(status_code=trace.StatusCode.OK)
if event.payload.status == SpanStatus.OK if event.payload.status == SpanStatus.OK
else trace.Status(status_code=trace.StatusCode.ERROR) else trace.Status(status_code=trace.StatusCode.ERROR)
) )
end_span.set_status(status) span.set_status(status)
end_span.end() span.end()
_GLOBAL_STORAGE["active_spans"].pop(span_id, None) _GLOBAL_STORAGE["active_spans"].pop(span_id, None)
else: else:
raise ValueError(f"Unknown structured log event: {event}") raise ValueError(f"Unknown structured log event: {event}")

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Literal, cast from typing import Any, Literal
from sqlalchemy import ( from sqlalchemy import (
JSON, JSON,
@ -55,17 +55,17 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
raise ValueError(f"Operator mapping must have a single operator, got: {value}") raise ValueError(f"Operator mapping must have a single operator, got: {value}")
op, operand = next(iter(value.items())) op, operand = next(iter(value.items()))
if op == "==" or op == "=": if op == "==" or op == "=":
return cast(ColumnElement[Any], column == operand) return column == operand
if op == ">": if op == ">":
return cast(ColumnElement[Any], column > operand) return column > operand
if op == "<": if op == "<":
return cast(ColumnElement[Any], column < operand) return column < operand
if op == ">=": if op == ">=":
return cast(ColumnElement[Any], column >= operand) return column >= operand
if op == "<=": if op == "<=":
return cast(ColumnElement[Any], column <= operand) return column <= operand
raise ValueError(f"Unsupported operator '{op}' in where mapping") raise ValueError(f"Unsupported operator '{op}' in where mapping")
return cast(ColumnElement[Any], column == value) return column == value
class SqlAlchemySqlStoreImpl(SqlStore): class SqlAlchemySqlStoreImpl(SqlStore):
@ -210,8 +210,10 @@ class SqlAlchemySqlStoreImpl(SqlStore):
query = query.limit(fetch_limit) query = query.limit(fetch_limit)
result = await session.execute(query) result = await session.execute(query)
# Iterate directly - if no rows, list comprehension yields empty list if result.rowcount == 0:
rows = [dict(row._mapping) for row in result] rows = []
else:
rows = [dict(row._mapping) for row in result]
# Always return pagination result # Always return pagination result
has_more = False has_more = False