fix(mypy): resolve OpenTelemetry typing issues in telemetry.py

Fix all 11 mypy errors in telemetry.py without using suppressions:

**Changes:**
- Add type aliases for OpenTelemetry attribute types (AttributeValue, Attributes)
- Create `_clean_attributes()` helper to filter None values from attribute dicts
- Use `cast()` for TracerProvider methods (add_span_processor, force_flush)
- Use `cast()` for metric creation methods returning from global storage
- Fix variable reuse by renaming `span` to `end_span` in SpanEndPayload branch
- Add None check for parent_span before set_span_in_context

**Errors fixed:**
- TracerProvider attribute access (2 errors)
- Counter/UpDownCounter/ObservableGauge return types (3 errors)
- Attribute dict type mismatches (4 errors)
- Span assignment type conflicts (2 errors)

This eliminates all mypy errors in the telemetry module.
This commit is contained in:
Ashwin Bharambe 2025-10-27 21:04:14 -07:00
parent 893c237759
commit 6e47335371

View file

@ -6,7 +6,8 @@
import os import os
import threading import threading
from typing import Any from collections.abc import Mapping, Sequence
from typing import Any, cast
from opentelemetry import metrics, trace from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@ -17,6 +18,10 @@ from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
# Type alias for OpenTelemetry attribute values (excludes None)
AttributeValue = str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]
Attributes = Mapping[str, AttributeValue]
from llama_stack.apis.telemetry import ( from llama_stack.apis.telemetry import (
Event, Event,
MetricEvent, MetricEvent,
@ -44,6 +49,13 @@ _TRACER_PROVIDER = None
logger = get_logger(name=__name__, category="telemetry") logger = get_logger(name=__name__, category="telemetry")
def _clean_attributes(attrs: dict[str, Any] | None) -> Attributes | None:
"""Remove None values from attributes dict to match OpenTelemetry's expected type."""
if attrs is None:
return None
return {k: v for k, v in attrs.items() if v is not None}
def is_tracing_enabled(tracer): def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span: with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording() return span.is_recording()
@ -72,7 +84,7 @@ class Telemetry(TelemetryBase):
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter # https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
span_exporter = OTLPSpanExporter() span_exporter = OTLPSpanExporter()
span_processor = BatchSpanProcessor(span_exporter) span_processor = BatchSpanProcessor(span_exporter)
trace.get_tracer_provider().add_span_processor(span_processor) cast(TracerProvider, trace.get_tracer_provider()).add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter()) metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
metric_provider = MeterProvider(metric_readers=[metric_reader]) metric_provider = MeterProvider(metric_readers=[metric_reader])
@ -90,7 +102,7 @@ class Telemetry(TelemetryBase):
async def shutdown(self) -> None: async def shutdown(self) -> None:
if self.is_otel_endpoint_set: if self.is_otel_endpoint_set:
trace.get_tracer_provider().force_flush() cast(TracerProvider, trace.get_tracer_provider()).force_flush()
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
if isinstance(event, UnstructuredLogEvent): if isinstance(event, UnstructuredLogEvent):
@ -131,7 +143,7 @@ class Telemetry(TelemetryBase):
unit=unit, unit=unit,
description=f"Counter for {name}", description=f"Counter for {name}",
) )
return _GLOBAL_STORAGE["counters"][name] return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name])
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
assert self.meter is not None assert self.meter is not None
@ -141,7 +153,7 @@ class Telemetry(TelemetryBase):
unit=unit, unit=unit,
description=f"Gauge for {name}", description=f"Gauge for {name}",
) )
return _GLOBAL_STORAGE["gauges"][name] return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
def _log_metric(self, event: MetricEvent) -> None: def _log_metric(self, event: MetricEvent) -> None:
# Add metric as an event to the current span # Add metric as an event to the current span
@ -176,10 +188,10 @@ class Telemetry(TelemetryBase):
return return
if isinstance(event.value, int): if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit) counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes) counter.add(event.value, attributes=_clean_attributes(event.attributes))
elif isinstance(event.value, float): elif isinstance(event.value, float):
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit) up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
up_down_counter.add(event.value, attributes=event.attributes) up_down_counter.add(event.value, attributes=_clean_attributes(event.attributes))
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter: def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
assert self.meter is not None assert self.meter is not None
@ -189,7 +201,7 @@ class Telemetry(TelemetryBase):
unit=unit, unit=unit,
description=f"UpDownCounter for {name}", description=f"UpDownCounter for {name}",
) )
return _GLOBAL_STORAGE["up_down_counters"][name] return cast(metrics.UpDownCounter, _GLOBAL_STORAGE["up_down_counters"][name])
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock: with self._lock:
@ -217,7 +229,8 @@ class Telemetry(TelemetryBase):
if event.payload.parent_span_id: if event.payload.parent_span_id:
parent_span_id = int(event.payload.parent_span_id, 16) parent_span_id = int(event.payload.parent_span_id, 16)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
context = trace.set_span_in_context(parent_span) if parent_span is not None:
context = trace.set_span_in_context(parent_span)
elif traceparent: elif traceparent:
carrier = { carrier = {
"traceparent": traceparent, "traceparent": traceparent,
@ -228,23 +241,25 @@ class Telemetry(TelemetryBase):
span = tracer.start_span( span = tracer.start_span(
name=event.payload.name, name=event.payload.name,
context=context, context=context,
attributes=event.attributes or {}, attributes=_clean_attributes(event.attributes) or {},
) )
_GLOBAL_STORAGE["active_spans"][span_id] = span _GLOBAL_STORAGE["active_spans"][span_id] = span
elif isinstance(event.payload, SpanEndPayload): elif isinstance(event.payload, SpanEndPayload):
span = _GLOBAL_STORAGE["active_spans"].get(span_id) end_span = cast(trace.Span | None, _GLOBAL_STORAGE["active_spans"].get(span_id))
if span: if end_span:
if event.attributes: if event.attributes:
span.set_attributes(event.attributes) cleaned_attrs = _clean_attributes(event.attributes)
if cleaned_attrs:
end_span.set_attributes(cleaned_attrs)
status = ( status = (
trace.Status(status_code=trace.StatusCode.OK) trace.Status(status_code=trace.StatusCode.OK)
if event.payload.status == SpanStatus.OK if event.payload.status == SpanStatus.OK
else trace.Status(status_code=trace.StatusCode.ERROR) else trace.Status(status_code=trace.StatusCode.ERROR)
) )
span.set_status(status) end_span.set_status(status)
span.end() end_span.end()
_GLOBAL_STORAGE["active_spans"].pop(span_id, None) _GLOBAL_STORAGE["active_spans"].pop(span_id, None)
else: else:
raise ValueError(f"Unknown structured log event: {event}") raise ValueError(f"Unknown structured log event: {event}")