diff --git a/scripts/integration-tests.sh b/scripts/integration-tests.sh index cdd3e736f..372e97d8c 100755 --- a/scripts/integration-tests.sh +++ b/scripts/integration-tests.sh @@ -227,14 +227,15 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then echo "=== Starting Llama Stack Server ===" export LLAMA_STACK_LOG_WIDTH=120 - # Configure telemetry collector for server mode - # Use a fixed port for the OTEL collector so the server can connect to it - COLLECTOR_PORT=4317 - export LLAMA_STACK_TEST_COLLECTOR_PORT="${COLLECTOR_PORT}" - export OTEL_EXPORTER_OTLP_ENDPOINT="http://127.0.0.1:${COLLECTOR_PORT}" - export OTEL_EXPORTER_OTLP_PROTOCOL="http/protobuf" - export OTEL_BSP_SCHEDULE_DELAY="200" - export OTEL_BSP_EXPORT_TIMEOUT="2000" + # Configure telemetry collector for server mode + # Use a fixed port for the OTEL collector so the server can connect to it + COLLECTOR_PORT=4317 + export LLAMA_STACK_TEST_COLLECTOR_PORT="${COLLECTOR_PORT}" + export OTEL_EXPORTER_OTLP_ENDPOINT="http://127.0.0.1:${COLLECTOR_PORT}" + export OTEL_EXPORTER_OTLP_PROTOCOL="http/protobuf" + export OTEL_BSP_SCHEDULE_DELAY="200" + export OTEL_BSP_EXPORT_TIMEOUT="2000" + export OTEL_METRIC_EXPORT_INTERVAL="200" # remove "server:" from STACK_CONFIG stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://') @@ -337,6 +338,9 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_INFERENCE_MODE=$INFERENCE_MODE" DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_STACK_CONFIG_TYPE=server" DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:${COLLECTOR_PORT}" + DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_METRIC_EXPORT_INTERVAL=200" + DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_BSP_SCHEDULE_DELAY=200" + DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_BSP_EXPORT_TIMEOUT=2000" # Pass through API keys if they exist [ -n "${TOGETHER_API_KEY:-}" ] && DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e TOGETHER_API_KEY=$TOGETHER_API_KEY" diff --git a/src/llama_stack/core/telemetry/telemetry.py b/src/llama_stack/core/telemetry/telemetry.py index 1ba43724d..9476c961a 100644 --- a/src/llama_stack/core/telemetry/telemetry.py +++ b/src/llama_stack/core/telemetry/telemetry.py @@ -427,6 +427,7 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = { "counters": {}, "gauges": {}, "up_down_counters": {}, + "histograms": {}, } _global_lock = threading.Lock() _TRACER_PROVIDER = None @@ -540,6 +541,16 @@ class Telemetry: ) return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name]) + def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram: + assert self.meter is not None + if name not in _GLOBAL_STORAGE["histograms"]: + _GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram( + name=name, + unit=unit, + description=f"Histogram for {name}", + ) + return cast(metrics.Histogram, _GLOBAL_STORAGE["histograms"][name]) + def _log_metric(self, event: MetricEvent) -> None: # Add metric as an event to the current span try: @@ -571,7 +582,16 @@ class Telemetry: # Log to OpenTelemetry meter if available if self.meter is None: return - if isinstance(event.value, int): + + # Use histograms for token-related metrics (per-request measurements) + # Use counters for other cumulative metrics + token_metrics = {"prompt_tokens", "completion_tokens", "total_tokens"} + + if event.metric in token_metrics: + # Token metrics are per-request measurements, use histogram + histogram = self._get_or_create_histogram(event.metric, event.unit) + histogram.record(event.value, attributes=_clean_attributes(event.attributes)) + elif isinstance(event.value, int): counter = self._get_or_create_counter(event.metric, event.unit) counter.add(event.value, attributes=_clean_attributes(event.attributes)) elif isinstance(event.value, float): diff --git a/tests/integration/common/recordings/models-64a2277c90f0f42576f60c1030e3a020403d34a95f56931b792d5939f4cebc57-826d44c3.json b/tests/integration/common/recordings/models-64a2277c90f0f42576f60c1030e3a020403d34a95f56931b792d5939f4cebc57-826d44c3.json new file mode 100644 index 000000000..878fcc650 --- /dev/null +++ b/tests/integration/common/recordings/models-64a2277c90f0f42576f60c1030e3a020403d34a95f56931b792d5939f4cebc57-826d44c3.json @@ -0,0 +1,89 @@ +{ + "test_id": null, + "request": { + "method": "POST", + "url": "http://0.0.0.0:11434/v1/v1/models", + "headers": {}, + "body": {}, + "endpoint": "/v1/models", + "model": "" + }, + "response": { + "body": [ + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "llama3.2:3b-instruct-fp16", + "created": 1760453641, + "object": "model", + "owned_by": "library" + } + }, + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "qwen3:4b", + "created": 1757615302, + "object": "model", + "owned_by": "library" + } + }, + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "gpt-oss:latest", + "created": 1756395223, + "object": "model", + "owned_by": "library" + } + }, + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "nomic-embed-text:latest", + "created": 1756318548, + "object": "model", + "owned_by": "library" + } + }, + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "llama3.2:3b", + "created": 1755191039, + "object": "model", + "owned_by": "library" + } + }, + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "all-minilm:l6-v2", + "created": 1753968177, + "object": "model", + "owned_by": "library" + } + }, + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "llama3.2:1b", + "created": 1746124735, + "object": "model", + "owned_by": "library" + } + }, + { + "__type__": "openai.types.model.Model", + "__data__": { + "id": "llama3.2:latest", + "created": 1746044170, + "object": "model", + "owned_by": "library" + } + } + ], + "is_streaming": false + }, + "id_normalization_mapping": {} +} diff --git a/tests/integration/telemetry/collectors/base.py b/tests/integration/telemetry/collectors/base.py index a85e6cf3f..c6c96e99a 100644 --- a/tests/integration/telemetry/collectors/base.py +++ b/tests/integration/telemetry/collectors/base.py @@ -6,20 +6,89 @@ """Shared helpers for telemetry test collectors.""" +import os +import time from collections.abc import Iterable from dataclasses import dataclass from typing import Any @dataclass -class SpanStub: +class MetricStub: + """Unified metric interface for both in-memory and OTLP collectors.""" + name: str - attributes: dict[str, Any] + value: Any + attributes: dict[str, Any] | None = None + + +@dataclass +class SpanStub: + """Unified span interface for both in-memory and OTLP collectors.""" + + name: str + attributes: dict[str, Any] | None = None resource_attributes: dict[str, Any] | None = None events: list[dict[str, Any]] | None = None trace_id: str | None = None span_id: str | None = None + @property + def context(self): + """Provide context-like interface for trace_id compatibility.""" + if self.trace_id is None: + return None + return type("Context", (), {"trace_id": int(self.trace_id, 16)})() + + def get_trace_id(self) -> str | None: + """Get trace ID in hex format. + + Tries context.trace_id first, then falls back to direct trace_id. + """ + context = getattr(self, "context", None) + if context and getattr(context, "trace_id", None) is not None: + return f"{context.trace_id:032x}" + return getattr(self, "trace_id", None) + + def has_message(self, text: str) -> bool: + """Check if span contains a specific message in its args.""" + if self.attributes is None: + return False + args = self.attributes.get("__args__") + if not args or not isinstance(args, str): + return False + return text in args + + def is_root_span(self) -> bool: + """Check if this is a root span.""" + if self.attributes is None: + return False + return self.attributes.get("__root__") is True + + def is_autotraced(self) -> bool: + """Check if this span was automatically traced.""" + if self.attributes is None: + return False + return self.attributes.get("__autotraced__") is True + + def get_span_type(self) -> str | None: + """Get the span type (async, sync, async_generator).""" + if self.attributes is None: + return None + return self.attributes.get("__type__") + + def get_class_method(self) -> tuple[str | None, str | None]: + """Get the class and method names for autotraced spans.""" + if self.attributes is None: + return None, None + return (self.attributes.get("__class__"), self.attributes.get("__method__")) + + def get_location(self) -> str | None: + """Get the location (library_client, server) for root spans.""" + if self.attributes is None: + return None + return self.attributes.get("__location__") + def _value_to_python(value: Any) -> Any: kind = value.WhichOneof("value") @@ -56,14 +125,65 @@ def events_to_list(events: Iterable[Any]) -> list[dict[str, Any]]: class BaseTelemetryCollector: + """Base class for telemetry collectors that ensures consistent return types. + + All collectors must return SpanStub objects to ensure test compatibility + across both library-client and server modes. + """ + + # Default delay in seconds if OTEL_METRIC_EXPORT_INTERVAL is not set + _DEFAULT_BASELINE_STABILIZATION_DELAY = 0.2 + + def __init__(self): + self._metric_baseline: dict[tuple[str, str], float] = {} + + @classmethod + def _get_baseline_stabilization_delay(cls) -> float: + """Get baseline stabilization delay from OTEL_METRIC_EXPORT_INTERVAL. + + Adds 1.5x buffer for CI environments. + """ + interval_ms = os.environ.get("OTEL_METRIC_EXPORT_INTERVAL") + if interval_ms: + try: + delay = float(interval_ms) / 1000.0 + except (ValueError, TypeError): + delay = cls._DEFAULT_BASELINE_STABILIZATION_DELAY + else: + delay = cls._DEFAULT_BASELINE_STABILIZATION_DELAY + + if os.environ.get("CI"): + delay *= 1.5 + + return delay + + def _get_metric_key(self, metric: MetricStub) -> tuple[str, str]: + """Generate a stable key for a metric based on name and attributes.""" + attrs = metric.attributes or {} + attr_key = ",".join(f"{k}={v}" for k, v in sorted(attrs.items())) + return (metric.name, attr_key) + + def _compute_metric_delta(self, metric: MetricStub) -> int | float | None: + """Compute delta value for a metric from baseline. + + Returns: + Delta value if metric was in baseline, absolute value if new, None if unchanged. + """ + metric_key = self._get_metric_key(metric) + + if metric_key in self._metric_baseline: + baseline_value = self._metric_baseline[metric_key] + delta = metric.value - baseline_value + return delta if delta > 0 else None + else: + return metric.value + def get_spans( self, expected_count: int | None = None, timeout: float = 5.0, poll_interval: float = 0.05, - ) -> tuple[Any, ...]: - import time - + ) -> tuple[SpanStub, ...]: deadline = time.time() + timeout min_count = expected_count if expected_count is not None else 1 last_len: int | None = None @@ -91,16 +211,292 @@ class BaseTelemetryCollector: last_len = len(spans) time.sleep(poll_interval) - def get_metrics(self) -> Any | None: - return self._snapshot_metrics() + def get_metrics( + self, + expected_count: int | None = None, + timeout: float = 5.0, + poll_interval: float = 0.05, + expect_model_id: str | None = None, + ) -> dict[str, MetricStub]: + """Poll until expected metrics are available or timeout is reached. + + Returns metrics with delta values computed from baseline. + """ + deadline = time.time() + timeout + min_count = expected_count if expected_count is not None else 1 + accumulated_metrics = {} + seen_metric_names_with_model_id = set() + + while time.time() < deadline: + current_metrics = self._snapshot_metrics() + if current_metrics: + for metric in current_metrics: + delta_value = self._compute_metric_delta(metric) + if delta_value is None: + continue + + metric_with_delta = MetricStub( + name=metric.name, + value=delta_value, + attributes=metric.attributes, + ) + + self._accumulate_metric( + accumulated_metrics, + metric_with_delta, + expect_model_id, + seen_metric_names_with_model_id, + ) + + if self._has_enough_metrics( + accumulated_metrics, seen_metric_names_with_model_id, min_count, expect_model_id + ): + return accumulated_metrics + + time.sleep(poll_interval) + + return accumulated_metrics + + def _accumulate_metric( + self, + accumulated: dict[str, MetricStub], + metric: MetricStub, + expect_model_id: str | None, + seen_with_model_id: set[str], + ) -> None: + """Accumulate a metric, preferring those matching expected model_id.""" + metric_name = metric.name + matches_model_id = ( + expect_model_id and metric.attributes and metric.attributes.get("model_id") == expect_model_id + ) + + if metric_name not in accumulated: + accumulated[metric_name] = metric + if matches_model_id: + seen_with_model_id.add(metric_name) + return + + existing = accumulated[metric_name] + existing_matches = ( + expect_model_id and existing.attributes and existing.attributes.get("model_id") == expect_model_id + ) + + if matches_model_id and not existing_matches: + accumulated[metric_name] = metric + seen_with_model_id.add(metric_name) + elif matches_model_id == existing_matches: + if metric.value > existing.value: + accumulated[metric_name] = metric + if matches_model_id: + seen_with_model_id.add(metric_name) + + def _has_enough_metrics( + self, + accumulated: dict[str, MetricStub], + seen_with_model_id: set[str], + min_count: int, + expect_model_id: str | None, + ) -> bool: + """Check if we have collected enough metrics.""" + if len(accumulated) < min_count: + return False + if not expect_model_id: + return True + return len(seen_with_model_id) >= min_count + + @staticmethod + def _convert_attributes_to_dict(attrs: Any) -> dict[str, Any]: + """Convert various attribute types to a consistent dictionary format. + + Handles mappingproxy, dict, and other attribute types. + """ + if attrs is None: + return {} + + try: + return dict(attrs.items()) # type: ignore[attr-defined] + except AttributeError: + try: + return dict(attrs) + except TypeError: + return dict(attrs) if attrs else {} + + @staticmethod + def _extract_trace_span_ids(span: Any) -> tuple[str | None, str | None]: + """Extract trace_id and span_id from OpenTelemetry span object. + + Handles both context-based and direct attribute access. + """ + trace_id = None + span_id = None + + context = getattr(span, "context", None) + if context: + trace_id = f"{context.trace_id:032x}" + span_id = f"{context.span_id:016x}" + else: + trace_id = getattr(span, "trace_id", None) + span_id = getattr(span, "span_id", None) + + return trace_id, span_id + + @staticmethod + def _create_span_stub_from_opentelemetry(span: Any) -> SpanStub: + """Create SpanStub from OpenTelemetry span object. + + This helper reduces code duplication between collectors. + """ + trace_id, span_id = BaseTelemetryCollector._extract_trace_span_ids(span) + attributes = BaseTelemetryCollector._convert_attributes_to_dict(span.attributes) or {} + + return SpanStub( + name=span.name, + attributes=attributes, + trace_id=trace_id, + span_id=span_id, + ) + + @staticmethod + def _create_span_stub_from_protobuf(span: Any, resource_attrs: dict[str, Any] | None = None) -> SpanStub: + """Create SpanStub from protobuf span object. + + This helper handles the different structure of protobuf spans. + """ + attributes = attributes_to_dict(span.attributes) or {} + events = events_to_list(span.events) if span.events else None + trace_id = span.trace_id.hex() if span.trace_id else None + span_id = span.span_id.hex() if span.span_id else None + + return SpanStub( + name=span.name, + attributes=attributes, + resource_attributes=resource_attrs, + events=events, + trace_id=trace_id, + span_id=span_id, + ) + + @staticmethod + def _extract_metric_from_opentelemetry(metric: Any) -> MetricStub | None: + """Extract MetricStub from OpenTelemetry metric object. + + This helper reduces code duplication between collectors. + """ + if not (hasattr(metric, "name") and hasattr(metric, "data") and hasattr(metric.data, "data_points")): + return None + + if not (metric.data.data_points and len(metric.data.data_points) > 0): + return None + + data_point = metric.data.data_points[0] + + if hasattr(data_point, "value"): + # Counter or Gauge + value = data_point.value + elif hasattr(data_point, "sum"): + # Histogram - use the sum of all recorded values + value = data_point.sum + else: + return None + + attributes = {} + if hasattr(data_point, "attributes"): + attrs = data_point.attributes + if attrs is not None and hasattr(attrs, "items"): + attributes = dict(attrs.items()) + elif attrs is not None and not isinstance(attrs, dict): + attributes = dict(attrs) + + return MetricStub( + name=metric.name, + value=value, + attributes=attributes or {}, + ) + + @staticmethod + def _create_metric_stubs_from_protobuf(metric: Any) -> list[MetricStub]: + """Create list of MetricStub objects from protobuf metric object. + + Protobuf metrics can have sum, gauge, or histogram data. Each metric can have + multiple data points with different attributes, so we return one MetricStub + per data point. + + Returns: + List of MetricStub objects, one per data point in the metric. + """ + if not hasattr(metric, "name"): + return [] + + metric_stubs = [] + + for metric_type in ["sum", "gauge", "histogram"]: + if not hasattr(metric, metric_type): + continue + + metric_data = getattr(metric, metric_type) + if not metric_data or not hasattr(metric_data, "data_points"): + continue + + data_points = metric_data.data_points + if not data_points: + continue + + for data_point in data_points: + attributes = attributes_to_dict(data_point.attributes) if hasattr(data_point, "attributes") else {} + + value = BaseTelemetryCollector._extract_data_point_value(data_point, metric_type) + if value is None: + continue + + metric_stubs.append( + MetricStub( + name=metric.name, + value=value, + attributes=attributes, + ) + ) + + # Only process one metric type per metric + break + + return metric_stubs + + @staticmethod + def _extract_data_point_value(data_point: Any, metric_type: str) -> float | int | None: + """Extract value from a protobuf metric data point based on metric type.""" + if metric_type == "sum": + if hasattr(data_point, "as_int"): + return data_point.as_int + if hasattr(data_point, "as_double"): + return data_point.as_double + elif metric_type == "gauge": + if hasattr(data_point, "as_double"): + return data_point.as_double + elif metric_type == "histogram": + # Histograms use sum field which represents cumulative sum of all recorded values + if hasattr(data_point, "sum"): + return data_point.sum + + return None def clear(self) -> None: + """Clear telemetry data and establish baseline for metric delta computation.""" + self._metric_baseline.clear() + self._clear_impl() - def _snapshot_spans(self) -> tuple[Any, ...]: # pragma: no cover - interface hook + delay = self._get_baseline_stabilization_delay() + time.sleep(delay) + baseline_metrics = self._snapshot_metrics() + if baseline_metrics: + for metric in baseline_metrics: + metric_key = self._get_metric_key(metric) + self._metric_baseline[metric_key] = metric.value + + def _snapshot_spans(self) -> tuple[SpanStub, ...]: # pragma: no cover - interface hook raise NotImplementedError - def _snapshot_metrics(self) -> Any | None: # pragma: no cover - interface hook + def _snapshot_metrics(self) -> tuple[MetricStub, ...] | None: # pragma: no cover - interface hook raise NotImplementedError def _clear_impl(self) -> None: # pragma: no cover - interface hook diff --git a/tests/integration/telemetry/collectors/in_memory.py b/tests/integration/telemetry/collectors/in_memory.py index 2cf320f7b..7127b3816 100644 --- a/tests/integration/telemetry/collectors/in_memory.py +++ b/tests/integration/telemetry/collectors/in_memory.py @@ -6,8 +6,6 @@ """In-memory telemetry collector for library-client tests.""" -from typing import Any - import opentelemetry.metrics as otel_metrics import opentelemetry.trace as otel_trace from opentelemetry import metrics, trace @@ -19,46 +17,42 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanE import llama_stack.core.telemetry.telemetry as telemetry_module -from .base import BaseTelemetryCollector, SpanStub +from .base import BaseTelemetryCollector, MetricStub, SpanStub class InMemoryTelemetryCollector(BaseTelemetryCollector): + """In-memory telemetry collector for library-client tests. + + Converts OpenTelemetry span objects to SpanStub objects to ensure + consistent interface with OTLP collector used in server mode. + """ + def __init__(self, span_exporter: InMemorySpanExporter, metric_reader: InMemoryMetricReader) -> None: + super().__init__() self._span_exporter = span_exporter self._metric_reader = metric_reader - def _snapshot_spans(self) -> tuple[Any, ...]: + def _snapshot_spans(self) -> tuple[SpanStub, ...]: spans = [] for span in self._span_exporter.get_finished_spans(): - trace_id = None - span_id = None - context = getattr(span, "context", None) - if context: - trace_id = f"{context.trace_id:032x}" - span_id = f"{context.span_id:016x}" - else: - trace_id = getattr(span, "trace_id", None) - span_id = getattr(span, "span_id", None) - - stub = SpanStub( - span.name, - span.attributes, - getattr(span, "resource", None), - getattr(span, "events", None), - trace_id, - span_id, - ) - spans.append(stub) - + spans.append(self._create_span_stub_from_opentelemetry(span)) return tuple(spans) - def _snapshot_metrics(self) -> Any | None: + def _snapshot_metrics(self) -> tuple[MetricStub, ...] | None: data = self._metric_reader.get_metrics_data() - if data and data.resource_metrics: - resource_metric = data.resource_metrics[0] + if not data or not data.resource_metrics: + return None + + metric_stubs = [] + for resource_metric in data.resource_metrics: if resource_metric.scope_metrics: - return resource_metric.scope_metrics[0].metrics - return None + for scope_metric in resource_metric.scope_metrics: + for metric in scope_metric.metrics: + metric_stub = self._extract_metric_from_opentelemetry(metric) + if metric_stub: + metric_stubs.append(metric_stub) + + return tuple(metric_stubs) if metric_stubs else None def _clear_impl(self) -> None: self._span_exporter.clear() diff --git a/tests/integration/telemetry/collectors/otlp.py b/tests/integration/telemetry/collectors/otlp.py index 2d6cb0b7e..21702e447 100644 --- a/tests/integration/telemetry/collectors/otlp.py +++ b/tests/integration/telemetry/collectors/otlp.py @@ -9,20 +9,21 @@ import gzip import os import threading +import time from http.server import BaseHTTPRequestHandler, HTTPServer from socketserver import ThreadingMixIn -from typing import Any from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2 import ExportMetricsServiceRequest from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest -from .base import BaseTelemetryCollector, SpanStub, attributes_to_dict, events_to_list +from .base import BaseTelemetryCollector, MetricStub, SpanStub, attributes_to_dict class OtlpHttpTestCollector(BaseTelemetryCollector): def __init__(self) -> None: + super().__init__() self._spans: list[SpanStub] = [] - self._metrics: list[Any] = [] + self._metrics: list[MetricStub] = [] self._lock = threading.Lock() class _ThreadingHTTPServer(ThreadingMixIn, HTTPServer): @@ -47,11 +48,7 @@ class OtlpHttpTestCollector(BaseTelemetryCollector): for scope_spans in resource_spans.scope_spans: for span in scope_spans.spans: - attributes = attributes_to_dict(span.attributes) - events = events_to_list(span.events) if span.events else None - trace_id = span.trace_id.hex() if span.trace_id else None - span_id = span.span_id.hex() if span.span_id else None - new_spans.append(SpanStub(span.name, attributes, resource_attrs or None, events, trace_id, span_id)) + new_spans.append(self._create_span_stub_from_protobuf(span, resource_attrs or None)) if not new_spans: return @@ -60,10 +57,13 @@ class OtlpHttpTestCollector(BaseTelemetryCollector): self._spans.extend(new_spans) def _handle_metrics(self, request: ExportMetricsServiceRequest) -> None: - new_metrics: list[Any] = [] + new_metrics: list[MetricStub] = [] for resource_metrics in request.resource_metrics: for scope_metrics in resource_metrics.scope_metrics: - new_metrics.extend(scope_metrics.metrics) + for metric in scope_metrics.metrics: + # Handle multiple data points per metric (e.g., different attribute sets) + metric_stubs = self._create_metric_stubs_from_protobuf(metric) + new_metrics.extend(metric_stubs) if not new_metrics: return @@ -75,11 +75,40 @@ class OtlpHttpTestCollector(BaseTelemetryCollector): with self._lock: return tuple(self._spans) - def _snapshot_metrics(self) -> Any | None: + def _snapshot_metrics(self) -> tuple[MetricStub, ...] | None: with self._lock: - return list(self._metrics) if self._metrics else None + return tuple(self._metrics) if self._metrics else None def _clear_impl(self) -> None: + """Clear telemetry over a period of time to prevent race conditions between tests.""" + with self._lock: + self._spans.clear() + self._metrics.clear() + + # Prevent race conditions where telemetry arrives after clear() but before + # the test starts, causing contamination between tests + deadline = time.time() + 2.0 # Maximum wait time + last_span_count = 0 + last_metric_count = 0 + stable_iterations = 0 + + while time.time() < deadline: + with self._lock: + current_span_count = len(self._spans) + current_metric_count = len(self._metrics) + + if current_span_count == last_span_count and current_metric_count == last_metric_count: + stable_iterations += 1 + if stable_iterations >= 4: # 4 * 50ms = 200ms of stability + break + else: + stable_iterations = 0 + last_span_count = current_span_count + last_metric_count = current_metric_count + + time.sleep(0.05) + + # Final clear to remove any telemetry that arrived during stabilization with self._lock: self._spans.clear() self._metrics.clear() diff --git a/tests/integration/telemetry/test_completions.py b/tests/integration/telemetry/test_completions.py index 5322f021a..2b8835f6c 100644 --- a/tests/integration/telemetry/test_completions.py +++ b/tests/integration/telemetry/test_completions.py @@ -4,48 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -"""Telemetry tests verifying @trace_protocol decorator format across stack modes.""" +"""Telemetry tests verifying @trace_protocol decorator format across stack modes. + +Note: The mock_otlp_collector fixture automatically clears telemetry data +before and after each test, ensuring test isolation. +""" import json -def _span_attributes(span): - attrs = getattr(span, "attributes", None) - if attrs is None: - return {} - # ReadableSpan.attributes acts like a mapping - try: - return dict(attrs.items()) # type: ignore[attr-defined] - except AttributeError: - try: - return dict(attrs) - except TypeError: - return attrs - - -def _span_attr(span, key): - attrs = _span_attributes(span) - return attrs.get(key) - - -def _span_trace_id(span): - context = getattr(span, "context", None) - if context and getattr(context, "trace_id", None) is not None: - return f"{context.trace_id:032x}" - return getattr(span, "trace_id", None) - - -def _span_has_message(span, text: str) -> bool: - args = _span_attr(span, "__args__") - if not args or not isinstance(args, str): - return False - return text in args - - def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_model_id): """Verify streaming adds chunk_count and __type__=async_generator.""" - mock_otlp_collector.clear() - stream = llama_stack_client.chat.completions.create( model=text_model_id, messages=[{"role": "user", "content": "Test trace openai 1"}], @@ -62,16 +31,16 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod ( span for span in reversed(spans) - if _span_attr(span, "__type__") == "async_generator" - and _span_attr(span, "chunk_count") - and _span_has_message(span, "Test trace openai 1") + if span.get_span_type() == "async_generator" + and span.attributes.get("chunk_count") + and span.has_message("Test trace openai 1") ), None, ) assert async_generator_span is not None - raw_chunk_count = _span_attr(async_generator_span, "chunk_count") + raw_chunk_count = async_generator_span.attributes.get("chunk_count") assert raw_chunk_count is not None chunk_count = int(raw_chunk_count) @@ -80,7 +49,6 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, text_model_id): """Comprehensive validation of telemetry data format including spans and metrics.""" - mock_otlp_collector.clear() response = llama_stack_client.chat.completions.create( model=text_model_id, @@ -101,37 +69,36 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, # Verify spans spans = mock_otlp_collector.get_spans(expected_count=7) target_span = next( - (span for span in reversed(spans) if _span_has_message(span, "Test trace openai with temperature 0.7")), + (span for span in reversed(spans) if span.has_message("Test trace openai with temperature 0.7")), None, ) assert target_span is not None - trace_id = _span_trace_id(target_span) + trace_id = target_span.get_trace_id() assert trace_id is not None - spans = [span for span in spans if _span_trace_id(span) == trace_id] - spans = [span for span in spans if _span_attr(span, "__root__") or _span_attr(span, "__autotraced__")] + spans = [span for span in spans if span.get_trace_id() == trace_id] + spans = [span for span in spans if span.is_root_span() or span.is_autotraced()] assert len(spans) >= 4 # Collect all model_ids found in spans logged_model_ids = [] for span in spans: - attrs = _span_attributes(span) + attrs = span.attributes assert attrs is not None # Root span is created manually by tracing middleware, not by @trace_protocol decorator - is_root_span = attrs.get("__root__") is True - - if is_root_span: - assert attrs.get("__location__") in ["library_client", "server"] + if span.is_root_span(): + assert span.get_location() in ["library_client", "server"] continue - assert attrs.get("__autotraced__") - assert attrs.get("__class__") and attrs.get("__method__") - assert attrs.get("__type__") in ["async", "sync", "async_generator"] + assert span.is_autotraced() + class_name, method_name = span.get_class_method() + assert class_name and method_name + assert span.get_span_type() in ["async", "sync", "async_generator"] - args_field = attrs.get("__args__") + args_field = span.attributes.get("__args__") if args_field: args = json.loads(args_field) if "model_id" in args: @@ -140,21 +107,39 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, # At least one span should capture the fully qualified model ID assert text_model_id in logged_model_ids, f"Expected to find {text_model_id} in spans, but got {logged_model_ids}" - # TODO: re-enable this once metrics get fixed - """ - # Verify token usage metrics in response - metrics = mock_otlp_collector.get_metrics() + # Verify token usage metrics in response using polling + expected_metrics = ["completion_tokens", "total_tokens", "prompt_tokens"] + metrics = mock_otlp_collector.get_metrics(expected_count=len(expected_metrics), expect_model_id=text_model_id) + assert len(metrics) > 0, "No metrics found within timeout" - assert metrics - for metric in metrics: - assert metric.name in ["completion_tokens", "total_tokens", "prompt_tokens"] - assert metric.unit == "tokens" - assert metric.data.data_points and len(metric.data.data_points) == 1 - match metric.name: - case "completion_tokens": - assert metric.data.data_points[0].value == usage["completion_tokens"] - case "total_tokens": - assert metric.data.data_points[0].value == usage["total_tokens"] - case "prompt_tokens": - assert metric.data.data_points[0].value == usage["prompt_tokens" - """ + # Filter metrics to only those from the specific model used in the request + # Multiple metrics with the same name can exist (e.g., from safety models) + inference_model_metrics = {} + all_model_ids = set() + + for name, metric in metrics.items(): + if name in expected_metrics: + model_id = metric.attributes.get("model_id") + all_model_ids.add(model_id) + # Only include metrics from the specific model used in the test request + if model_id == text_model_id: + inference_model_metrics[name] = metric + + # Verify expected metrics are present for our specific model + for metric_name in expected_metrics: + assert metric_name in inference_model_metrics, ( + f"Expected metric {metric_name} for model {text_model_id} not found. " + f"Available models: {sorted(all_model_ids)}, " + f"Available metrics for {text_model_id}: {list(inference_model_metrics.keys())}" + ) + + # Verify metric values match usage data + assert inference_model_metrics["completion_tokens"].value == usage["completion_tokens"], ( + f"Expected {usage['completion_tokens']} for completion_tokens, but got {inference_model_metrics['completion_tokens'].value}" + ) + assert inference_model_metrics["total_tokens"].value == usage["total_tokens"], ( + f"Expected {usage['total_tokens']} for total_tokens, but got {inference_model_metrics['total_tokens'].value}" + ) + assert inference_model_metrics["prompt_tokens"].value == usage["prompt_tokens"], ( + f"Expected {usage['prompt_tokens']} for prompt_tokens, but got {inference_model_metrics['prompt_tokens'].value}" + )