From 9afc52a36a73a748ea107846794177e144043e8e Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 28 Oct 2025 09:47:20 -0700 Subject: [PATCH] fix(mypy): resolve OpenTelemetry typing issues in telemetry.py (#3931) ## Summary Fix all 11 mypy type checking errors in `telemetry.py` without using any type suppressions. **Changes:** - Add type aliases for OpenTelemetry attribute types (`AttributeValue`, `Attributes`) - Create `_clean_attributes()` helper to filter None values from attribute dicts - Use `cast()` for TracerProvider methods (`add_span_processor`, `force_flush`) - Use `cast()` for metric creation methods returning from global storage - Fix variable reuse by renaming `span` to `end_span` in SpanEndPayload branch - Add None check for `parent_span` before `set_span_in_context` **Errors Fixed:** - TracerProvider attribute access: 2 errors - Counter/UpDownCounter/ObservableGauge return types: 3 errors - Attribute dict type mismatches: 4 errors - Span assignment type conflicts: 2 errors **Testing:** ```bash uv run mypy src/llama_stack/core/telemetry/telemetry.py # Success: no issues found ``` **Part of:** Mypy suppression removal plan (Phase 2a/4) **Stack:** - [Phase 1] Add type stubs (#3930) - [Phase 2a] Fix OpenTelemetry types (this PR) - [Phase 2b+] Fix remaining errors (upcoming) - [Phase 3] Remove inline suppressions (upcoming) - [Phase 4] Un-exclude files from mypy (upcoming) --- src/llama_stack/core/telemetry/telemetry.py | 452 ++---------------- .../utils/sqlstore/sqlalchemy_sqlstore.py | 20 +- 2 files changed, 49 insertions(+), 423 deletions(-) diff --git a/src/llama_stack/core/telemetry/telemetry.py b/src/llama_stack/core/telemetry/telemetry.py index dbd10e89c..b5e651572 100644 --- a/src/llama_stack/core/telemetry/telemetry.py +++ b/src/llama_stack/core/telemetry/telemetry.py @@ -6,13 +6,8 @@ import os import threading -from datetime import datetime -from enum import Enum -from typing import ( - Annotated, - Any, - Literal, -) +from collections.abc import Mapping, Sequence +from typing import Any, cast from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter @@ -22,399 +17,22 @@ from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator -from pydantic import BaseModel, Field +# Type alias for OpenTelemetry attribute values (excludes None) +AttributeValue = str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float] +Attributes = Mapping[str, AttributeValue] + +from llama_stack.apis.telemetry import ( + Event, + MetricEvent, + SpanEndPayload, + SpanStartPayload, + SpanStatus, + StructuredLogEvent, + UnstructuredLogEvent, +) +from llama_stack.core.telemetry.tracing import ROOT_SPAN_MARKERS from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import Primitive -from llama_stack.schema_utils import json_schema_type, register_schema - -ROOT_SPAN_MARKERS = ["__root__", "__root_span__"] - - -@json_schema_type -class SpanStatus(Enum): - """The status of a span indicating whether it completed successfully or with an error. - :cvar OK: Span completed successfully without errors - :cvar ERROR: Span completed with an error or failure - """ - - OK = "ok" - ERROR = "error" - - -@json_schema_type -class Span(BaseModel): - """A span representing a single operation within a trace. - :param span_id: Unique identifier for the span - :param trace_id: Unique identifier for the trace this span belongs to - :param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span - :param name: Human-readable name describing the operation this span represents - :param start_time: Timestamp when the operation began - :param end_time: (Optional) Timestamp when the operation finished, if completed - :param attributes: (Optional) Key-value pairs containing additional metadata about the span - """ - - span_id: str - trace_id: str - parent_span_id: str | None = None - name: str - start_time: datetime - end_time: datetime | None = None - attributes: dict[str, Any] | None = Field(default_factory=lambda: {}) - - def set_attribute(self, key: str, value: Any): - if self.attributes is None: - self.attributes = {} - self.attributes[key] = value - - -@json_schema_type -class Trace(BaseModel): - """A trace representing the complete execution path of a request across multiple operations. - :param trace_id: Unique identifier for the trace - :param root_span_id: Unique identifier for the root span that started this trace - :param start_time: Timestamp when the trace began - :param end_time: (Optional) Timestamp when the trace finished, if completed - """ - - trace_id: str - root_span_id: str - start_time: datetime - end_time: datetime | None = None - - -@json_schema_type -class EventType(Enum): - """The type of telemetry event being logged. - :cvar UNSTRUCTURED_LOG: A simple log message with severity level - :cvar STRUCTURED_LOG: A structured log event with typed payload data - :cvar METRIC: A metric measurement with value and unit - """ - - UNSTRUCTURED_LOG = "unstructured_log" - STRUCTURED_LOG = "structured_log" - METRIC = "metric" - - -@json_schema_type -class LogSeverity(Enum): - """The severity level of a log message. - :cvar VERBOSE: Detailed diagnostic information for troubleshooting - :cvar DEBUG: Debug information useful during development - :cvar INFO: General informational messages about normal operation - :cvar WARN: Warning messages about potentially problematic situations - :cvar ERROR: Error messages indicating failures that don't stop execution - :cvar CRITICAL: Critical error messages indicating severe failures - """ - - VERBOSE = "verbose" - DEBUG = "debug" - INFO = "info" - WARN = "warn" - ERROR = "error" - CRITICAL = "critical" - - -class EventCommon(BaseModel): - """Common fields shared by all telemetry events. - :param trace_id: Unique identifier for the trace this event belongs to - :param span_id: Unique identifier for the span this event belongs to - :param timestamp: Timestamp when the event occurred - :param attributes: (Optional) Key-value pairs containing additional metadata about the event - """ - - trace_id: str - span_id: str - timestamp: datetime - attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {}) - - -@json_schema_type -class UnstructuredLogEvent(EventCommon): - """An unstructured log event containing a simple text message. - :param type: Event type identifier set to UNSTRUCTURED_LOG - :param message: The log message text - :param severity: The severity level of the log message - """ - - type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG - message: str - severity: LogSeverity - - -@json_schema_type -class MetricEvent(EventCommon): - """A metric event containing a measured value. - :param type: Event type identifier set to METRIC - :param metric: The name of the metric being measured - :param value: The numeric value of the metric measurement - :param unit: The unit of measurement for the metric value - """ - - type: Literal[EventType.METRIC] = EventType.METRIC - metric: str # this would be an enum - value: int | float - unit: str - - -@json_schema_type -class MetricInResponse(BaseModel): - """A metric value included in API responses. - :param metric: The name of the metric - :param value: The numeric value of the metric - :param unit: (Optional) The unit of measurement for the metric value - """ - - metric: str - value: int | float - unit: str | None = None - - -# This is a short term solution to allow inference API to return metrics -# The ideal way to do this is to have a way for all response types to include metrics -# and all metric events logged to the telemetry API to be included with the response -# To do this, we will need to augment all response types with a metrics field. -# We have hit a blocker from stainless SDK that prevents us from doing this. -# The blocker is that if we were to augment the response types that have a data field -# in them like so -# class ListModelsResponse(BaseModel): -# metrics: Optional[List[MetricEvent]] = None -# data: List[Models] -# ... -# The client SDK will need to access the data by using a .data field, which is not -# ergonomic. Stainless SDK does support unwrapping the response type, but it -# requires that the response type to only have a single field. - -# We will need a way in the client SDK to signal that the metrics are needed -# and if they are needed, the client SDK has to return the full response type -# without unwrapping it. - - -class MetricResponseMixin(BaseModel): - """Mixin class for API responses that can include metrics. - :param metrics: (Optional) List of metrics associated with the API response - """ - - metrics: list[MetricInResponse] | None = None - - -@json_schema_type -class StructuredLogType(Enum): - """The type of structured log event payload. - :cvar SPAN_START: Event indicating the start of a new span - :cvar SPAN_END: Event indicating the completion of a span - """ - - SPAN_START = "span_start" - SPAN_END = "span_end" - - -@json_schema_type -class SpanStartPayload(BaseModel): - """Payload for a span start event. - :param type: Payload type identifier set to SPAN_START - :param name: Human-readable name describing the operation this span represents - :param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span - """ - - type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START - name: str - parent_span_id: str | None = None - - -@json_schema_type -class SpanEndPayload(BaseModel): - """Payload for a span end event. - :param type: Payload type identifier set to SPAN_END - :param status: The final status of the span indicating success or failure - """ - - type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END - status: SpanStatus - - -StructuredLogPayload = Annotated[ - SpanStartPayload | SpanEndPayload, - Field(discriminator="type"), -] -register_schema(StructuredLogPayload, name="StructuredLogPayload") - - -@json_schema_type -class StructuredLogEvent(EventCommon): - """A structured log event containing typed payload data. - :param type: Event type identifier set to STRUCTURED_LOG - :param payload: The structured payload data for the log event - """ - - type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG - payload: StructuredLogPayload - - -Event = Annotated[ - UnstructuredLogEvent | MetricEvent | StructuredLogEvent, - Field(discriminator="type"), -] -register_schema(Event, name="Event") - - -@json_schema_type -class EvalTrace(BaseModel): - """A trace record for evaluation purposes. - :param session_id: Unique identifier for the evaluation session - :param step: The evaluation step or phase identifier - :param input: The input data for the evaluation - :param output: The actual output produced during evaluation - :param expected_output: The expected output for comparison during evaluation - """ - - session_id: str - step: str - input: str - output: str - expected_output: str - - -@json_schema_type -class SpanWithStatus(Span): - """A span that includes status information. - :param status: (Optional) The current status of the span - """ - - status: SpanStatus | None = None - - -@json_schema_type -class QueryConditionOp(Enum): - """Comparison operators for query conditions. - :cvar EQ: Equal to comparison - :cvar NE: Not equal to comparison - :cvar GT: Greater than comparison - :cvar LT: Less than comparison - """ - - EQ = "eq" - NE = "ne" - GT = "gt" - LT = "lt" - - -@json_schema_type -class QueryCondition(BaseModel): - """A condition for filtering query results. - :param key: The attribute key to filter on - :param op: The comparison operator to apply - :param value: The value to compare against - """ - - key: str - op: QueryConditionOp - value: Any - - -class QueryTracesResponse(BaseModel): - """Response containing a list of traces. - :param data: List of traces matching the query criteria - """ - - data: list[Trace] - - -class QuerySpansResponse(BaseModel): - """Response containing a list of spans. - :param data: List of spans matching the query criteria - """ - - data: list[Span] - - -class QuerySpanTreeResponse(BaseModel): - """Response containing a tree structure of spans. - :param data: Dictionary mapping span IDs to spans with status information - """ - - data: dict[str, SpanWithStatus] - - -class MetricQueryType(Enum): - """The type of metric query to perform. - :cvar RANGE: Query metrics over a time range - :cvar INSTANT: Query metrics at a specific point in time - """ - - RANGE = "range" - INSTANT = "instant" - - -class MetricLabelOperator(Enum): - """Operators for matching metric labels. - :cvar EQUALS: Label value must equal the specified value - :cvar NOT_EQUALS: Label value must not equal the specified value - :cvar REGEX_MATCH: Label value must match the specified regular expression - :cvar REGEX_NOT_MATCH: Label value must not match the specified regular expression - """ - - EQUALS = "=" - NOT_EQUALS = "!=" - REGEX_MATCH = "=~" - REGEX_NOT_MATCH = "!~" - - -class MetricLabelMatcher(BaseModel): - """A matcher for filtering metrics by label values. - :param name: The name of the label to match - :param value: The value to match against - :param operator: The comparison operator to use for matching - """ - - name: str - value: str - operator: MetricLabelOperator = MetricLabelOperator.EQUALS - - -@json_schema_type -class MetricLabel(BaseModel): - """A label associated with a metric. - :param name: The name of the label - :param value: The value of the label - """ - - name: str - value: str - - -@json_schema_type -class MetricDataPoint(BaseModel): - """A single data point in a metric time series. - :param timestamp: Unix timestamp when the metric value was recorded - :param value: The numeric value of the metric at this timestamp - """ - - timestamp: int - value: float - unit: str - - -@json_schema_type -class MetricSeries(BaseModel): - """A time series of metric data points. - :param metric: The name of the metric - :param labels: List of labels associated with this metric series - :param values: List of data points in chronological order - """ - - metric: str - labels: list[MetricLabel] - values: list[MetricDataPoint] - - -class QueryMetricsResponse(BaseModel): - """Response containing metric time series data. - :param data: List of metric series matching the query criteria - """ - - data: list[MetricSeries] - _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = { "active_spans": {}, @@ -428,6 +46,13 @@ _TRACER_PROVIDER = None logger = get_logger(name=__name__, category="telemetry") +def _clean_attributes(attrs: dict[str, Any] | None) -> Attributes | None: + """Remove None values from attributes dict to match OpenTelemetry's expected type.""" + if attrs is None: + return None + return {k: v for k, v in attrs.items() if v is not None} + + def is_tracing_enabled(tracer): with tracer.start_as_current_span("check_tracing") as span: return span.is_recording() @@ -456,7 +81,7 @@ class Telemetry: # https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter span_exporter = OTLPSpanExporter() span_processor = BatchSpanProcessor(span_exporter) - trace.get_tracer_provider().add_span_processor(span_processor) + cast(TracerProvider, trace.get_tracer_provider()).add_span_processor(span_processor) metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter()) metric_provider = MeterProvider(metric_readers=[metric_reader]) @@ -474,7 +99,7 @@ class Telemetry: async def shutdown(self) -> None: if self.is_otel_endpoint_set: - trace.get_tracer_provider().force_flush() + cast(TracerProvider, trace.get_tracer_provider()).force_flush() async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: if isinstance(event, UnstructuredLogEvent): @@ -515,7 +140,7 @@ class Telemetry: unit=unit, description=f"Counter for {name}", ) - return _GLOBAL_STORAGE["counters"][name] + return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name]) def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: assert self.meter is not None @@ -525,7 +150,7 @@ class Telemetry: unit=unit, description=f"Gauge for {name}", ) - return _GLOBAL_STORAGE["gauges"][name] + return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name]) def _log_metric(self, event: MetricEvent) -> None: # Add metric as an event to the current span @@ -560,10 +185,10 @@ class Telemetry: return if isinstance(event.value, int): counter = self._get_or_create_counter(event.metric, event.unit) - counter.add(event.value, attributes=event.attributes) + counter.add(event.value, attributes=_clean_attributes(event.attributes)) elif isinstance(event.value, float): up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit) - up_down_counter.add(event.value, attributes=event.attributes) + up_down_counter.add(event.value, attributes=_clean_attributes(event.attributes)) def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter: assert self.meter is not None @@ -573,7 +198,7 @@ class Telemetry: unit=unit, description=f"UpDownCounter for {name}", ) - return _GLOBAL_STORAGE["up_down_counters"][name] + return cast(metrics.UpDownCounter, _GLOBAL_STORAGE["up_down_counters"][name]) def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: with self._lock: @@ -601,7 +226,8 @@ class Telemetry: if event.payload.parent_span_id: parent_span_id = int(event.payload.parent_span_id, 16) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) - context = trace.set_span_in_context(parent_span) + if parent_span is not None: + context = trace.set_span_in_context(parent_span) elif traceparent: carrier = { "traceparent": traceparent, @@ -612,23 +238,25 @@ class Telemetry: span = tracer.start_span( name=event.payload.name, context=context, - attributes=event.attributes or {}, + attributes=_clean_attributes(event.attributes) or {}, ) _GLOBAL_STORAGE["active_spans"][span_id] = span elif isinstance(event.payload, SpanEndPayload): - span = _GLOBAL_STORAGE["active_spans"].get(span_id) - if span: + end_span = cast(trace.Span | None, _GLOBAL_STORAGE["active_spans"].get(span_id)) + if end_span: if event.attributes: - span.set_attributes(event.attributes) + cleaned_attrs = _clean_attributes(event.attributes) + if cleaned_attrs: + end_span.set_attributes(cleaned_attrs) status = ( trace.Status(status_code=trace.StatusCode.OK) if event.payload.status == SpanStatus.OK else trace.Status(status_code=trace.StatusCode.ERROR) ) - span.set_status(status) - span.end() + end_span.set_status(status) + end_span.end() _GLOBAL_STORAGE["active_spans"].pop(span_id, None) else: raise ValueError(f"Unknown structured log event: {event}") diff --git a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index c1ccd73dd..1bd364d43 100644 --- a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from collections.abc import Mapping, Sequence -from typing import Any, Literal +from typing import Any, Literal, cast from sqlalchemy import ( JSON, @@ -55,17 +55,17 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement: raise ValueError(f"Operator mapping must have a single operator, got: {value}") op, operand = next(iter(value.items())) if op == "==" or op == "=": - return column == operand + return cast(ColumnElement[Any], column == operand) if op == ">": - return column > operand + return cast(ColumnElement[Any], column > operand) if op == "<": - return column < operand + return cast(ColumnElement[Any], column < operand) if op == ">=": - return column >= operand + return cast(ColumnElement[Any], column >= operand) if op == "<=": - return column <= operand + return cast(ColumnElement[Any], column <= operand) raise ValueError(f"Unsupported operator '{op}' in where mapping") - return column == value + return cast(ColumnElement[Any], column == value) class SqlAlchemySqlStoreImpl(SqlStore): @@ -210,10 +210,8 @@ class SqlAlchemySqlStoreImpl(SqlStore): query = query.limit(fetch_limit) result = await session.execute(query) - if result.rowcount == 0: - rows = [] - else: - rows = [dict(row._mapping) for row in result] + # Iterate directly - if no rows, list comprehension yields empty list + rows = [dict(row._mapping) for row in result] # Always return pagination result has_more = False