fix(mypy): resolve OpenTelemetry typing issues in telemetry.py (#3931)

## Summary

Fix all 11 mypy type checking errors in `telemetry.py` without using any
type suppressions.

**Changes:**
- Add type aliases for OpenTelemetry attribute types (`AttributeValue`,
`Attributes`)
- Create `_clean_attributes()` helper to filter None values from
attribute dicts
- Use `cast()` for TracerProvider methods (`add_span_processor`,
`force_flush`)
- Use `cast()` for metric creation methods returning from global storage
- Fix variable reuse by renaming `span` to `end_span` in SpanEndPayload
branch
- Add None check for `parent_span` before `set_span_in_context`

**Errors Fixed:**
- TracerProvider attribute access: 2 errors
- Counter/UpDownCounter/ObservableGauge return types: 3 errors
- Attribute dict type mismatches: 4 errors
- Span assignment type conflicts: 2 errors

**Testing:**
```bash
uv run mypy src/llama_stack/core/telemetry/telemetry.py
# Success: no issues found
```

**Part of:** Mypy suppression removal plan (Phase 2a/4)

**Stack:**
- [Phase 1] Add type stubs (#3930)
- [Phase 2a] Fix OpenTelemetry types (this PR)
- [Phase 2b+] Fix remaining errors (upcoming)
- [Phase 3] Remove inline suppressions (upcoming)
- [Phase 4] Un-exclude files from mypy (upcoming)
This commit is contained in:
Ashwin Bharambe 2025-10-28 09:47:20 -07:00 committed by GitHub
parent 5598f61e12
commit 9afc52a36a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 49 additions and 423 deletions

View file

@ -6,13 +6,8 @@
import os
import threading
from datetime import datetime
from enum import Enum
from typing import (
Annotated,
Any,
Literal,
)
from collections.abc import Mapping, Sequence
from typing import Any, cast
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@ -22,399 +17,22 @@ from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from pydantic import BaseModel, Field
# Type alias for OpenTelemetry attribute values (excludes None)
AttributeValue = str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]
Attributes = Mapping[str, AttributeValue]
from llama_stack.apis.telemetry import (
Event,
MetricEvent,
SpanEndPayload,
SpanStartPayload,
SpanStatus,
StructuredLogEvent,
UnstructuredLogEvent,
)
from llama_stack.core.telemetry.tracing import ROOT_SPAN_MARKERS
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import Primitive
from llama_stack.schema_utils import json_schema_type, register_schema
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
@json_schema_type
class SpanStatus(Enum):
"""The status of a span indicating whether it completed successfully or with an error.
:cvar OK: Span completed successfully without errors
:cvar ERROR: Span completed with an error or failure
"""
OK = "ok"
ERROR = "error"
@json_schema_type
class Span(BaseModel):
"""A span representing a single operation within a trace.
:param span_id: Unique identifier for the span
:param trace_id: Unique identifier for the trace this span belongs to
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
:param name: Human-readable name describing the operation this span represents
:param start_time: Timestamp when the operation began
:param end_time: (Optional) Timestamp when the operation finished, if completed
:param attributes: (Optional) Key-value pairs containing additional metadata about the span
"""
span_id: str
trace_id: str
parent_span_id: str | None = None
name: str
start_time: datetime
end_time: datetime | None = None
attributes: dict[str, Any] | None = Field(default_factory=lambda: {})
def set_attribute(self, key: str, value: Any):
if self.attributes is None:
self.attributes = {}
self.attributes[key] = value
@json_schema_type
class Trace(BaseModel):
"""A trace representing the complete execution path of a request across multiple operations.
:param trace_id: Unique identifier for the trace
:param root_span_id: Unique identifier for the root span that started this trace
:param start_time: Timestamp when the trace began
:param end_time: (Optional) Timestamp when the trace finished, if completed
"""
trace_id: str
root_span_id: str
start_time: datetime
end_time: datetime | None = None
@json_schema_type
class EventType(Enum):
"""The type of telemetry event being logged.
:cvar UNSTRUCTURED_LOG: A simple log message with severity level
:cvar STRUCTURED_LOG: A structured log event with typed payload data
:cvar METRIC: A metric measurement with value and unit
"""
UNSTRUCTURED_LOG = "unstructured_log"
STRUCTURED_LOG = "structured_log"
METRIC = "metric"
@json_schema_type
class LogSeverity(Enum):
"""The severity level of a log message.
:cvar VERBOSE: Detailed diagnostic information for troubleshooting
:cvar DEBUG: Debug information useful during development
:cvar INFO: General informational messages about normal operation
:cvar WARN: Warning messages about potentially problematic situations
:cvar ERROR: Error messages indicating failures that don't stop execution
:cvar CRITICAL: Critical error messages indicating severe failures
"""
VERBOSE = "verbose"
DEBUG = "debug"
INFO = "info"
WARN = "warn"
ERROR = "error"
CRITICAL = "critical"
class EventCommon(BaseModel):
"""Common fields shared by all telemetry events.
:param trace_id: Unique identifier for the trace this event belongs to
:param span_id: Unique identifier for the span this event belongs to
:param timestamp: Timestamp when the event occurred
:param attributes: (Optional) Key-value pairs containing additional metadata about the event
"""
trace_id: str
span_id: str
timestamp: datetime
attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {})
@json_schema_type
class UnstructuredLogEvent(EventCommon):
"""An unstructured log event containing a simple text message.
:param type: Event type identifier set to UNSTRUCTURED_LOG
:param message: The log message text
:param severity: The severity level of the log message
"""
type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG
message: str
severity: LogSeverity
@json_schema_type
class MetricEvent(EventCommon):
"""A metric event containing a measured value.
:param type: Event type identifier set to METRIC
:param metric: The name of the metric being measured
:param value: The numeric value of the metric measurement
:param unit: The unit of measurement for the metric value
"""
type: Literal[EventType.METRIC] = EventType.METRIC
metric: str # this would be an enum
value: int | float
unit: str
@json_schema_type
class MetricInResponse(BaseModel):
"""A metric value included in API responses.
:param metric: The name of the metric
:param value: The numeric value of the metric
:param unit: (Optional) The unit of measurement for the metric value
"""
metric: str
value: int | float
unit: str | None = None
# This is a short term solution to allow inference API to return metrics
# The ideal way to do this is to have a way for all response types to include metrics
# and all metric events logged to the telemetry API to be included with the response
# To do this, we will need to augment all response types with a metrics field.
# We have hit a blocker from stainless SDK that prevents us from doing this.
# The blocker is that if we were to augment the response types that have a data field
# in them like so
# class ListModelsResponse(BaseModel):
# metrics: Optional[List[MetricEvent]] = None
# data: List[Models]
# ...
# The client SDK will need to access the data by using a .data field, which is not
# ergonomic. Stainless SDK does support unwrapping the response type, but it
# requires that the response type to only have a single field.
# We will need a way in the client SDK to signal that the metrics are needed
# and if they are needed, the client SDK has to return the full response type
# without unwrapping it.
class MetricResponseMixin(BaseModel):
"""Mixin class for API responses that can include metrics.
:param metrics: (Optional) List of metrics associated with the API response
"""
metrics: list[MetricInResponse] | None = None
@json_schema_type
class StructuredLogType(Enum):
"""The type of structured log event payload.
:cvar SPAN_START: Event indicating the start of a new span
:cvar SPAN_END: Event indicating the completion of a span
"""
SPAN_START = "span_start"
SPAN_END = "span_end"
@json_schema_type
class SpanStartPayload(BaseModel):
"""Payload for a span start event.
:param type: Payload type identifier set to SPAN_START
:param name: Human-readable name describing the operation this span represents
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
"""
type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START
name: str
parent_span_id: str | None = None
@json_schema_type
class SpanEndPayload(BaseModel):
"""Payload for a span end event.
:param type: Payload type identifier set to SPAN_END
:param status: The final status of the span indicating success or failure
"""
type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END
status: SpanStatus
StructuredLogPayload = Annotated[
SpanStartPayload | SpanEndPayload,
Field(discriminator="type"),
]
register_schema(StructuredLogPayload, name="StructuredLogPayload")
@json_schema_type
class StructuredLogEvent(EventCommon):
"""A structured log event containing typed payload data.
:param type: Event type identifier set to STRUCTURED_LOG
:param payload: The structured payload data for the log event
"""
type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG
payload: StructuredLogPayload
Event = Annotated[
UnstructuredLogEvent | MetricEvent | StructuredLogEvent,
Field(discriminator="type"),
]
register_schema(Event, name="Event")
@json_schema_type
class EvalTrace(BaseModel):
"""A trace record for evaluation purposes.
:param session_id: Unique identifier for the evaluation session
:param step: The evaluation step or phase identifier
:param input: The input data for the evaluation
:param output: The actual output produced during evaluation
:param expected_output: The expected output for comparison during evaluation
"""
session_id: str
step: str
input: str
output: str
expected_output: str
@json_schema_type
class SpanWithStatus(Span):
"""A span that includes status information.
:param status: (Optional) The current status of the span
"""
status: SpanStatus | None = None
@json_schema_type
class QueryConditionOp(Enum):
"""Comparison operators for query conditions.
:cvar EQ: Equal to comparison
:cvar NE: Not equal to comparison
:cvar GT: Greater than comparison
:cvar LT: Less than comparison
"""
EQ = "eq"
NE = "ne"
GT = "gt"
LT = "lt"
@json_schema_type
class QueryCondition(BaseModel):
"""A condition for filtering query results.
:param key: The attribute key to filter on
:param op: The comparison operator to apply
:param value: The value to compare against
"""
key: str
op: QueryConditionOp
value: Any
class QueryTracesResponse(BaseModel):
"""Response containing a list of traces.
:param data: List of traces matching the query criteria
"""
data: list[Trace]
class QuerySpansResponse(BaseModel):
"""Response containing a list of spans.
:param data: List of spans matching the query criteria
"""
data: list[Span]
class QuerySpanTreeResponse(BaseModel):
"""Response containing a tree structure of spans.
:param data: Dictionary mapping span IDs to spans with status information
"""
data: dict[str, SpanWithStatus]
class MetricQueryType(Enum):
"""The type of metric query to perform.
:cvar RANGE: Query metrics over a time range
:cvar INSTANT: Query metrics at a specific point in time
"""
RANGE = "range"
INSTANT = "instant"
class MetricLabelOperator(Enum):
"""Operators for matching metric labels.
:cvar EQUALS: Label value must equal the specified value
:cvar NOT_EQUALS: Label value must not equal the specified value
:cvar REGEX_MATCH: Label value must match the specified regular expression
:cvar REGEX_NOT_MATCH: Label value must not match the specified regular expression
"""
EQUALS = "="
NOT_EQUALS = "!="
REGEX_MATCH = "=~"
REGEX_NOT_MATCH = "!~"
class MetricLabelMatcher(BaseModel):
"""A matcher for filtering metrics by label values.
:param name: The name of the label to match
:param value: The value to match against
:param operator: The comparison operator to use for matching
"""
name: str
value: str
operator: MetricLabelOperator = MetricLabelOperator.EQUALS
@json_schema_type
class MetricLabel(BaseModel):
"""A label associated with a metric.
:param name: The name of the label
:param value: The value of the label
"""
name: str
value: str
@json_schema_type
class MetricDataPoint(BaseModel):
"""A single data point in a metric time series.
:param timestamp: Unix timestamp when the metric value was recorded
:param value: The numeric value of the metric at this timestamp
"""
timestamp: int
value: float
unit: str
@json_schema_type
class MetricSeries(BaseModel):
"""A time series of metric data points.
:param metric: The name of the metric
:param labels: List of labels associated with this metric series
:param values: List of data points in chronological order
"""
metric: str
labels: list[MetricLabel]
values: list[MetricDataPoint]
class QueryMetricsResponse(BaseModel):
"""Response containing metric time series data.
:param data: List of metric series matching the query criteria
"""
data: list[MetricSeries]
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"active_spans": {},
@ -428,6 +46,13 @@ _TRACER_PROVIDER = None
logger = get_logger(name=__name__, category="telemetry")
def _clean_attributes(attrs: dict[str, Any] | None) -> Attributes | None:
"""Remove None values from attributes dict to match OpenTelemetry's expected type."""
if attrs is None:
return None
return {k: v for k, v in attrs.items() if v is not None}
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording()
@ -456,7 +81,7 @@ class Telemetry:
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
span_exporter = OTLPSpanExporter()
span_processor = BatchSpanProcessor(span_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
cast(TracerProvider, trace.get_tracer_provider()).add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
metric_provider = MeterProvider(metric_readers=[metric_reader])
@ -474,7 +99,7 @@ class Telemetry:
async def shutdown(self) -> None:
if self.is_otel_endpoint_set:
trace.get_tracer_provider().force_flush()
cast(TracerProvider, trace.get_tracer_provider()).force_flush()
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
if isinstance(event, UnstructuredLogEvent):
@ -515,7 +140,7 @@ class Telemetry:
unit=unit,
description=f"Counter for {name}",
)
return _GLOBAL_STORAGE["counters"][name]
return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name])
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
assert self.meter is not None
@ -525,7 +150,7 @@ class Telemetry:
unit=unit,
description=f"Gauge for {name}",
)
return _GLOBAL_STORAGE["gauges"][name]
return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
def _log_metric(self, event: MetricEvent) -> None:
# Add metric as an event to the current span
@ -560,10 +185,10 @@ class Telemetry:
return
if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)
counter.add(event.value, attributes=_clean_attributes(event.attributes))
elif isinstance(event.value, float):
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
up_down_counter.add(event.value, attributes=event.attributes)
up_down_counter.add(event.value, attributes=_clean_attributes(event.attributes))
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
assert self.meter is not None
@ -573,7 +198,7 @@ class Telemetry:
unit=unit,
description=f"UpDownCounter for {name}",
)
return _GLOBAL_STORAGE["up_down_counters"][name]
return cast(metrics.UpDownCounter, _GLOBAL_STORAGE["up_down_counters"][name])
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
@ -601,7 +226,8 @@ class Telemetry:
if event.payload.parent_span_id:
parent_span_id = int(event.payload.parent_span_id, 16)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
context = trace.set_span_in_context(parent_span)
if parent_span is not None:
context = trace.set_span_in_context(parent_span)
elif traceparent:
carrier = {
"traceparent": traceparent,
@ -612,23 +238,25 @@ class Telemetry:
span = tracer.start_span(
name=event.payload.name,
context=context,
attributes=event.attributes or {},
attributes=_clean_attributes(event.attributes) or {},
)
_GLOBAL_STORAGE["active_spans"][span_id] = span
elif isinstance(event.payload, SpanEndPayload):
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
end_span = cast(trace.Span | None, _GLOBAL_STORAGE["active_spans"].get(span_id))
if end_span:
if event.attributes:
span.set_attributes(event.attributes)
cleaned_attrs = _clean_attributes(event.attributes)
if cleaned_attrs:
end_span.set_attributes(cleaned_attrs)
status = (
trace.Status(status_code=trace.StatusCode.OK)
if event.payload.status == SpanStatus.OK
else trace.Status(status_code=trace.StatusCode.ERROR)
)
span.set_status(status)
span.end()
end_span.set_status(status)
end_span.end()
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
else:
raise ValueError(f"Unknown structured log event: {event}")

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from typing import Any, Literal, cast
from sqlalchemy import (
JSON,
@ -55,17 +55,17 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
raise ValueError(f"Operator mapping must have a single operator, got: {value}")
op, operand = next(iter(value.items()))
if op == "==" or op == "=":
return column == operand
return cast(ColumnElement[Any], column == operand)
if op == ">":
return column > operand
return cast(ColumnElement[Any], column > operand)
if op == "<":
return column < operand
return cast(ColumnElement[Any], column < operand)
if op == ">=":
return column >= operand
return cast(ColumnElement[Any], column >= operand)
if op == "<=":
return column <= operand
return cast(ColumnElement[Any], column <= operand)
raise ValueError(f"Unsupported operator '{op}' in where mapping")
return column == value
return cast(ColumnElement[Any], column == value)
class SqlAlchemySqlStoreImpl(SqlStore):
@ -210,10 +210,8 @@ class SqlAlchemySqlStoreImpl(SqlStore):
query = query.limit(fetch_limit)
result = await session.execute(query)
if result.rowcount == 0:
rows = []
else:
rows = [dict(row._mapping) for row in result]
# Iterate directly - if no rows, list comprehension yields empty list
rows = [dict(row._mapping) for row in result]
# Always return pagination result
has_more = False