mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix(mypy): resolve OpenTelemetry typing issues in telemetry.py (#3943)
Fixes mypy type errors in OpenTelemetry integration: - Add type aliases for AttributeValue and Attributes - Add helper to filter None values from attributes (OpenTelemetry doesn't accept None) - Cast metric and tracer objects to proper types - Update imports after refactoring No functional changes.
This commit is contained in:
parent
85887d724f
commit
4a2ea278c5
2 changed files with 36 additions and 22 deletions
|
|
@ -6,12 +6,14 @@
|
|||
|
||||
import os
|
||||
import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
cast,
|
||||
)
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
|
|
@ -30,6 +32,10 @@ from llama_stack.schema_utils import json_schema_type, register_schema
|
|||
|
||||
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
|
||||
|
||||
# Type alias for OpenTelemetry attribute values (excludes None)
|
||||
AttributeValue = str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]
|
||||
Attributes = Mapping[str, AttributeValue]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanStatus(Enum):
|
||||
|
|
@ -428,6 +434,13 @@ _TRACER_PROVIDER = None
|
|||
logger = get_logger(name=__name__, category="telemetry")
|
||||
|
||||
|
||||
def _clean_attributes(attrs: dict[str, Any] | None) -> Attributes | None:
|
||||
"""Remove None values from attributes dict to match OpenTelemetry's expected type."""
|
||||
if attrs is None:
|
||||
return None
|
||||
return {k: v for k, v in attrs.items() if v is not None}
|
||||
|
||||
|
||||
def is_tracing_enabled(tracer):
|
||||
with tracer.start_as_current_span("check_tracing") as span:
|
||||
return span.is_recording()
|
||||
|
|
@ -456,7 +469,7 @@ class Telemetry:
|
|||
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
|
||||
span_exporter = OTLPSpanExporter()
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
cast(TracerProvider, trace.get_tracer_provider()).add_span_processor(span_processor)
|
||||
|
||||
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
|
||||
metric_provider = MeterProvider(metric_readers=[metric_reader])
|
||||
|
|
@ -474,7 +487,7 @@ class Telemetry:
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
if self.is_otel_endpoint_set:
|
||||
trace.get_tracer_provider().force_flush()
|
||||
cast(TracerProvider, trace.get_tracer_provider()).force_flush()
|
||||
|
||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
|
|
@ -515,7 +528,7 @@ class Telemetry:
|
|||
unit=unit,
|
||||
description=f"Counter for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["counters"][name]
|
||||
return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name])
|
||||
|
||||
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
||||
assert self.meter is not None
|
||||
|
|
@ -525,7 +538,7 @@ class Telemetry:
|
|||
unit=unit,
|
||||
description=f"Gauge for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
# Add metric as an event to the current span
|
||||
|
|
@ -560,10 +573,10 @@ class Telemetry:
|
|||
return
|
||||
if isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
counter.add(event.value, attributes=_clean_attributes(event.attributes))
|
||||
elif isinstance(event.value, float):
|
||||
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
|
||||
up_down_counter.add(event.value, attributes=event.attributes)
|
||||
up_down_counter.add(event.value, attributes=_clean_attributes(event.attributes))
|
||||
|
||||
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
|
||||
assert self.meter is not None
|
||||
|
|
@ -573,7 +586,7 @@ class Telemetry:
|
|||
unit=unit,
|
||||
description=f"UpDownCounter for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["up_down_counters"][name]
|
||||
return cast(metrics.UpDownCounter, _GLOBAL_STORAGE["up_down_counters"][name])
|
||||
|
||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
|
|
@ -601,7 +614,8 @@ class Telemetry:
|
|||
if event.payload.parent_span_id:
|
||||
parent_span_id = int(event.payload.parent_span_id, 16)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
context = trace.set_span_in_context(parent_span)
|
||||
if parent_span:
|
||||
context = trace.set_span_in_context(parent_span)
|
||||
elif traceparent:
|
||||
carrier = {
|
||||
"traceparent": traceparent,
|
||||
|
|
@ -612,15 +626,17 @@ class Telemetry:
|
|||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
context=context,
|
||||
attributes=event.attributes or {},
|
||||
attributes=_clean_attributes(event.attributes),
|
||||
)
|
||||
_GLOBAL_STORAGE["active_spans"][span_id] = span
|
||||
|
||||
elif isinstance(event.payload, SpanEndPayload):
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id) # type: ignore[assignment]
|
||||
if span:
|
||||
if event.attributes:
|
||||
span.set_attributes(event.attributes)
|
||||
cleaned_attrs = _clean_attributes(event.attributes)
|
||||
if cleaned_attrs:
|
||||
span.set_attributes(cleaned_attrs)
|
||||
|
||||
status = (
|
||||
trace.Status(status_code=trace.StatusCode.OK)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
|
|
@ -55,17 +55,17 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
|
|||
raise ValueError(f"Operator mapping must have a single operator, got: {value}")
|
||||
op, operand = next(iter(value.items()))
|
||||
if op == "==" or op == "=":
|
||||
return column == operand
|
||||
return cast(ColumnElement[Any], column == operand)
|
||||
if op == ">":
|
||||
return column > operand
|
||||
return cast(ColumnElement[Any], column > operand)
|
||||
if op == "<":
|
||||
return column < operand
|
||||
return cast(ColumnElement[Any], column < operand)
|
||||
if op == ">=":
|
||||
return column >= operand
|
||||
return cast(ColumnElement[Any], column >= operand)
|
||||
if op == "<=":
|
||||
return column <= operand
|
||||
return cast(ColumnElement[Any], column <= operand)
|
||||
raise ValueError(f"Unsupported operator '{op}' in where mapping")
|
||||
return column == value
|
||||
return cast(ColumnElement[Any], column == value)
|
||||
|
||||
|
||||
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||
|
|
@ -210,10 +210,8 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
query = query.limit(fetch_limit)
|
||||
|
||||
result = await session.execute(query)
|
||||
if result.rowcount == 0:
|
||||
rows = []
|
||||
else:
|
||||
rows = [dict(row._mapping) for row in result]
|
||||
# Iterate directly - if no rows, list comprehension yields empty list
|
||||
rows = [dict(row._mapping) for row in result]
|
||||
|
||||
# Always return pagination result
|
||||
has_more = False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue