diff --git a/src/llama_stack/core/telemetry/telemetry.py b/src/llama_stack/core/telemetry/telemetry.py index 1ba43724d..9476c961a 100644 --- a/src/llama_stack/core/telemetry/telemetry.py +++ b/src/llama_stack/core/telemetry/telemetry.py @@ -427,6 +427,7 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = { "counters": {}, "gauges": {}, "up_down_counters": {}, + "histograms": {}, } _global_lock = threading.Lock() _TRACER_PROVIDER = None @@ -540,6 +541,16 @@ class Telemetry: ) return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name]) + def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram: + assert self.meter is not None + if name not in _GLOBAL_STORAGE["histograms"]: + _GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram( + name=name, + unit=unit, + description=f"Histogram for {name}", + ) + return cast(metrics.Histogram, _GLOBAL_STORAGE["histograms"][name]) + def _log_metric(self, event: MetricEvent) -> None: # Add metric as an event to the current span try: @@ -571,7 +582,16 @@ class Telemetry: # Log to OpenTelemetry meter if available if self.meter is None: return - if isinstance(event.value, int): + + # Use histograms for token-related metrics (per-request measurements) + # Use counters for other cumulative metrics + token_metrics = {"prompt_tokens", "completion_tokens", "total_tokens"} + + if event.metric in token_metrics: + # Token metrics are per-request measurements, use histogram + histogram = self._get_or_create_histogram(event.metric, event.unit) + histogram.record(event.value, attributes=_clean_attributes(event.attributes)) + elif isinstance(event.value, int): counter = self._get_or_create_counter(event.metric, event.unit) counter.add(event.value, attributes=_clean_attributes(event.attributes)) elif isinstance(event.value, float): diff --git a/tests/integration/telemetry/collectors/base.py b/tests/integration/telemetry/collectors/base.py index f88ab37bf..963da5b8e 100644 --- a/tests/integration/telemetry/collectors/base.py +++ b/tests/integration/telemetry/collectors/base.py @@ -6,7 +6,8 @@ """Shared helpers for telemetry test collectors.""" -from collections.abc import Iterable, Mapping +import time +from collections.abc import Iterable from dataclasses import dataclass from typing import Any @@ -19,29 +20,13 @@ class MetricStub: value: Any attributes: dict[str, Any] | None = None - def get_value(self) -> Any: - """Get the metric value.""" - return self.value - - def get_name(self) -> str: - """Get the metric name.""" - return self.name - - def get_attributes(self) -> dict[str, Any]: - """Get metric attributes as a dictionary.""" - return self.attributes or {} - - def get_attribute(self, key: str) -> Any: - """Get a specific attribute value by key.""" - return self.get_attributes().get(key) - @dataclass class SpanStub: """Unified span interface for both in-memory and OTLP collectors.""" name: str - attributes: Mapping[str, Any] | None = None + attributes: dict[str, Any] | None = None resource_attributes: dict[str, Any] | None = None events: list[dict[str, Any]] | None = None trace_id: str | None = None @@ -54,19 +39,6 @@ class SpanStub: return None return type("Context", (), {"trace_id": int(self.trace_id, 16)})() - def get_attributes(self) -> dict[str, Any]: - """Get span attributes as a dictionary. - - Handles different attribute types (mapping, dict, etc.) and returns - a consistent dictionary format. - """ - return BaseTelemetryCollector._convert_attributes_to_dict(self.attributes) - - def get_attribute(self, key: str) -> Any: - """Get a specific attribute value by key.""" - attrs = self.get_attributes() - return attrs.get(key) - def get_trace_id(self) -> str | None: """Get trace ID in hex format. @@ -79,30 +51,42 @@ class SpanStub: def has_message(self, text: str) -> bool: """Check if span contains a specific message in its args.""" - args = self.get_attribute("__args__") + if self.attributes is None: + return False + args = self.attributes.get("__args__") if not args or not isinstance(args, str): return False return text in args def is_root_span(self) -> bool: """Check if this is a root span.""" - return self.get_attribute("__root__") is True + if self.attributes is None: + return False + return self.attributes.get("__root__") is True def is_autotraced(self) -> bool: """Check if this span was automatically traced.""" - return self.get_attribute("__autotraced__") is True + if self.attributes is None: + return False + return self.attributes.get("__autotraced__") is True def get_span_type(self) -> str | None: """Get the span type (async, sync, async_generator).""" - return self.get_attribute("__type__") + if self.attributes is None: + return None + return self.attributes.get("__type__") def get_class_method(self) -> tuple[str | None, str | None]: """Get the class and method names for autotraced spans.""" - return (self.get_attribute("__class__"), self.get_attribute("__method__")) + if self.attributes is None: + return None, None + return (self.attributes.get("__class__"), self.attributes.get("__method__")) def get_location(self) -> str | None: """Get the location (library_client, server) for root spans.""" - return self.get_attribute("__location__") + if self.attributes is None: + return None + return self.attributes.get("__location__") def _value_to_python(value: Any) -> Any: @@ -152,8 +136,6 @@ class BaseTelemetryCollector: timeout: float = 5.0, poll_interval: float = 0.05, ) -> tuple[SpanStub, ...]: - import time - deadline = time.time() + timeout min_count = expected_count if expected_count is not None else 1 last_len: int | None = None @@ -188,8 +170,8 @@ class BaseTelemetryCollector: poll_interval: float = 0.05, ) -> dict[str, MetricStub]: """Get metrics with polling until metrics are available or timeout is reached.""" - import time + # metrics need to be collected since get requests delete stored metrics deadline = time.time() + timeout min_count = expected_count if expected_count is not None else 1 accumulated_metrics = {} @@ -197,14 +179,11 @@ class BaseTelemetryCollector: while time.time() < deadline: current_metrics = self._snapshot_metrics() if current_metrics: - # Accumulate new metrics without losing existing ones for metric in current_metrics: - metric_name = metric.get_name() + metric_name = metric.name if metric_name not in accumulated_metrics: accumulated_metrics[metric_name] = metric else: - # If we already have this metric, keep the latest one - # (in case metrics are updated with new values) accumulated_metrics[metric_name] = metric # Check if we have enough metrics @@ -258,7 +237,7 @@ class BaseTelemetryCollector: This helper reduces code duplication between collectors. """ trace_id, span_id = BaseTelemetryCollector._extract_trace_span_ids(span) - attributes = BaseTelemetryCollector._convert_attributes_to_dict(span.attributes) + attributes = BaseTelemetryCollector._convert_attributes_to_dict(span.attributes) or {} return SpanStub( name=span.name, @@ -273,7 +252,7 @@ class BaseTelemetryCollector: This helper handles the different structure of protobuf spans. """ - attributes = attributes_to_dict(span.attributes) + attributes = attributes_to_dict(span.attributes) or {} events = events_to_list(span.events) if span.events else None trace_id = span.trace_id.hex() if span.trace_id else None span_id = span.span_id.hex() if span.span_id else None @@ -300,12 +279,22 @@ class BaseTelemetryCollector: return None # Get the value from the first data point - value = metric.data.data_points[0].value + data_point = metric.data.data_points[0] + + # Handle different metric types + if hasattr(data_point, "value"): + # Counter or Gauge + value = data_point.value + elif hasattr(data_point, "sum"): + # Histogram - use the sum of all recorded values + value = data_point.sum + else: + return None # Extract attributes if available attributes = {} - if hasattr(metric.data.data_points[0], "attributes"): - attrs = metric.data.data_points[0].attributes + if hasattr(data_point, "attributes"): + attrs = data_point.attributes if attrs is not None and hasattr(attrs, "items"): attributes = dict(attrs.items()) elif attrs is not None and not isinstance(attrs, dict): @@ -314,9 +303,48 @@ class BaseTelemetryCollector: return MetricStub( name=metric.name, value=value, - attributes=attributes if attributes else None, + attributes=attributes or {}, ) + @staticmethod + def _create_metric_stub_from_protobuf(metric: Any) -> MetricStub | None: + """Create MetricStub from protobuf metric object. + + Protobuf metrics have a different structure than OpenTelemetry metrics. + They can have sum, gauge, or histogram data. + """ + if not hasattr(metric, "name"): + return None + + # Try to extract value from different metric types + for metric_type in ["sum", "gauge", "histogram"]: + if hasattr(metric, metric_type): + metric_data = getattr(metric, metric_type) + if metric_data and hasattr(metric_data, "data_points"): + data_points = metric_data.data_points + if data_points and len(data_points) > 0: + data_point = data_points[0] + + # Extract attributes first (needed for all metric types) + attributes = ( + attributes_to_dict(data_point.attributes) if hasattr(data_point, "attributes") else {} + ) + + # Extract value based on metric type + if metric_type == "sum": + value = data_point.as_int + elif metric_type == "gauge": + value = data_point.as_double + else: # histogram + value = data_point.sum + + return MetricStub( + name=metric.name, + value=value, + attributes=attributes, + ) + return None + def clear(self) -> None: self._clear_impl() diff --git a/tests/integration/telemetry/collectors/otlp.py b/tests/integration/telemetry/collectors/otlp.py index a3535f818..024eb3ee5 100644 --- a/tests/integration/telemetry/collectors/otlp.py +++ b/tests/integration/telemetry/collectors/otlp.py @@ -11,7 +11,6 @@ import os import threading from http.server import BaseHTTPRequestHandler, HTTPServer from socketserver import ThreadingMixIn -from typing import Any from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2 import ExportMetricsServiceRequest from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest @@ -83,54 +82,6 @@ class OtlpHttpTestCollector(BaseTelemetryCollector): self._spans.clear() self._metrics.clear() - def _create_metric_stub_from_protobuf(self, metric: Any) -> MetricStub | None: - """Create MetricStub from protobuf metric object. - - Protobuf metrics have a different structure than OpenTelemetry metrics. - They can have sum, gauge, or histogram data. - """ - if not hasattr(metric, "name"): - return None - - # Try to extract value from different metric types - for metric_type in ["sum", "gauge", "histogram"]: - if hasattr(metric, metric_type): - metric_data = getattr(metric, metric_type) - if metric_data and hasattr(metric_data, "data_points"): - data_points = metric_data.data_points - if data_points and len(data_points) > 0: - data_point = data_points[0] - - # Extract value based on metric type - if metric_type == "sum": - value = data_point.as_int - elif metric_type == "gauge": - value = data_point.as_double - else: # histogram - value = data_point.count - - # Extract attributes if available - attributes = self._extract_attributes_from_data_point(data_point) - - return MetricStub( - name=metric.name, - value=value, - attributes=attributes if attributes else None, - ) - - return None - - def _extract_attributes_from_data_point(self, data_point: Any) -> dict[str, Any]: - """Extract attributes from a protobuf data point.""" - if not hasattr(data_point, "attributes"): - return {} - - attrs = data_point.attributes - if not attrs: - return {} - - return {kv.key: kv.value.string_value or kv.value.int_value or kv.value.double_value for kv in attrs} - def shutdown(self) -> None: self._server.shutdown() self._server.server_close() diff --git a/tests/integration/telemetry/test_completions.py b/tests/integration/telemetry/test_completions.py index d72f9e660..d1b97ef34 100644 --- a/tests/integration/telemetry/test_completions.py +++ b/tests/integration/telemetry/test_completions.py @@ -32,7 +32,7 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod span for span in reversed(spans) if span.get_span_type() == "async_generator" - and span.get_attribute("chunk_count") + and span.attributes.get("chunk_count") and span.has_message("Test trace openai 1") ), None, @@ -40,7 +40,7 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod assert async_generator_span is not None - raw_chunk_count = async_generator_span.get_attribute("chunk_count") + raw_chunk_count = async_generator_span.attributes.get("chunk_count") assert raw_chunk_count is not None chunk_count = int(raw_chunk_count) @@ -85,7 +85,7 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, logged_model_ids = [] for span in spans: - attrs = span.get_attributes() + attrs = span.attributes assert attrs is not None # Root span is created manually by tracing middleware, not by @trace_protocol decorator @@ -98,7 +98,7 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, assert class_name and method_name assert span.get_span_type() in ["async", "sync", "async_generator"] - args_field = span.get_attribute("__args__") + args_field = span.attributes.get("__args__") if args_field: args = json.loads(args_field) if "model_id" in args: @@ -115,37 +115,32 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, # Filter metrics to only those from the specific model used in the request # This prevents issues when multiple metrics with the same name exist from different models # (e.g., when safety models like llama-guard are also called) - model_metrics = {} + inference_model_metrics = {} all_model_ids = set() for name, metric in metrics.items(): if name in expected_metrics: - model_id = metric.get_attribute("model_id") + model_id = metric.attributes.get("model_id") all_model_ids.add(model_id) # Only include metrics from the specific model used in the test request if model_id == text_model_id: - model_metrics[name] = metric - - # Provide helpful error message if we have metrics from multiple models - if len(all_model_ids) > 1: - print(f"Note: Found metrics from multiple models: {sorted(all_model_ids)}") - print(f"Filtering to only metrics from test model: {text_model_id}") + inference_model_metrics[name] = metric # Verify expected metrics are present for our specific model for metric_name in expected_metrics: - assert metric_name in model_metrics, ( + assert metric_name in inference_model_metrics, ( f"Expected metric {metric_name} for model {text_model_id} not found. " f"Available models: {sorted(all_model_ids)}, " - f"Available metrics for {text_model_id}: {list(model_metrics.keys())}" + f"Available metrics for {text_model_id}: {list(inference_model_metrics.keys())}" ) # Verify metric values match usage data - assert model_metrics["completion_tokens"].get_value() == usage["completion_tokens"], ( - f"Expected {usage['completion_tokens']} for completion_tokens, but got {model_metrics['completion_tokens'].get_value()}" + assert inference_model_metrics["completion_tokens"].value == usage["completion_tokens"], ( + f"Expected {usage['completion_tokens']} for completion_tokens, but got {inference_model_metrics['completion_tokens'].value}" ) - assert model_metrics["total_tokens"].get_value() == usage["total_tokens"], ( - f"Expected {usage['total_tokens']} for total_tokens, but got {model_metrics['total_tokens'].get_value()}" + assert inference_model_metrics["total_tokens"].value == usage["total_tokens"], ( + f"Expected {usage['total_tokens']} for total_tokens, but got {inference_model_metrics['total_tokens'].value}" ) - assert model_metrics["prompt_tokens"].get_value() == usage["prompt_tokens"], ( - f"Expected {usage['prompt_tokens']} for prompt_tokens, but got {model_metrics['prompt_tokens'].get_value()}" + assert inference_model_metrics["prompt_tokens"].value == usage["prompt_tokens"], ( + f"Expected {usage['prompt_tokens']} for prompt_tokens, but got {inference_model_metrics['prompt_tokens'].value}" )