From 6e47335371bbe70bfaa362fa61c375b1f353fd8b Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 27 Oct 2025 21:04:14 -0700 Subject: [PATCH] fix(mypy): resolve OpenTelemetry typing issues in telemetry.py Fix all 11 mypy errors in telemetry.py without using suppressions: **Changes:** - Add type aliases for OpenTelemetry attribute types (AttributeValue, Attributes) - Create `_clean_attributes()` helper to filter None values from attribute dicts - Use `cast()` for TracerProvider methods (add_span_processor, force_flush) - Use `cast()` for metric creation methods returning from global storage - Fix variable reuse by renaming `span` to `end_span` in SpanEndPayload branch - Add None check for parent_span before set_span_in_context **Errors fixed:** - TracerProvider attribute access (2 errors) - Counter/UpDownCounter/ObservableGauge return types (3 errors) - Attribute dict type mismatches (4 errors) - Span assignment type conflicts (2 errors) This eliminates all mypy errors in the telemetry module. --- src/llama_stack/core/telemetry/telemetry.py | 45 ++++++++++++++------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/src/llama_stack/core/telemetry/telemetry.py b/src/llama_stack/core/telemetry/telemetry.py index f0cec08ec..fa49a45ef 100644 --- a/src/llama_stack/core/telemetry/telemetry.py +++ b/src/llama_stack/core/telemetry/telemetry.py @@ -6,7 +6,8 @@ import os import threading -from typing import Any +from collections.abc import Mapping, Sequence +from typing import Any, cast from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter @@ -17,6 +18,10 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +# Type alias for OpenTelemetry attribute values (excludes None) +AttributeValue = str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float] +Attributes = Mapping[str, AttributeValue] + from llama_stack.apis.telemetry import ( Event, MetricEvent, @@ -44,6 +49,13 @@ _TRACER_PROVIDER = None logger = get_logger(name=__name__, category="telemetry") +def _clean_attributes(attrs: dict[str, Any] | None) -> Attributes | None: + """Remove None values from attributes dict to match OpenTelemetry's expected type.""" + if attrs is None: + return None + return {k: v for k, v in attrs.items() if v is not None} + + def is_tracing_enabled(tracer): with tracer.start_as_current_span("check_tracing") as span: return span.is_recording() @@ -72,7 +84,7 @@ class Telemetry(TelemetryBase): # https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter span_exporter = OTLPSpanExporter() span_processor = BatchSpanProcessor(span_exporter) - trace.get_tracer_provider().add_span_processor(span_processor) + cast(TracerProvider, trace.get_tracer_provider()).add_span_processor(span_processor) metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter()) metric_provider = MeterProvider(metric_readers=[metric_reader]) @@ -90,7 +102,7 @@ class Telemetry(TelemetryBase): async def shutdown(self) -> None: if self.is_otel_endpoint_set: - trace.get_tracer_provider().force_flush() + cast(TracerProvider, trace.get_tracer_provider()).force_flush() async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: if isinstance(event, UnstructuredLogEvent): @@ -131,7 +143,7 @@ class Telemetry(TelemetryBase): unit=unit, description=f"Counter for {name}", ) - return _GLOBAL_STORAGE["counters"][name] + return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name]) def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: assert self.meter is not None @@ -141,7 +153,7 @@ class Telemetry(TelemetryBase): unit=unit, description=f"Gauge for {name}", ) - return _GLOBAL_STORAGE["gauges"][name] + return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name]) def _log_metric(self, event: MetricEvent) -> None: # Add metric as an event to the current span @@ -176,10 +188,10 @@ class Telemetry(TelemetryBase): return if isinstance(event.value, int): counter = self._get_or_create_counter(event.metric, event.unit) - counter.add(event.value, attributes=event.attributes) + counter.add(event.value, attributes=_clean_attributes(event.attributes)) elif isinstance(event.value, float): up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit) - up_down_counter.add(event.value, attributes=event.attributes) + up_down_counter.add(event.value, attributes=_clean_attributes(event.attributes)) def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter: assert self.meter is not None @@ -189,7 +201,7 @@ class Telemetry(TelemetryBase): unit=unit, description=f"UpDownCounter for {name}", ) - return _GLOBAL_STORAGE["up_down_counters"][name] + return cast(metrics.UpDownCounter, _GLOBAL_STORAGE["up_down_counters"][name]) def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: with self._lock: @@ -217,7 +229,8 @@ class Telemetry(TelemetryBase): if event.payload.parent_span_id: parent_span_id = int(event.payload.parent_span_id, 16) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) - context = trace.set_span_in_context(parent_span) + if parent_span is not None: + context = trace.set_span_in_context(parent_span) elif traceparent: carrier = { "traceparent": traceparent, @@ -228,23 +241,25 @@ class Telemetry(TelemetryBase): span = tracer.start_span( name=event.payload.name, context=context, - attributes=event.attributes or {}, + attributes=_clean_attributes(event.attributes) or {}, ) _GLOBAL_STORAGE["active_spans"][span_id] = span elif isinstance(event.payload, SpanEndPayload): - span = _GLOBAL_STORAGE["active_spans"].get(span_id) - if span: + end_span = cast(trace.Span | None, _GLOBAL_STORAGE["active_spans"].get(span_id)) + if end_span: if event.attributes: - span.set_attributes(event.attributes) + cleaned_attrs = _clean_attributes(event.attributes) + if cleaned_attrs: + end_span.set_attributes(cleaned_attrs) status = ( trace.Status(status_code=trace.StatusCode.OK) if event.payload.status == SpanStatus.OK else trace.Status(status_code=trace.StatusCode.ERROR) ) - span.set_status(status) - span.end() + end_span.set_status(status) + end_span.end() _GLOBAL_STORAGE["active_spans"].pop(span_id, None) else: raise ValueError(f"Unknown structured log event: {event}")