fix(tests): improve structure of telemetry tests for consistency

This commit is contained in:
Emilio Garcia 2025-10-29 11:56:11 -04:00
parent 583df48479
commit 79156bb08c
3 changed files with 55 additions and 24 deletions

View file

@ -13,6 +13,8 @@ from typing import Any
@dataclass @dataclass
class SpanStub: class SpanStub:
"""Unified span interface for both in-memory and OTLP collectors."""
name: str name: str
attributes: Mapping[str, Any] | None = None attributes: Mapping[str, Any] | None = None
resource_attributes: dict[str, Any] | None = None resource_attributes: dict[str, Any] | None = None
@ -20,6 +22,13 @@ class SpanStub:
trace_id: str | None = None trace_id: str | None = None
span_id: str | None = None span_id: str | None = None
@property
def context(self):
"""Provide context-like interface for trace_id compatibility."""
if self.trace_id is None:
return None
return type("Context", (), {"trace_id": int(self.trace_id, 16)})()
def _value_to_python(value: Any) -> Any: def _value_to_python(value: Any) -> Any:
kind = value.WhichOneof("value") kind = value.WhichOneof("value")
@ -56,12 +65,18 @@ def events_to_list(events: Iterable[Any]) -> list[dict[str, Any]]:
class BaseTelemetryCollector: class BaseTelemetryCollector:
"""Base class for telemetry collectors that ensures consistent return types.
All collectors must return SpanStub objects to ensure test compatibility
across both library-client and server modes.
"""
def get_spans( def get_spans(
self, self,
expected_count: int | None = None, expected_count: int | None = None,
timeout: float = 5.0, timeout: float = 5.0,
poll_interval: float = 0.05, poll_interval: float = 0.05,
) -> tuple[Any, ...]: ) -> tuple[SpanStub, ...]:
import time import time
deadline = time.time() + timeout deadline = time.time() + timeout
@ -97,7 +112,7 @@ class BaseTelemetryCollector:
def clear(self) -> None: def clear(self) -> None:
self._clear_impl() self._clear_impl()
def _snapshot_spans(self) -> tuple[Any, ...]: # pragma: no cover - interface hook def _snapshot_spans(self) -> tuple[SpanStub, ...]: # pragma: no cover - interface hook
raise NotImplementedError raise NotImplementedError
def _snapshot_metrics(self) -> Any | None: # pragma: no cover - interface hook def _snapshot_metrics(self) -> Any | None: # pragma: no cover - interface hook

View file

@ -23,13 +23,20 @@ from .base import BaseTelemetryCollector, SpanStub
class InMemoryTelemetryCollector(BaseTelemetryCollector): class InMemoryTelemetryCollector(BaseTelemetryCollector):
"""In-memory telemetry collector for library-client tests.
Converts OpenTelemetry span objects to SpanStub objects to ensure
consistent interface with OTLP collector used in server mode.
"""
def __init__(self, span_exporter: InMemorySpanExporter, metric_reader: InMemoryMetricReader) -> None: def __init__(self, span_exporter: InMemorySpanExporter, metric_reader: InMemoryMetricReader) -> None:
self._span_exporter = span_exporter self._span_exporter = span_exporter
self._metric_reader = metric_reader self._metric_reader = metric_reader
def _snapshot_spans(self) -> tuple[Any, ...]: def _snapshot_spans(self) -> tuple[SpanStub, ...]:
spans = [] spans = []
for span in self._span_exporter.get_finished_spans(): for span in self._span_exporter.get_finished_spans():
# Extract trace_id and span_id
trace_id = None trace_id = None
span_id = None span_id = None
context = getattr(span, "context", None) context = getattr(span, "context", None)
@ -40,28 +47,37 @@ class InMemoryTelemetryCollector(BaseTelemetryCollector):
trace_id = getattr(span, "trace_id", None) trace_id = getattr(span, "trace_id", None)
span_id = getattr(span, "span_id", None) span_id = getattr(span, "span_id", None)
stub = SpanStub( # Convert attributes to dict if needed
span.name, attrs = span.attributes
span.attributes, if attrs is not None and hasattr(attrs, "items"):
getattr(span, "resource", None), attrs = dict(attrs.items())
getattr(span, "events", None), elif attrs is not None and not isinstance(attrs, dict):
trace_id, attrs = dict(attrs)
span_id, elif attrs is None:
attrs = {}
spans.append(
SpanStub(
name=span.name,
attributes=attrs,
trace_id=trace_id,
span_id=span_id,
)
) )
spans.append(stub)
return tuple(spans) return tuple(spans)
def _snapshot_metrics(self) -> Any | None: def _snapshot_metrics(self) -> Any | None:
data = self._metric_reader.get_metrics_data() data = self._metric_reader.get_metrics_data()
if data and data.resource_metrics: if not data or not data.resource_metrics:
all_metrics = [] return None
for resource_metric in data.resource_metrics:
if resource_metric.scope_metrics: all_metrics = []
for scope_metric in resource_metric.scope_metrics: for resource_metric in data.resource_metrics:
all_metrics.extend(scope_metric.metrics) if resource_metric.scope_metrics:
return all_metrics if all_metrics else None for scope_metric in resource_metric.scope_metrics:
return None all_metrics.extend(scope_metric.metrics)
return all_metrics if all_metrics else None
def _clear_impl(self) -> None: def _clear_impl(self) -> None:
self._span_exporter.clear() self._span_exporter.clear()

View file

@ -4,7 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
"""Telemetry tests verifying @trace_protocol decorator format across stack modes.""" """Telemetry tests verifying @trace_protocol decorator format across stack modes.
Note: The mock_otlp_collector fixture automatically clears telemetry data
before and after each test, ensuring test isolation.
"""
import json import json
@ -44,8 +48,6 @@ def _span_has_message(span, text: str) -> bool:
def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_model_id): def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_model_id):
"""Verify streaming adds chunk_count and __type__=async_generator.""" """Verify streaming adds chunk_count and __type__=async_generator."""
mock_otlp_collector.clear()
stream = llama_stack_client.chat.completions.create( stream = llama_stack_client.chat.completions.create(
model=text_model_id, model=text_model_id,
messages=[{"role": "user", "content": "Test trace openai 1"}], messages=[{"role": "user", "content": "Test trace openai 1"}],
@ -80,8 +82,6 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod
def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, text_model_id): def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, text_model_id):
"""Comprehensive validation of telemetry data format including spans and metrics.""" """Comprehensive validation of telemetry data format including spans and metrics."""
mock_otlp_collector.clear()
response = llama_stack_client.chat.completions.create( response = llama_stack_client.chat.completions.create(
model=text_model_id, model=text_model_id,
messages=[{"role": "user", "content": "Test trace openai with temperature 0.7"}], messages=[{"role": "user", "content": "Test trace openai with temperature 0.7"}],