fix(mypy): resolve OpenTelemetry typing issues in telemetry.py (#3943)

Fixes mypy type errors in OpenTelemetry integration:
- Add type aliases for AttributeValue and Attributes
- Add helper to filter None values from attributes (OpenTelemetry
doesn't accept None)
- Cast metric and tracer objects to proper types
- Update imports after refactoring

No functional changes.
This commit is contained in:
Ashwin Bharambe 2025-10-28 10:10:18 -07:00 committed by GitHub
parent 85887d724f
commit 4a2ea278c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 36 additions and 22 deletions

View file

@ -6,12 +6,14 @@
import os import os
import threading import threading
from collections.abc import Mapping, Sequence
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import ( from typing import (
Annotated, Annotated,
Any, Any,
Literal, Literal,
cast,
) )
from opentelemetry import metrics, trace from opentelemetry import metrics, trace
@ -30,6 +32,10 @@ from llama_stack.schema_utils import json_schema_type, register_schema
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"] ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
# Type alias for OpenTelemetry attribute values (excludes None)
AttributeValue = str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]
Attributes = Mapping[str, AttributeValue]
@json_schema_type @json_schema_type
class SpanStatus(Enum): class SpanStatus(Enum):
@ -428,6 +434,13 @@ _TRACER_PROVIDER = None
logger = get_logger(name=__name__, category="telemetry") logger = get_logger(name=__name__, category="telemetry")
def _clean_attributes(attrs: dict[str, Any] | None) -> Attributes | None:
"""Remove None values from attributes dict to match OpenTelemetry's expected type."""
if attrs is None:
return None
return {k: v for k, v in attrs.items() if v is not None}
def is_tracing_enabled(tracer): def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span: with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording() return span.is_recording()
@ -456,7 +469,7 @@ class Telemetry:
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter # https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
span_exporter = OTLPSpanExporter() span_exporter = OTLPSpanExporter()
span_processor = BatchSpanProcessor(span_exporter) span_processor = BatchSpanProcessor(span_exporter)
trace.get_tracer_provider().add_span_processor(span_processor) cast(TracerProvider, trace.get_tracer_provider()).add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter()) metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
metric_provider = MeterProvider(metric_readers=[metric_reader]) metric_provider = MeterProvider(metric_readers=[metric_reader])
@ -474,7 +487,7 @@ class Telemetry:
async def shutdown(self) -> None: async def shutdown(self) -> None:
if self.is_otel_endpoint_set: if self.is_otel_endpoint_set:
trace.get_tracer_provider().force_flush() cast(TracerProvider, trace.get_tracer_provider()).force_flush()
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
if isinstance(event, UnstructuredLogEvent): if isinstance(event, UnstructuredLogEvent):
@ -515,7 +528,7 @@ class Telemetry:
unit=unit, unit=unit,
description=f"Counter for {name}", description=f"Counter for {name}",
) )
return _GLOBAL_STORAGE["counters"][name] return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name])
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
assert self.meter is not None assert self.meter is not None
@ -525,7 +538,7 @@ class Telemetry:
unit=unit, unit=unit,
description=f"Gauge for {name}", description=f"Gauge for {name}",
) )
return _GLOBAL_STORAGE["gauges"][name] return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
def _log_metric(self, event: MetricEvent) -> None: def _log_metric(self, event: MetricEvent) -> None:
# Add metric as an event to the current span # Add metric as an event to the current span
@ -560,10 +573,10 @@ class Telemetry:
return return
if isinstance(event.value, int): if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit) counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes) counter.add(event.value, attributes=_clean_attributes(event.attributes))
elif isinstance(event.value, float): elif isinstance(event.value, float):
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit) up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
up_down_counter.add(event.value, attributes=event.attributes) up_down_counter.add(event.value, attributes=_clean_attributes(event.attributes))
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter: def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
assert self.meter is not None assert self.meter is not None
@ -573,7 +586,7 @@ class Telemetry:
unit=unit, unit=unit,
description=f"UpDownCounter for {name}", description=f"UpDownCounter for {name}",
) )
return _GLOBAL_STORAGE["up_down_counters"][name] return cast(metrics.UpDownCounter, _GLOBAL_STORAGE["up_down_counters"][name])
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock: with self._lock:
@ -601,6 +614,7 @@ class Telemetry:
if event.payload.parent_span_id: if event.payload.parent_span_id:
parent_span_id = int(event.payload.parent_span_id, 16) parent_span_id = int(event.payload.parent_span_id, 16)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
if parent_span:
context = trace.set_span_in_context(parent_span) context = trace.set_span_in_context(parent_span)
elif traceparent: elif traceparent:
carrier = { carrier = {
@ -612,15 +626,17 @@ class Telemetry:
span = tracer.start_span( span = tracer.start_span(
name=event.payload.name, name=event.payload.name,
context=context, context=context,
attributes=event.attributes or {}, attributes=_clean_attributes(event.attributes),
) )
_GLOBAL_STORAGE["active_spans"][span_id] = span _GLOBAL_STORAGE["active_spans"][span_id] = span
elif isinstance(event.payload, SpanEndPayload): elif isinstance(event.payload, SpanEndPayload):
span = _GLOBAL_STORAGE["active_spans"].get(span_id) span = _GLOBAL_STORAGE["active_spans"].get(span_id) # type: ignore[assignment]
if span: if span:
if event.attributes: if event.attributes:
span.set_attributes(event.attributes) cleaned_attrs = _clean_attributes(event.attributes)
if cleaned_attrs:
span.set_attributes(cleaned_attrs)
status = ( status = (
trace.Status(status_code=trace.StatusCode.OK) trace.Status(status_code=trace.StatusCode.OK)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Literal from typing import Any, Literal, cast
from sqlalchemy import ( from sqlalchemy import (
JSON, JSON,
@ -55,17 +55,17 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
raise ValueError(f"Operator mapping must have a single operator, got: {value}") raise ValueError(f"Operator mapping must have a single operator, got: {value}")
op, operand = next(iter(value.items())) op, operand = next(iter(value.items()))
if op == "==" or op == "=": if op == "==" or op == "=":
return column == operand return cast(ColumnElement[Any], column == operand)
if op == ">": if op == ">":
return column > operand return cast(ColumnElement[Any], column > operand)
if op == "<": if op == "<":
return column < operand return cast(ColumnElement[Any], column < operand)
if op == ">=": if op == ">=":
return column >= operand return cast(ColumnElement[Any], column >= operand)
if op == "<=": if op == "<=":
return column <= operand return cast(ColumnElement[Any], column <= operand)
raise ValueError(f"Unsupported operator '{op}' in where mapping") raise ValueError(f"Unsupported operator '{op}' in where mapping")
return column == value return cast(ColumnElement[Any], column == value)
class SqlAlchemySqlStoreImpl(SqlStore): class SqlAlchemySqlStoreImpl(SqlStore):
@ -210,9 +210,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
query = query.limit(fetch_limit) query = query.limit(fetch_limit)
result = await session.execute(query) result = await session.execute(query)
if result.rowcount == 0: # Iterate directly - if no rows, list comprehension yields empty list
rows = []
else:
rows = [dict(row._mapping) for row in result] rows = [dict(row._mapping) for row in result]
# Always return pagination result # Always return pagination result