fix(tests): deduplicate code and simplify user experience for telemetry test code

This commit is contained in:
Emilio Garcia 2025-10-29 12:53:38 -04:00
parent e4e8a59325
commit 52a7784847
4 changed files with 180 additions and 79 deletions

View file

@ -11,6 +11,31 @@ from dataclasses import dataclass
from typing import Any
@dataclass
class MetricStub:
"""Unified metric interface for both in-memory and OTLP collectors."""
name: str
value: Any
attributes: dict[str, Any] | None = None
def get_value(self) -> Any:
"""Get the metric value."""
return self.value
def get_name(self) -> str:
"""Get the metric name."""
return self.name
def get_attributes(self) -> dict[str, Any]:
"""Get metric attributes as a dictionary."""
return self.attributes or {}
def get_attribute(self, key: str) -> Any:
"""Get a specific attribute value by key."""
return self.get_attributes().get(key)
@dataclass
class SpanStub:
"""Unified span interface for both in-memory and OTLP collectors."""
@ -35,18 +60,7 @@ class SpanStub:
Handles different attribute types (mapping, dict, etc.) and returns
a consistent dictionary format.
"""
attrs = self.attributes
if attrs is None:
return {}
# Handle mapping-like objects (e.g., mappingproxy)
try:
return dict(attrs.items()) # type: ignore[attr-defined]
except AttributeError:
try:
return dict(attrs)
except TypeError:
return dict(attrs) if attrs else {}
return BaseTelemetryCollector._convert_attributes_to_dict(self.attributes)
def get_attribute(self, key: str) -> Any:
"""Get a specific attribute value by key."""
@ -167,16 +181,142 @@ class BaseTelemetryCollector:
last_len = len(spans)
time.sleep(poll_interval)
def get_metrics(self) -> Any | None:
def get_metrics(self) -> tuple[MetricStub, ...] | None:
return self._snapshot_metrics()
def get_metrics_dict(self) -> dict[str, Any]:
"""Get metrics as a simple name->value dictionary for easy lookup.
This method works with MetricStub objects for consistent interface
across both in-memory and OTLP collectors.
"""
metrics = self._snapshot_metrics()
if not metrics:
return {}
return {metric.get_name(): metric.get_value() for metric in metrics}
def get_metric_value(self, name: str) -> Any | None:
"""Get a specific metric value by name."""
return self.get_metrics_dict().get(name)
def has_metric(self, name: str) -> bool:
"""Check if a metric with the given name exists."""
return name in self.get_metrics_dict()
def get_metric_names(self) -> list[str]:
"""Get all available metric names."""
return list(self.get_metrics_dict().keys())
@staticmethod
def _convert_attributes_to_dict(attrs: Any) -> dict[str, Any]:
"""Convert various attribute types to a consistent dictionary format.
Handles mappingproxy, dict, and other attribute types.
"""
if attrs is None:
return {}
try:
return dict(attrs.items()) # type: ignore[attr-defined]
except AttributeError:
try:
return dict(attrs)
except TypeError:
return dict(attrs) if attrs else {}
@staticmethod
def _extract_trace_span_ids(span: Any) -> tuple[str | None, str | None]:
"""Extract trace_id and span_id from OpenTelemetry span object.
Handles both context-based and direct attribute access.
"""
trace_id = None
span_id = None
context = getattr(span, "context", None)
if context:
trace_id = f"{context.trace_id:032x}"
span_id = f"{context.span_id:016x}"
else:
trace_id = getattr(span, "trace_id", None)
span_id = getattr(span, "span_id", None)
return trace_id, span_id
@staticmethod
def _create_span_stub_from_opentelemetry(span: Any) -> SpanStub:
"""Create SpanStub from OpenTelemetry span object.
This helper reduces code duplication between collectors.
"""
trace_id, span_id = BaseTelemetryCollector._extract_trace_span_ids(span)
attributes = BaseTelemetryCollector._convert_attributes_to_dict(span.attributes)
return SpanStub(
name=span.name,
attributes=attributes,
trace_id=trace_id,
span_id=span_id,
)
@staticmethod
def _create_span_stub_from_protobuf(span: Any, resource_attrs: dict[str, Any] | None = None) -> SpanStub:
"""Create SpanStub from protobuf span object.
This helper handles the different structure of protobuf spans.
"""
attributes = attributes_to_dict(span.attributes)
events = events_to_list(span.events) if span.events else None
trace_id = span.trace_id.hex() if span.trace_id else None
span_id = span.span_id.hex() if span.span_id else None
return SpanStub(
name=span.name,
attributes=attributes,
resource_attributes=resource_attrs,
events=events,
trace_id=trace_id,
span_id=span_id,
)
@staticmethod
def _extract_metric_from_opentelemetry(metric: Any) -> MetricStub | None:
"""Extract MetricStub from OpenTelemetry metric object.
This helper reduces code duplication between collectors.
"""
if not (hasattr(metric, "name") and hasattr(metric, "data") and hasattr(metric.data, "data_points")):
return None
if not (metric.data.data_points and len(metric.data.data_points) > 0):
return None
# Get the value from the first data point
value = metric.data.data_points[0].value
# Extract attributes if available
attributes = {}
if hasattr(metric.data.data_points[0], "attributes"):
attrs = metric.data.data_points[0].attributes
if attrs is not None and hasattr(attrs, "items"):
attributes = dict(attrs.items())
elif attrs is not None and not isinstance(attrs, dict):
attributes = dict(attrs)
return MetricStub(
name=metric.name,
value=value,
attributes=attributes if attributes else None,
)
def clear(self) -> None:
self._clear_impl()
def _snapshot_spans(self) -> tuple[SpanStub, ...]: # pragma: no cover - interface hook
raise NotImplementedError
def _snapshot_metrics(self) -> Any | None: # pragma: no cover - interface hook
def _snapshot_metrics(self) -> tuple[MetricStub, ...] | None: # pragma: no cover - interface hook
raise NotImplementedError
def _clear_impl(self) -> None: # pragma: no cover - interface hook

View file

@ -6,8 +6,6 @@
"""In-memory telemetry collector for library-client tests."""
from typing import Any
import opentelemetry.metrics as otel_metrics
import opentelemetry.trace as otel_trace
from opentelemetry import metrics, trace
@ -19,7 +17,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanE
import llama_stack.core.telemetry.telemetry as telemetry_module
from .base import BaseTelemetryCollector, SpanStub
from .base import BaseTelemetryCollector, MetricStub, SpanStub
class InMemoryTelemetryCollector(BaseTelemetryCollector):
@ -36,48 +34,24 @@ class InMemoryTelemetryCollector(BaseTelemetryCollector):
def _snapshot_spans(self) -> tuple[SpanStub, ...]:
spans = []
for span in self._span_exporter.get_finished_spans():
# Extract trace_id and span_id
trace_id = None
span_id = None
context = getattr(span, "context", None)
if context:
trace_id = f"{context.trace_id:032x}"
span_id = f"{context.span_id:016x}"
else:
trace_id = getattr(span, "trace_id", None)
span_id = getattr(span, "span_id", None)
# Convert attributes to dict if needed
attrs = span.attributes
if attrs is not None and hasattr(attrs, "items"):
attrs = dict(attrs.items())
elif attrs is not None and not isinstance(attrs, dict):
attrs = dict(attrs)
elif attrs is None:
attrs = {}
spans.append(
SpanStub(
name=span.name,
attributes=attrs,
trace_id=trace_id,
span_id=span_id,
)
)
spans.append(self._create_span_stub_from_opentelemetry(span))
return tuple(spans)
def _snapshot_metrics(self) -> Any | None:
def _snapshot_metrics(self) -> tuple[MetricStub, ...] | None:
data = self._metric_reader.get_metrics_data()
if not data or not data.resource_metrics:
return None
all_metrics = []
metric_stubs = []
for resource_metric in data.resource_metrics:
if resource_metric.scope_metrics:
for scope_metric in resource_metric.scope_metrics:
all_metrics.extend(scope_metric.metrics)
return all_metrics if all_metrics else None
for metric in scope_metric.metrics:
metric_stub = self._extract_metric_from_opentelemetry(metric)
if metric_stub:
metric_stubs.append(metric_stub)
return tuple(metric_stubs) if metric_stubs else None
def _clear_impl(self) -> None:
self._span_exporter.clear()

View file

@ -11,18 +11,17 @@ import os
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from socketserver import ThreadingMixIn
from typing import Any
from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2 import ExportMetricsServiceRequest
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest
from .base import BaseTelemetryCollector, SpanStub, attributes_to_dict, events_to_list
from .base import BaseTelemetryCollector, MetricStub, SpanStub, attributes_to_dict
class OtlpHttpTestCollector(BaseTelemetryCollector):
def __init__(self) -> None:
self._spans: list[SpanStub] = []
self._metrics: list[Any] = []
self._metrics: list[MetricStub] = []
self._lock = threading.Lock()
class _ThreadingHTTPServer(ThreadingMixIn, HTTPServer):
@ -47,11 +46,7 @@ class OtlpHttpTestCollector(BaseTelemetryCollector):
for scope_spans in resource_spans.scope_spans:
for span in scope_spans.spans:
attributes = attributes_to_dict(span.attributes)
events = events_to_list(span.events) if span.events else None
trace_id = span.trace_id.hex() if span.trace_id else None
span_id = span.span_id.hex() if span.span_id else None
new_spans.append(SpanStub(span.name, attributes, resource_attrs or None, events, trace_id, span_id))
new_spans.append(self._create_span_stub_from_protobuf(span, resource_attrs or None))
if not new_spans:
return
@ -60,10 +55,13 @@ class OtlpHttpTestCollector(BaseTelemetryCollector):
self._spans.extend(new_spans)
def _handle_metrics(self, request: ExportMetricsServiceRequest) -> None:
new_metrics: list[Any] = []
new_metrics: list[MetricStub] = []
for resource_metrics in request.resource_metrics:
for scope_metrics in resource_metrics.scope_metrics:
new_metrics.extend(scope_metrics.metrics)
for metric in scope_metrics.metrics:
metric_stub = self._extract_metric_from_opentelemetry(metric)
if metric_stub:
new_metrics.append(metric_stub)
if not new_metrics:
return
@ -75,9 +73,9 @@ class OtlpHttpTestCollector(BaseTelemetryCollector):
with self._lock:
return tuple(self._spans)
def _snapshot_metrics(self) -> Any | None:
def _snapshot_metrics(self) -> tuple[MetricStub, ...] | None:
with self._lock:
return list(self._metrics) if self._metrics else None
return tuple(self._metrics) if self._metrics else None
def _clear_impl(self) -> None:
with self._lock:

View file

@ -107,25 +107,14 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
assert text_model_id in logged_model_ids, f"Expected to find {text_model_id} in spans, but got {logged_model_ids}"
# Verify token usage metrics in response
metrics = mock_otlp_collector.get_metrics()
assert metrics, "Expected metrics to be generated"
# Convert metrics to a dictionary for easier lookup
metrics_dict = {}
for metric in metrics:
if hasattr(metric, "name") and hasattr(metric, "data") and hasattr(metric.data, "data_points"):
if metric.data.data_points and len(metric.data.data_points) > 0:
# Get the value from the first data point
value = metric.data.data_points[0].value
metrics_dict[metric.name] = value
# Verify expected metrics are present
expected_metrics = ["completion_tokens", "total_tokens", "prompt_tokens"]
for metric_name in expected_metrics:
assert metric_name in metrics_dict, f"Expected metric {metric_name} not found in {list(metrics_dict.keys())}"
assert mock_otlp_collector.has_metric(metric_name), (
f"Expected metric {metric_name} not found in {mock_otlp_collector.get_metric_names()}"
)
# Verify metric values match usage data
assert metrics_dict["completion_tokens"] == usage["completion_tokens"]
assert metrics_dict["total_tokens"] == usage["total_tokens"]
assert metrics_dict["prompt_tokens"] == usage["prompt_tokens"]
assert mock_otlp_collector.get_metric_value("completion_tokens") == usage["completion_tokens"]
assert mock_otlp_collector.get_metric_value("total_tokens") == usage["total_tokens"]
assert mock_otlp_collector.get_metric_value("prompt_tokens") == usage["prompt_tokens"]