diff --git a/src/llama_stack/core/telemetry/telemetry.py b/src/llama_stack/core/telemetry/telemetry.py index dbd10e89c..1ba43724d 100644 --- a/src/llama_stack/core/telemetry/telemetry.py +++ b/src/llama_stack/core/telemetry/telemetry.py @@ -6,12 +6,14 @@ import os import threading +from collections.abc import Mapping, Sequence from datetime import datetime from enum import Enum from typing import ( Annotated, Any, Literal, + cast, ) from opentelemetry import metrics, trace @@ -30,6 +32,10 @@ from llama_stack.schema_utils import json_schema_type, register_schema ROOT_SPAN_MARKERS = ["__root__", "__root_span__"] +# Type alias for OpenTelemetry attribute values (excludes None) +AttributeValue = str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float] +Attributes = Mapping[str, AttributeValue] + @json_schema_type class SpanStatus(Enum): @@ -428,6 +434,13 @@ _TRACER_PROVIDER = None logger = get_logger(name=__name__, category="telemetry") +def _clean_attributes(attrs: dict[str, Any] | None) -> Attributes | None: + """Remove None values from attributes dict to match OpenTelemetry's expected type.""" + if attrs is None: + return None + return {k: v for k, v in attrs.items() if v is not None} + + def is_tracing_enabled(tracer): with tracer.start_as_current_span("check_tracing") as span: return span.is_recording() @@ -456,7 +469,7 @@ class Telemetry: # https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter span_exporter = OTLPSpanExporter() span_processor = BatchSpanProcessor(span_exporter) - trace.get_tracer_provider().add_span_processor(span_processor) + cast(TracerProvider, trace.get_tracer_provider()).add_span_processor(span_processor) metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter()) metric_provider = MeterProvider(metric_readers=[metric_reader]) @@ -474,7 +487,7 @@ class Telemetry: async def shutdown(self) -> None: if self.is_otel_endpoint_set: - trace.get_tracer_provider().force_flush() + cast(TracerProvider, trace.get_tracer_provider()).force_flush() async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: if isinstance(event, UnstructuredLogEvent): @@ -515,7 +528,7 @@ class Telemetry: unit=unit, description=f"Counter for {name}", ) - return _GLOBAL_STORAGE["counters"][name] + return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name]) def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: assert self.meter is not None @@ -525,7 +538,7 @@ class Telemetry: unit=unit, description=f"Gauge for {name}", ) - return _GLOBAL_STORAGE["gauges"][name] + return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name]) def _log_metric(self, event: MetricEvent) -> None: # Add metric as an event to the current span @@ -560,10 +573,10 @@ class Telemetry: return if isinstance(event.value, int): counter = self._get_or_create_counter(event.metric, event.unit) - counter.add(event.value, attributes=event.attributes) + counter.add(event.value, attributes=_clean_attributes(event.attributes)) elif isinstance(event.value, float): up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit) - up_down_counter.add(event.value, attributes=event.attributes) + up_down_counter.add(event.value, attributes=_clean_attributes(event.attributes)) def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter: assert self.meter is not None @@ -573,7 +586,7 @@ class Telemetry: unit=unit, description=f"UpDownCounter for {name}", ) - return _GLOBAL_STORAGE["up_down_counters"][name] + return cast(metrics.UpDownCounter, _GLOBAL_STORAGE["up_down_counters"][name]) def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: with self._lock: @@ -601,7 +614,8 @@ class Telemetry: if event.payload.parent_span_id: parent_span_id = int(event.payload.parent_span_id, 16) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) - context = trace.set_span_in_context(parent_span) + if parent_span: + context = trace.set_span_in_context(parent_span) elif traceparent: carrier = { "traceparent": traceparent, @@ -612,15 +626,17 @@ class Telemetry: span = tracer.start_span( name=event.payload.name, context=context, - attributes=event.attributes or {}, + attributes=_clean_attributes(event.attributes), ) _GLOBAL_STORAGE["active_spans"][span_id] = span elif isinstance(event.payload, SpanEndPayload): - span = _GLOBAL_STORAGE["active_spans"].get(span_id) + span = _GLOBAL_STORAGE["active_spans"].get(span_id) # type: ignore[assignment] if span: if event.attributes: - span.set_attributes(event.attributes) + cleaned_attrs = _clean_attributes(event.attributes) + if cleaned_attrs: + span.set_attributes(cleaned_attrs) status = ( trace.Status(status_code=trace.StatusCode.OK) diff --git a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index c1ccd73dd..1bd364d43 100644 --- a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from collections.abc import Mapping, Sequence -from typing import Any, Literal +from typing import Any, Literal, cast from sqlalchemy import ( JSON, @@ -55,17 +55,17 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement: raise ValueError(f"Operator mapping must have a single operator, got: {value}") op, operand = next(iter(value.items())) if op == "==" or op == "=": - return column == operand + return cast(ColumnElement[Any], column == operand) if op == ">": - return column > operand + return cast(ColumnElement[Any], column > operand) if op == "<": - return column < operand + return cast(ColumnElement[Any], column < operand) if op == ">=": - return column >= operand + return cast(ColumnElement[Any], column >= operand) if op == "<=": - return column <= operand + return cast(ColumnElement[Any], column <= operand) raise ValueError(f"Unsupported operator '{op}' in where mapping") - return column == value + return cast(ColumnElement[Any], column == value) class SqlAlchemySqlStoreImpl(SqlStore): @@ -210,10 +210,8 @@ class SqlAlchemySqlStoreImpl(SqlStore): query = query.limit(fetch_limit) result = await session.execute(query) - if result.rowcount == 0: - rows = [] - else: - rows = [dict(row._mapping) for row in result] + # Iterate directly - if no rows, list comprehension yields empty list + rows = [dict(row._mapping) for row in result] # Always return pagination result has_more = False