fix(telemetry): token counters changed to histograms to reflect count per request

This commit is contained in:
Emilio Garcia 2025-10-30 12:59:24 -04:00
parent 0e0bc8aba7
commit 23fce9718c
4 changed files with 114 additions and 120 deletions

View file

@ -427,6 +427,7 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"counters": {}, "counters": {},
"gauges": {}, "gauges": {},
"up_down_counters": {}, "up_down_counters": {},
"histograms": {},
} }
_global_lock = threading.Lock() _global_lock = threading.Lock()
_TRACER_PROVIDER = None _TRACER_PROVIDER = None
@ -540,6 +541,16 @@ class Telemetry:
) )
return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name]) return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["histograms"]:
_GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram(
name=name,
unit=unit,
description=f"Histogram for {name}",
)
return cast(metrics.Histogram, _GLOBAL_STORAGE["histograms"][name])
def _log_metric(self, event: MetricEvent) -> None: def _log_metric(self, event: MetricEvent) -> None:
# Add metric as an event to the current span # Add metric as an event to the current span
try: try:
@ -571,7 +582,16 @@ class Telemetry:
# Log to OpenTelemetry meter if available # Log to OpenTelemetry meter if available
if self.meter is None: if self.meter is None:
return return
if isinstance(event.value, int):
# Use histograms for token-related metrics (per-request measurements)
# Use counters for other cumulative metrics
token_metrics = {"prompt_tokens", "completion_tokens", "total_tokens"}
if event.metric in token_metrics:
# Token metrics are per-request measurements, use histogram
histogram = self._get_or_create_histogram(event.metric, event.unit)
histogram.record(event.value, attributes=_clean_attributes(event.attributes))
elif isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit) counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=_clean_attributes(event.attributes)) counter.add(event.value, attributes=_clean_attributes(event.attributes))
elif isinstance(event.value, float): elif isinstance(event.value, float):

View file

@ -6,7 +6,8 @@
"""Shared helpers for telemetry test collectors.""" """Shared helpers for telemetry test collectors."""
from collections.abc import Iterable, Mapping import time
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
@ -19,29 +20,13 @@ class MetricStub:
value: Any value: Any
attributes: dict[str, Any] | None = None attributes: dict[str, Any] | None = None
def get_value(self) -> Any:
"""Get the metric value."""
return self.value
def get_name(self) -> str:
"""Get the metric name."""
return self.name
def get_attributes(self) -> dict[str, Any]:
"""Get metric attributes as a dictionary."""
return self.attributes or {}
def get_attribute(self, key: str) -> Any:
"""Get a specific attribute value by key."""
return self.get_attributes().get(key)
@dataclass @dataclass
class SpanStub: class SpanStub:
"""Unified span interface for both in-memory and OTLP collectors.""" """Unified span interface for both in-memory and OTLP collectors."""
name: str name: str
attributes: Mapping[str, Any] | None = None attributes: dict[str, Any] | None = None
resource_attributes: dict[str, Any] | None = None resource_attributes: dict[str, Any] | None = None
events: list[dict[str, Any]] | None = None events: list[dict[str, Any]] | None = None
trace_id: str | None = None trace_id: str | None = None
@ -54,19 +39,6 @@ class SpanStub:
return None return None
return type("Context", (), {"trace_id": int(self.trace_id, 16)})() return type("Context", (), {"trace_id": int(self.trace_id, 16)})()
def get_attributes(self) -> dict[str, Any]:
"""Get span attributes as a dictionary.
Handles different attribute types (mapping, dict, etc.) and returns
a consistent dictionary format.
"""
return BaseTelemetryCollector._convert_attributes_to_dict(self.attributes)
def get_attribute(self, key: str) -> Any:
"""Get a specific attribute value by key."""
attrs = self.get_attributes()
return attrs.get(key)
def get_trace_id(self) -> str | None: def get_trace_id(self) -> str | None:
"""Get trace ID in hex format. """Get trace ID in hex format.
@ -79,30 +51,42 @@ class SpanStub:
def has_message(self, text: str) -> bool: def has_message(self, text: str) -> bool:
"""Check if span contains a specific message in its args.""" """Check if span contains a specific message in its args."""
args = self.get_attribute("__args__") if self.attributes is None:
return False
args = self.attributes.get("__args__")
if not args or not isinstance(args, str): if not args or not isinstance(args, str):
return False return False
return text in args return text in args
def is_root_span(self) -> bool: def is_root_span(self) -> bool:
"""Check if this is a root span.""" """Check if this is a root span."""
return self.get_attribute("__root__") is True if self.attributes is None:
return False
return self.attributes.get("__root__") is True
def is_autotraced(self) -> bool: def is_autotraced(self) -> bool:
"""Check if this span was automatically traced.""" """Check if this span was automatically traced."""
return self.get_attribute("__autotraced__") is True if self.attributes is None:
return False
return self.attributes.get("__autotraced__") is True
def get_span_type(self) -> str | None: def get_span_type(self) -> str | None:
"""Get the span type (async, sync, async_generator).""" """Get the span type (async, sync, async_generator)."""
return self.get_attribute("__type__") if self.attributes is None:
return None
return self.attributes.get("__type__")
def get_class_method(self) -> tuple[str | None, str | None]: def get_class_method(self) -> tuple[str | None, str | None]:
"""Get the class and method names for autotraced spans.""" """Get the class and method names for autotraced spans."""
return (self.get_attribute("__class__"), self.get_attribute("__method__")) if self.attributes is None:
return None, None
return (self.attributes.get("__class__"), self.attributes.get("__method__"))
def get_location(self) -> str | None: def get_location(self) -> str | None:
"""Get the location (library_client, server) for root spans.""" """Get the location (library_client, server) for root spans."""
return self.get_attribute("__location__") if self.attributes is None:
return None
return self.attributes.get("__location__")
def _value_to_python(value: Any) -> Any: def _value_to_python(value: Any) -> Any:
@ -152,8 +136,6 @@ class BaseTelemetryCollector:
timeout: float = 5.0, timeout: float = 5.0,
poll_interval: float = 0.05, poll_interval: float = 0.05,
) -> tuple[SpanStub, ...]: ) -> tuple[SpanStub, ...]:
import time
deadline = time.time() + timeout deadline = time.time() + timeout
min_count = expected_count if expected_count is not None else 1 min_count = expected_count if expected_count is not None else 1
last_len: int | None = None last_len: int | None = None
@ -188,8 +170,8 @@ class BaseTelemetryCollector:
poll_interval: float = 0.05, poll_interval: float = 0.05,
) -> dict[str, MetricStub]: ) -> dict[str, MetricStub]:
"""Get metrics with polling until metrics are available or timeout is reached.""" """Get metrics with polling until metrics are available or timeout is reached."""
import time
# metrics need to be collected since get requests delete stored metrics
deadline = time.time() + timeout deadline = time.time() + timeout
min_count = expected_count if expected_count is not None else 1 min_count = expected_count if expected_count is not None else 1
accumulated_metrics = {} accumulated_metrics = {}
@ -197,14 +179,11 @@ class BaseTelemetryCollector:
while time.time() < deadline: while time.time() < deadline:
current_metrics = self._snapshot_metrics() current_metrics = self._snapshot_metrics()
if current_metrics: if current_metrics:
# Accumulate new metrics without losing existing ones
for metric in current_metrics: for metric in current_metrics:
metric_name = metric.get_name() metric_name = metric.name
if metric_name not in accumulated_metrics: if metric_name not in accumulated_metrics:
accumulated_metrics[metric_name] = metric accumulated_metrics[metric_name] = metric
else: else:
# If we already have this metric, keep the latest one
# (in case metrics are updated with new values)
accumulated_metrics[metric_name] = metric accumulated_metrics[metric_name] = metric
# Check if we have enough metrics # Check if we have enough metrics
@ -258,7 +237,7 @@ class BaseTelemetryCollector:
This helper reduces code duplication between collectors. This helper reduces code duplication between collectors.
""" """
trace_id, span_id = BaseTelemetryCollector._extract_trace_span_ids(span) trace_id, span_id = BaseTelemetryCollector._extract_trace_span_ids(span)
attributes = BaseTelemetryCollector._convert_attributes_to_dict(span.attributes) attributes = BaseTelemetryCollector._convert_attributes_to_dict(span.attributes) or {}
return SpanStub( return SpanStub(
name=span.name, name=span.name,
@ -273,7 +252,7 @@ class BaseTelemetryCollector:
This helper handles the different structure of protobuf spans. This helper handles the different structure of protobuf spans.
""" """
attributes = attributes_to_dict(span.attributes) attributes = attributes_to_dict(span.attributes) or {}
events = events_to_list(span.events) if span.events else None events = events_to_list(span.events) if span.events else None
trace_id = span.trace_id.hex() if span.trace_id else None trace_id = span.trace_id.hex() if span.trace_id else None
span_id = span.span_id.hex() if span.span_id else None span_id = span.span_id.hex() if span.span_id else None
@ -300,12 +279,22 @@ class BaseTelemetryCollector:
return None return None
# Get the value from the first data point # Get the value from the first data point
value = metric.data.data_points[0].value data_point = metric.data.data_points[0]
# Handle different metric types
if hasattr(data_point, "value"):
# Counter or Gauge
value = data_point.value
elif hasattr(data_point, "sum"):
# Histogram - use the sum of all recorded values
value = data_point.sum
else:
return None
# Extract attributes if available # Extract attributes if available
attributes = {} attributes = {}
if hasattr(metric.data.data_points[0], "attributes"): if hasattr(data_point, "attributes"):
attrs = metric.data.data_points[0].attributes attrs = data_point.attributes
if attrs is not None and hasattr(attrs, "items"): if attrs is not None and hasattr(attrs, "items"):
attributes = dict(attrs.items()) attributes = dict(attrs.items())
elif attrs is not None and not isinstance(attrs, dict): elif attrs is not None and not isinstance(attrs, dict):
@ -314,9 +303,48 @@ class BaseTelemetryCollector:
return MetricStub( return MetricStub(
name=metric.name, name=metric.name,
value=value, value=value,
attributes=attributes if attributes else None, attributes=attributes or {},
) )
@staticmethod
def _create_metric_stub_from_protobuf(metric: Any) -> MetricStub | None:
"""Create MetricStub from protobuf metric object.
Protobuf metrics have a different structure than OpenTelemetry metrics.
They can have sum, gauge, or histogram data.
"""
if not hasattr(metric, "name"):
return None
# Try to extract value from different metric types
for metric_type in ["sum", "gauge", "histogram"]:
if hasattr(metric, metric_type):
metric_data = getattr(metric, metric_type)
if metric_data and hasattr(metric_data, "data_points"):
data_points = metric_data.data_points
if data_points and len(data_points) > 0:
data_point = data_points[0]
# Extract attributes first (needed for all metric types)
attributes = (
attributes_to_dict(data_point.attributes) if hasattr(data_point, "attributes") else {}
)
# Extract value based on metric type
if metric_type == "sum":
value = data_point.as_int
elif metric_type == "gauge":
value = data_point.as_double
else: # histogram
value = data_point.sum
return MetricStub(
name=metric.name,
value=value,
attributes=attributes,
)
return None
def clear(self) -> None: def clear(self) -> None:
self._clear_impl() self._clear_impl()

View file

@ -11,7 +11,6 @@ import os
import threading import threading
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from socketserver import ThreadingMixIn from socketserver import ThreadingMixIn
from typing import Any
from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2 import ExportMetricsServiceRequest from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2 import ExportMetricsServiceRequest
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest
@ -83,54 +82,6 @@ class OtlpHttpTestCollector(BaseTelemetryCollector):
self._spans.clear() self._spans.clear()
self._metrics.clear() self._metrics.clear()
def _create_metric_stub_from_protobuf(self, metric: Any) -> MetricStub | None:
"""Create MetricStub from protobuf metric object.
Protobuf metrics have a different structure than OpenTelemetry metrics.
They can have sum, gauge, or histogram data.
"""
if not hasattr(metric, "name"):
return None
# Try to extract value from different metric types
for metric_type in ["sum", "gauge", "histogram"]:
if hasattr(metric, metric_type):
metric_data = getattr(metric, metric_type)
if metric_data and hasattr(metric_data, "data_points"):
data_points = metric_data.data_points
if data_points and len(data_points) > 0:
data_point = data_points[0]
# Extract value based on metric type
if metric_type == "sum":
value = data_point.as_int
elif metric_type == "gauge":
value = data_point.as_double
else: # histogram
value = data_point.count
# Extract attributes if available
attributes = self._extract_attributes_from_data_point(data_point)
return MetricStub(
name=metric.name,
value=value,
attributes=attributes if attributes else None,
)
return None
def _extract_attributes_from_data_point(self, data_point: Any) -> dict[str, Any]:
"""Extract attributes from a protobuf data point."""
if not hasattr(data_point, "attributes"):
return {}
attrs = data_point.attributes
if not attrs:
return {}
return {kv.key: kv.value.string_value or kv.value.int_value or kv.value.double_value for kv in attrs}
def shutdown(self) -> None: def shutdown(self) -> None:
self._server.shutdown() self._server.shutdown()
self._server.server_close() self._server.server_close()

View file

@ -32,7 +32,7 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod
span span
for span in reversed(spans) for span in reversed(spans)
if span.get_span_type() == "async_generator" if span.get_span_type() == "async_generator"
and span.get_attribute("chunk_count") and span.attributes.get("chunk_count")
and span.has_message("Test trace openai 1") and span.has_message("Test trace openai 1")
), ),
None, None,
@ -40,7 +40,7 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod
assert async_generator_span is not None assert async_generator_span is not None
raw_chunk_count = async_generator_span.get_attribute("chunk_count") raw_chunk_count = async_generator_span.attributes.get("chunk_count")
assert raw_chunk_count is not None assert raw_chunk_count is not None
chunk_count = int(raw_chunk_count) chunk_count = int(raw_chunk_count)
@ -85,7 +85,7 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
logged_model_ids = [] logged_model_ids = []
for span in spans: for span in spans:
attrs = span.get_attributes() attrs = span.attributes
assert attrs is not None assert attrs is not None
# Root span is created manually by tracing middleware, not by @trace_protocol decorator # Root span is created manually by tracing middleware, not by @trace_protocol decorator
@ -98,7 +98,7 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
assert class_name and method_name assert class_name and method_name
assert span.get_span_type() in ["async", "sync", "async_generator"] assert span.get_span_type() in ["async", "sync", "async_generator"]
args_field = span.get_attribute("__args__") args_field = span.attributes.get("__args__")
if args_field: if args_field:
args = json.loads(args_field) args = json.loads(args_field)
if "model_id" in args: if "model_id" in args:
@ -115,37 +115,32 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
# Filter metrics to only those from the specific model used in the request # Filter metrics to only those from the specific model used in the request
# This prevents issues when multiple metrics with the same name exist from different models # This prevents issues when multiple metrics with the same name exist from different models
# (e.g., when safety models like llama-guard are also called) # (e.g., when safety models like llama-guard are also called)
model_metrics = {} inference_model_metrics = {}
all_model_ids = set() all_model_ids = set()
for name, metric in metrics.items(): for name, metric in metrics.items():
if name in expected_metrics: if name in expected_metrics:
model_id = metric.get_attribute("model_id") model_id = metric.attributes.get("model_id")
all_model_ids.add(model_id) all_model_ids.add(model_id)
# Only include metrics from the specific model used in the test request # Only include metrics from the specific model used in the test request
if model_id == text_model_id: if model_id == text_model_id:
model_metrics[name] = metric inference_model_metrics[name] = metric
# Provide helpful error message if we have metrics from multiple models
if len(all_model_ids) > 1:
print(f"Note: Found metrics from multiple models: {sorted(all_model_ids)}")
print(f"Filtering to only metrics from test model: {text_model_id}")
# Verify expected metrics are present for our specific model # Verify expected metrics are present for our specific model
for metric_name in expected_metrics: for metric_name in expected_metrics:
assert metric_name in model_metrics, ( assert metric_name in inference_model_metrics, (
f"Expected metric {metric_name} for model {text_model_id} not found. " f"Expected metric {metric_name} for model {text_model_id} not found. "
f"Available models: {sorted(all_model_ids)}, " f"Available models: {sorted(all_model_ids)}, "
f"Available metrics for {text_model_id}: {list(model_metrics.keys())}" f"Available metrics for {text_model_id}: {list(inference_model_metrics.keys())}"
) )
# Verify metric values match usage data # Verify metric values match usage data
assert model_metrics["completion_tokens"].get_value() == usage["completion_tokens"], ( assert inference_model_metrics["completion_tokens"].value == usage["completion_tokens"], (
f"Expected {usage['completion_tokens']} for completion_tokens, but got {model_metrics['completion_tokens'].get_value()}" f"Expected {usage['completion_tokens']} for completion_tokens, but got {inference_model_metrics['completion_tokens'].value}"
) )
assert model_metrics["total_tokens"].get_value() == usage["total_tokens"], ( assert inference_model_metrics["total_tokens"].value == usage["total_tokens"], (
f"Expected {usage['total_tokens']} for total_tokens, but got {model_metrics['total_tokens'].get_value()}" f"Expected {usage['total_tokens']} for total_tokens, but got {inference_model_metrics['total_tokens'].value}"
) )
assert model_metrics["prompt_tokens"].get_value() == usage["prompt_tokens"], ( assert inference_model_metrics["prompt_tokens"].value == usage["prompt_tokens"], (
f"Expected {usage['prompt_tokens']} for prompt_tokens, but got {model_metrics['prompt_tokens'].get_value()}" f"Expected {usage['prompt_tokens']} for prompt_tokens, but got {inference_model_metrics['prompt_tokens'].value}"
) )