metrics tests

- Add _create_metric_stub_from_protobuf method to correctly parse protobuf metrics
- Add _extract_attributes_from_data_point helper method
- Change metric handling to use protobuf-specific parsing instead of OpenTelemetry native parsing
- Add missing typing import
- Add OTEL_METRIC_EXPORT_INTERVAL environment variable for test configuration

This fixes the CI failure where metrics were not being properly extracted from
protobuf data in server mode tests.
This commit is contained in:
Emilio Garcia 2025-11-03 10:45:30 -08:00 committed by Eric Huang
parent 415fd9e36b
commit 7a19488787
8 changed files with 420 additions and 125 deletions

View file

@ -215,6 +215,7 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
export OTEL_EXPORTER_OTLP_PROTOCOL="http/protobuf" export OTEL_EXPORTER_OTLP_PROTOCOL="http/protobuf"
export OTEL_BSP_SCHEDULE_DELAY="200" export OTEL_BSP_SCHEDULE_DELAY="200"
export OTEL_BSP_EXPORT_TIMEOUT="2000" export OTEL_BSP_EXPORT_TIMEOUT="2000"
export OTEL_METRIC_EXPORT_INTERVAL="200"
# remove "server:" from STACK_CONFIG # remove "server:" from STACK_CONFIG
stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://') stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://')
@ -311,6 +312,9 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_INFERENCE_MODE=$INFERENCE_MODE" DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_INFERENCE_MODE=$INFERENCE_MODE"
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_STACK_CONFIG_TYPE=server" DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_STACK_CONFIG_TYPE=server"
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:${COLLECTOR_PORT}" DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:${COLLECTOR_PORT}"
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_METRIC_EXPORT_INTERVAL=200"
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_BSP_SCHEDULE_DELAY=200"
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_BSP_EXPORT_TIMEOUT=2000"
# Pass through API keys if they exist # Pass through API keys if they exist
[ -n "${TOGETHER_API_KEY:-}" ] && DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e TOGETHER_API_KEY=$TOGETHER_API_KEY" [ -n "${TOGETHER_API_KEY:-}" ] && DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e TOGETHER_API_KEY=$TOGETHER_API_KEY"

View file

@ -427,6 +427,7 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"counters": {}, "counters": {},
"gauges": {}, "gauges": {},
"up_down_counters": {}, "up_down_counters": {},
"histograms": {},
} }
_global_lock = threading.Lock() _global_lock = threading.Lock()
_TRACER_PROVIDER = None _TRACER_PROVIDER = None
@ -540,6 +541,16 @@ class Telemetry:
) )
return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name]) return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["histograms"]:
_GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram(
name=name,
unit=unit,
description=f"Histogram for {name}",
)
return cast(metrics.Histogram, _GLOBAL_STORAGE["histograms"][name])
def _log_metric(self, event: MetricEvent) -> None: def _log_metric(self, event: MetricEvent) -> None:
# Add metric as an event to the current span # Add metric as an event to the current span
try: try:
@ -571,7 +582,16 @@ class Telemetry:
# Log to OpenTelemetry meter if available # Log to OpenTelemetry meter if available
if self.meter is None: if self.meter is None:
return return
if isinstance(event.value, int):
# Use histograms for token-related metrics (per-request measurements)
# Use counters for other cumulative metrics
token_metrics = {"prompt_tokens", "completion_tokens", "total_tokens"}
if event.metric in token_metrics:
# Token metrics are per-request measurements, use histogram
histogram = self._get_or_create_histogram(event.metric, event.unit)
histogram.record(event.value, attributes=_clean_attributes(event.attributes))
elif isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit) counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=_clean_attributes(event.attributes)) counter.add(event.value, attributes=_clean_attributes(event.attributes))
elif isinstance(event.value, float): elif isinstance(event.value, float):

View file

@ -84,5 +84,6 @@
} }
], ],
"is_streaming": false "is_streaming": false
} },
"id_normalization_mapping": {}
} }

View file

@ -6,20 +6,88 @@
"""Shared helpers for telemetry test collectors.""" """Shared helpers for telemetry test collectors."""
import time
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
@dataclass @dataclass
class SpanStub: class MetricStub:
"""Unified metric interface for both in-memory and OTLP collectors."""
name: str name: str
attributes: dict[str, Any] value: Any
attributes: dict[str, Any] | None = None
@dataclass
class SpanStub:
"""Unified span interface for both in-memory and OTLP collectors."""
name: str
attributes: dict[str, Any] | None = None
resource_attributes: dict[str, Any] | None = None resource_attributes: dict[str, Any] | None = None
events: list[dict[str, Any]] | None = None events: list[dict[str, Any]] | None = None
trace_id: str | None = None trace_id: str | None = None
span_id: str | None = None span_id: str | None = None
@property
def context(self):
"""Provide context-like interface for trace_id compatibility."""
if self.trace_id is None:
return None
return type("Context", (), {"trace_id": int(self.trace_id, 16)})()
def get_trace_id(self) -> str | None:
"""Get trace ID in hex format.
Tries context.trace_id first, then falls back to direct trace_id.
"""
context = getattr(self, "context", None)
if context and getattr(context, "trace_id", None) is not None:
return f"{context.trace_id:032x}"
return getattr(self, "trace_id", None)
def has_message(self, text: str) -> bool:
"""Check if span contains a specific message in its args."""
if self.attributes is None:
return False
args = self.attributes.get("__args__")
if not args or not isinstance(args, str):
return False
return text in args
def is_root_span(self) -> bool:
"""Check if this is a root span."""
if self.attributes is None:
return False
return self.attributes.get("__root__") is True
def is_autotraced(self) -> bool:
"""Check if this span was automatically traced."""
if self.attributes is None:
return False
return self.attributes.get("__autotraced__") is True
def get_span_type(self) -> str | None:
"""Get the span type (async, sync, async_generator)."""
if self.attributes is None:
return None
return self.attributes.get("__type__")
def get_class_method(self) -> tuple[str | None, str | None]:
"""Get the class and method names for autotraced spans."""
if self.attributes is None:
return None, None
return (self.attributes.get("__class__"), self.attributes.get("__method__"))
def get_location(self) -> str | None:
"""Get the location (library_client, server) for root spans."""
if self.attributes is None:
return None
return self.attributes.get("__location__")
def _value_to_python(value: Any) -> Any: def _value_to_python(value: Any) -> Any:
kind = value.WhichOneof("value") kind = value.WhichOneof("value")
@ -56,14 +124,18 @@ def events_to_list(events: Iterable[Any]) -> list[dict[str, Any]]:
class BaseTelemetryCollector: class BaseTelemetryCollector:
"""Base class for telemetry collectors that ensures consistent return types.
All collectors must return SpanStub objects to ensure test compatibility
across both library-client and server modes.
"""
def get_spans( def get_spans(
self, self,
expected_count: int | None = None, expected_count: int | None = None,
timeout: float = 5.0, timeout: float = 5.0,
poll_interval: float = 0.05, poll_interval: float = 0.05,
) -> tuple[Any, ...]: ) -> tuple[SpanStub, ...]:
import time
deadline = time.time() + timeout deadline = time.time() + timeout
min_count = expected_count if expected_count is not None else 1 min_count = expected_count if expected_count is not None else 1
last_len: int | None = None last_len: int | None = None
@ -91,16 +163,206 @@ class BaseTelemetryCollector:
last_len = len(spans) last_len = len(spans)
time.sleep(poll_interval) time.sleep(poll_interval)
def get_metrics(self) -> Any | None: def get_metrics(
return self._snapshot_metrics() self,
expected_count: int | None = None,
timeout: float = 5.0,
poll_interval: float = 0.05,
expect_model_id: str | None = None,
) -> dict[str, MetricStub]:
"""Get metrics with polling until metrics are available or timeout is reached."""
# metrics need to be collected since get requests delete stored metrics
deadline = time.time() + timeout
min_count = expected_count if expected_count is not None else 1
accumulated_metrics = {}
count_metrics_with_model_id = 0
while time.time() < deadline:
current_metrics = self._snapshot_metrics()
if current_metrics:
for metric in current_metrics:
metric_name = metric.name
if metric_name not in accumulated_metrics:
accumulated_metrics[metric_name] = metric
if (
expect_model_id
and metric.attributes
and metric.attributes.get("model_id") == expect_model_id
):
count_metrics_with_model_id += 1
else:
accumulated_metrics[metric_name] = metric
# Check if we have enough metrics
if len(accumulated_metrics) >= min_count:
if not expect_model_id:
return accumulated_metrics
if count_metrics_with_model_id >= min_count:
return accumulated_metrics
time.sleep(poll_interval)
return accumulated_metrics
@staticmethod
def _convert_attributes_to_dict(attrs: Any) -> dict[str, Any]:
"""Convert various attribute types to a consistent dictionary format.
Handles mappingproxy, dict, and other attribute types.
"""
if attrs is None:
return {}
try:
return dict(attrs.items()) # type: ignore[attr-defined]
except AttributeError:
try:
return dict(attrs)
except TypeError:
return dict(attrs) if attrs else {}
@staticmethod
def _extract_trace_span_ids(span: Any) -> tuple[str | None, str | None]:
"""Extract trace_id and span_id from OpenTelemetry span object.
Handles both context-based and direct attribute access.
"""
trace_id = None
span_id = None
context = getattr(span, "context", None)
if context:
trace_id = f"{context.trace_id:032x}"
span_id = f"{context.span_id:016x}"
else:
trace_id = getattr(span, "trace_id", None)
span_id = getattr(span, "span_id", None)
return trace_id, span_id
@staticmethod
def _create_span_stub_from_opentelemetry(span: Any) -> SpanStub:
"""Create SpanStub from OpenTelemetry span object.
This helper reduces code duplication between collectors.
"""
trace_id, span_id = BaseTelemetryCollector._extract_trace_span_ids(span)
attributes = BaseTelemetryCollector._convert_attributes_to_dict(span.attributes) or {}
return SpanStub(
name=span.name,
attributes=attributes,
trace_id=trace_id,
span_id=span_id,
)
@staticmethod
def _create_span_stub_from_protobuf(span: Any, resource_attrs: dict[str, Any] | None = None) -> SpanStub:
"""Create SpanStub from protobuf span object.
This helper handles the different structure of protobuf spans.
"""
attributes = attributes_to_dict(span.attributes) or {}
events = events_to_list(span.events) if span.events else None
trace_id = span.trace_id.hex() if span.trace_id else None
span_id = span.span_id.hex() if span.span_id else None
return SpanStub(
name=span.name,
attributes=attributes,
resource_attributes=resource_attrs,
events=events,
trace_id=trace_id,
span_id=span_id,
)
@staticmethod
def _extract_metric_from_opentelemetry(metric: Any) -> MetricStub | None:
"""Extract MetricStub from OpenTelemetry metric object.
This helper reduces code duplication between collectors.
"""
if not (hasattr(metric, "name") and hasattr(metric, "data") and hasattr(metric.data, "data_points")):
return None
if not (metric.data.data_points and len(metric.data.data_points) > 0):
return None
# Get the value from the first data point
data_point = metric.data.data_points[0]
# Handle different metric types
if hasattr(data_point, "value"):
# Counter or Gauge
value = data_point.value
elif hasattr(data_point, "sum"):
# Histogram - use the sum of all recorded values
value = data_point.sum
else:
return None
# Extract attributes if available
attributes = {}
if hasattr(data_point, "attributes"):
attrs = data_point.attributes
if attrs is not None and hasattr(attrs, "items"):
attributes = dict(attrs.items())
elif attrs is not None and not isinstance(attrs, dict):
attributes = dict(attrs)
return MetricStub(
name=metric.name,
value=value,
attributes=attributes or {},
)
@staticmethod
def _create_metric_stub_from_protobuf(metric: Any) -> MetricStub | None:
"""Create MetricStub from protobuf metric object.
Protobuf metrics have a different structure than OpenTelemetry metrics.
They can have sum, gauge, or histogram data.
"""
if not hasattr(metric, "name"):
return None
# Try to extract value from different metric types
for metric_type in ["sum", "gauge", "histogram"]:
if hasattr(metric, metric_type):
metric_data = getattr(metric, metric_type)
if metric_data and hasattr(metric_data, "data_points"):
data_points = metric_data.data_points
if data_points and len(data_points) > 0:
data_point = data_points[0]
# Extract attributes first (needed for all metric types)
attributes = (
attributes_to_dict(data_point.attributes) if hasattr(data_point, "attributes") else {}
)
# Extract value based on metric type
if metric_type == "sum":
value = data_point.as_int
elif metric_type == "gauge":
value = data_point.as_double
else: # histogram
value = data_point.sum
return MetricStub(
name=metric.name,
value=value,
attributes=attributes,
)
return None
def clear(self) -> None: def clear(self) -> None:
self._clear_impl() self._clear_impl()
def _snapshot_spans(self) -> tuple[Any, ...]: # pragma: no cover - interface hook def _snapshot_spans(self) -> tuple[SpanStub, ...]: # pragma: no cover - interface hook
raise NotImplementedError raise NotImplementedError
def _snapshot_metrics(self) -> Any | None: # pragma: no cover - interface hook def _snapshot_metrics(self) -> tuple[MetricStub, ...] | None: # pragma: no cover - interface hook
raise NotImplementedError raise NotImplementedError
def _clear_impl(self) -> None: # pragma: no cover - interface hook def _clear_impl(self) -> None: # pragma: no cover - interface hook

View file

@ -6,8 +6,6 @@
"""In-memory telemetry collector for library-client tests.""" """In-memory telemetry collector for library-client tests."""
from typing import Any
import opentelemetry.metrics as otel_metrics import opentelemetry.metrics as otel_metrics
import opentelemetry.trace as otel_trace import opentelemetry.trace as otel_trace
from opentelemetry import metrics, trace from opentelemetry import metrics, trace
@ -19,47 +17,42 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanE
import llama_stack.core.telemetry.telemetry as telemetry_module import llama_stack.core.telemetry.telemetry as telemetry_module
from .base import BaseTelemetryCollector, SpanStub from .base import BaseTelemetryCollector, MetricStub, SpanStub
class InMemoryTelemetryCollector(BaseTelemetryCollector): class InMemoryTelemetryCollector(BaseTelemetryCollector):
"""In-memory telemetry collector for library-client tests.
Converts OpenTelemetry span objects to SpanStub objects to ensure
consistent interface with OTLP collector used in server mode.
"""
def __init__(self, span_exporter: InMemorySpanExporter, metric_reader: InMemoryMetricReader) -> None: def __init__(self, span_exporter: InMemorySpanExporter, metric_reader: InMemoryMetricReader) -> None:
self._span_exporter = span_exporter self._span_exporter = span_exporter
self._metric_reader = metric_reader self._metric_reader = metric_reader
def _snapshot_spans(self) -> tuple[Any, ...]: def _snapshot_spans(self) -> tuple[SpanStub, ...]:
spans = [] spans = []
for span in self._span_exporter.get_finished_spans(): for span in self._span_exporter.get_finished_spans():
trace_id = None spans.append(self._create_span_stub_from_opentelemetry(span))
span_id = None
context = getattr(span, "context", None)
if context:
trace_id = f"{context.trace_id:032x}"
span_id = f"{context.span_id:016x}"
else:
trace_id = getattr(span, "trace_id", None)
span_id = getattr(span, "span_id", None)
stub = SpanStub(
span.name,
span.attributes,
getattr(span, "resource", None),
getattr(span, "events", None),
trace_id,
span_id,
)
spans.append(stub)
return tuple(spans) return tuple(spans)
def _snapshot_metrics(self) -> Any | None: def _snapshot_metrics(self) -> tuple[MetricStub, ...] | None:
data = self._metric_reader.get_metrics_data() data = self._metric_reader.get_metrics_data()
if data and data.resource_metrics: if not data or not data.resource_metrics:
resource_metric = data.resource_metrics[0]
if resource_metric.scope_metrics:
return resource_metric.scope_metrics[0].metrics
return None return None
metric_stubs = []
for resource_metric in data.resource_metrics:
if resource_metric.scope_metrics:
for scope_metric in resource_metric.scope_metrics:
for metric in scope_metric.metrics:
metric_stub = self._extract_metric_from_opentelemetry(metric)
if metric_stub:
metric_stubs.append(metric_stub)
return tuple(metric_stubs) if metric_stubs else None
def _clear_impl(self) -> None: def _clear_impl(self) -> None:
self._span_exporter.clear() self._span_exporter.clear()
self._metric_reader.get_metrics_data() self._metric_reader.get_metrics_data()

View file

@ -9,20 +9,20 @@
import gzip import gzip
import os import os
import threading import threading
import time
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from socketserver import ThreadingMixIn from socketserver import ThreadingMixIn
from typing import Any
from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2 import ExportMetricsServiceRequest from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2 import ExportMetricsServiceRequest
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest
from .base import BaseTelemetryCollector, SpanStub, attributes_to_dict, events_to_list from .base import BaseTelemetryCollector, MetricStub, SpanStub, attributes_to_dict
class OtlpHttpTestCollector(BaseTelemetryCollector): class OtlpHttpTestCollector(BaseTelemetryCollector):
def __init__(self) -> None: def __init__(self) -> None:
self._spans: list[SpanStub] = [] self._spans: list[SpanStub] = []
self._metrics: list[Any] = [] self._metrics: list[MetricStub] = []
self._lock = threading.Lock() self._lock = threading.Lock()
class _ThreadingHTTPServer(ThreadingMixIn, HTTPServer): class _ThreadingHTTPServer(ThreadingMixIn, HTTPServer):
@ -47,11 +47,7 @@ class OtlpHttpTestCollector(BaseTelemetryCollector):
for scope_spans in resource_spans.scope_spans: for scope_spans in resource_spans.scope_spans:
for span in scope_spans.spans: for span in scope_spans.spans:
attributes = attributes_to_dict(span.attributes) new_spans.append(self._create_span_stub_from_protobuf(span, resource_attrs or None))
events = events_to_list(span.events) if span.events else None
trace_id = span.trace_id.hex() if span.trace_id else None
span_id = span.span_id.hex() if span.span_id else None
new_spans.append(SpanStub(span.name, attributes, resource_attrs or None, events, trace_id, span_id))
if not new_spans: if not new_spans:
return return
@ -60,10 +56,13 @@ class OtlpHttpTestCollector(BaseTelemetryCollector):
self._spans.extend(new_spans) self._spans.extend(new_spans)
def _handle_metrics(self, request: ExportMetricsServiceRequest) -> None: def _handle_metrics(self, request: ExportMetricsServiceRequest) -> None:
new_metrics: list[Any] = [] new_metrics: list[MetricStub] = []
for resource_metrics in request.resource_metrics: for resource_metrics in request.resource_metrics:
for scope_metrics in resource_metrics.scope_metrics: for scope_metrics in resource_metrics.scope_metrics:
new_metrics.extend(scope_metrics.metrics) for metric in scope_metrics.metrics:
metric_stub = self._create_metric_stub_from_protobuf(metric)
if metric_stub:
new_metrics.append(metric_stub)
if not new_metrics: if not new_metrics:
return return
@ -75,11 +74,40 @@ class OtlpHttpTestCollector(BaseTelemetryCollector):
with self._lock: with self._lock:
return tuple(self._spans) return tuple(self._spans)
def _snapshot_metrics(self) -> Any | None: def _snapshot_metrics(self) -> tuple[MetricStub, ...] | None:
with self._lock: with self._lock:
return list(self._metrics) if self._metrics else None return tuple(self._metrics) if self._metrics else None
def _clear_impl(self) -> None: def _clear_impl(self) -> None:
"""Clear telemetry over a period of time to prevent race conditions between tests."""
with self._lock:
self._spans.clear()
self._metrics.clear()
# Prevent race conditions where telemetry arrives after clear() but before
# the test starts, causing contamination between tests
deadline = time.time() + 2.0 # Maximum wait time
last_span_count = 0
last_metric_count = 0
stable_iterations = 0
while time.time() < deadline:
with self._lock:
current_span_count = len(self._spans)
current_metric_count = len(self._metrics)
if current_span_count == last_span_count and current_metric_count == last_metric_count:
stable_iterations += 1
if stable_iterations >= 4: # 4 * 50ms = 200ms of stability
break
else:
stable_iterations = 0
last_span_count = current_span_count
last_metric_count = current_metric_count
time.sleep(0.05)
# Final clear to remove any telemetry that arrived during stabilization
with self._lock: with self._lock:
self._spans.clear() self._spans.clear()
self._metrics.clear() self._metrics.clear()

View file

@ -30,7 +30,7 @@
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": "import torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\n# Load the pre-trained model and tokenizer\nmodel_name = \"CompVis/transformers-base-uncased\"\nmodel = AutoModelForCausalLM.from_pretrained(model_name)\ntokenizer = AutoTokenizer.from_pretrained(model_name)\n\n# Set the temperature to 0.7\ntemperature = 0.7\n\n# Define a function to generate text\ndef generate_text(prompt, max_length=100):\n input", "content": "To test the trace function from OpenAI's API with a temperature of 0.7, you can use the following Python code:\n\n```python\nimport json\n\n# Import the required libraries\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\n# Set the API endpoint and model name\nmodel_name = \"dalle-mini\"\n\n# Initialize the model and tokenizer\nmodel = AutoModelForCausalLM.from_pretrained(model_name)\ntokenizer = AutoTokenizer.from_pretrained(model_name)\n\n",
"refusal": null, "refusal": null,
"role": "assistant", "role": "assistant",
"annotations": null, "annotations": null,
@ -55,5 +55,6 @@
} }
}, },
"is_streaming": false "is_streaming": false
} },
"id_normalization_mapping": {}
} }

View file

@ -4,48 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
"""Telemetry tests verifying @trace_protocol decorator format across stack modes.""" """Telemetry tests verifying @trace_protocol decorator format across stack modes.
Note: The mock_otlp_collector fixture automatically clears telemetry data
before and after each test, ensuring test isolation.
"""
import json import json
def _span_attributes(span):
attrs = getattr(span, "attributes", None)
if attrs is None:
return {}
# ReadableSpan.attributes acts like a mapping
try:
return dict(attrs.items()) # type: ignore[attr-defined]
except AttributeError:
try:
return dict(attrs)
except TypeError:
return attrs
def _span_attr(span, key):
attrs = _span_attributes(span)
return attrs.get(key)
def _span_trace_id(span):
context = getattr(span, "context", None)
if context and getattr(context, "trace_id", None) is not None:
return f"{context.trace_id:032x}"
return getattr(span, "trace_id", None)
def _span_has_message(span, text: str) -> bool:
args = _span_attr(span, "__args__")
if not args or not isinstance(args, str):
return False
return text in args
def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_model_id): def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_model_id):
"""Verify streaming adds chunk_count and __type__=async_generator.""" """Verify streaming adds chunk_count and __type__=async_generator."""
mock_otlp_collector.clear()
stream = llama_stack_client.chat.completions.create( stream = llama_stack_client.chat.completions.create(
model=text_model_id, model=text_model_id,
messages=[{"role": "user", "content": "Test trace openai 1"}], messages=[{"role": "user", "content": "Test trace openai 1"}],
@ -62,16 +31,16 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod
( (
span span
for span in reversed(spans) for span in reversed(spans)
if _span_attr(span, "__type__") == "async_generator" if span.get_span_type() == "async_generator"
and _span_attr(span, "chunk_count") and span.attributes.get("chunk_count")
and _span_has_message(span, "Test trace openai 1") and span.has_message("Test trace openai 1")
), ),
None, None,
) )
assert async_generator_span is not None assert async_generator_span is not None
raw_chunk_count = _span_attr(async_generator_span, "chunk_count") raw_chunk_count = async_generator_span.attributes.get("chunk_count")
assert raw_chunk_count is not None assert raw_chunk_count is not None
chunk_count = int(raw_chunk_count) chunk_count = int(raw_chunk_count)
@ -80,7 +49,6 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod
def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, text_model_id): def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, text_model_id):
"""Comprehensive validation of telemetry data format including spans and metrics.""" """Comprehensive validation of telemetry data format including spans and metrics."""
mock_otlp_collector.clear()
response = llama_stack_client.chat.completions.create( response = llama_stack_client.chat.completions.create(
model=text_model_id, model=text_model_id,
@ -101,37 +69,36 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
# Verify spans # Verify spans
spans = mock_otlp_collector.get_spans(expected_count=7) spans = mock_otlp_collector.get_spans(expected_count=7)
target_span = next( target_span = next(
(span for span in reversed(spans) if _span_has_message(span, "Test trace openai with temperature 0.7")), (span for span in reversed(spans) if span.has_message("Test trace openai with temperature 0.7")),
None, None,
) )
assert target_span is not None assert target_span is not None
trace_id = _span_trace_id(target_span) trace_id = target_span.get_trace_id()
assert trace_id is not None assert trace_id is not None
spans = [span for span in spans if _span_trace_id(span) == trace_id] spans = [span for span in spans if span.get_trace_id() == trace_id]
spans = [span for span in spans if _span_attr(span, "__root__") or _span_attr(span, "__autotraced__")] spans = [span for span in spans if span.is_root_span() or span.is_autotraced()]
assert len(spans) >= 4 assert len(spans) >= 4
# Collect all model_ids found in spans # Collect all model_ids found in spans
logged_model_ids = [] logged_model_ids = []
for span in spans: for span in spans:
attrs = _span_attributes(span) attrs = span.attributes
assert attrs is not None assert attrs is not None
# Root span is created manually by tracing middleware, not by @trace_protocol decorator # Root span is created manually by tracing middleware, not by @trace_protocol decorator
is_root_span = attrs.get("__root__") is True if span.is_root_span():
assert span.get_location() in ["library_client", "server"]
if is_root_span:
assert attrs.get("__location__") in ["library_client", "server"]
continue continue
assert attrs.get("__autotraced__") assert span.is_autotraced()
assert attrs.get("__class__") and attrs.get("__method__") class_name, method_name = span.get_class_method()
assert attrs.get("__type__") in ["async", "sync", "async_generator"] assert class_name and method_name
assert span.get_span_type() in ["async", "sync", "async_generator"]
args_field = attrs.get("__args__") args_field = span.attributes.get("__args__")
if args_field: if args_field:
args = json.loads(args_field) args = json.loads(args_field)
if "model_id" in args: if "model_id" in args:
@ -140,21 +107,40 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
# At least one span should capture the fully qualified model ID # At least one span should capture the fully qualified model ID
assert text_model_id in logged_model_ids, f"Expected to find {text_model_id} in spans, but got {logged_model_ids}" assert text_model_id in logged_model_ids, f"Expected to find {text_model_id} in spans, but got {logged_model_ids}"
# TODO: re-enable this once metrics get fixed # Verify token usage metrics in response using polling
""" expected_metrics = ["completion_tokens", "total_tokens", "prompt_tokens"]
# Verify token usage metrics in response metrics = mock_otlp_collector.get_metrics(expected_count=len(expected_metrics), expect_model_id=text_model_id)
metrics = mock_otlp_collector.get_metrics() assert len(metrics) > 0, "No metrics found within timeout"
assert metrics # Filter metrics to only those from the specific model used in the request
for metric in metrics: # This prevents issues when multiple metrics with the same name exist from different models
assert metric.name in ["completion_tokens", "total_tokens", "prompt_tokens"] # (e.g., when safety models like llama-guard are also called)
assert metric.unit == "tokens" inference_model_metrics = {}
assert metric.data.data_points and len(metric.data.data_points) == 1 all_model_ids = set()
match metric.name:
case "completion_tokens": for name, metric in metrics.items():
assert metric.data.data_points[0].value == usage["completion_tokens"] if name in expected_metrics:
case "total_tokens": model_id = metric.attributes.get("model_id")
assert metric.data.data_points[0].value == usage["total_tokens"] all_model_ids.add(model_id)
case "prompt_tokens": # Only include metrics from the specific model used in the test request
assert metric.data.data_points[0].value == usage["prompt_tokens" if model_id == text_model_id:
""" inference_model_metrics[name] = metric
# Verify expected metrics are present for our specific model
for metric_name in expected_metrics:
assert metric_name in inference_model_metrics, (
f"Expected metric {metric_name} for model {text_model_id} not found. "
f"Available models: {sorted(all_model_ids)}, "
f"Available metrics for {text_model_id}: {list(inference_model_metrics.keys())}"
)
# Verify metric values match usage data
assert inference_model_metrics["completion_tokens"].value == usage["completion_tokens"], (
f"Expected {usage['completion_tokens']} for completion_tokens, but got {inference_model_metrics['completion_tokens'].value}"
)
assert inference_model_metrics["total_tokens"].value == usage["total_tokens"], (
f"Expected {usage['total_tokens']} for total_tokens, but got {inference_model_metrics['total_tokens'].value}"
)
assert inference_model_metrics["prompt_tokens"].value == usage["prompt_tokens"], (
f"Expected {usage['prompt_tokens']} for prompt_tokens, but got {inference_model_metrics['prompt_tokens'].value}"
)