feat(tests): metrics tests (#3966)

# What does this PR do?
1. Make telemetry tests as easy as possible for users by expanding the
`SpanStub` data class and creating the `MetricStub` dataclass as a way
to consistently marshal telemetry data in test fixtures and unmarshal
and handle it in tests.
2. Structure server and client tests to always follow the same standards
for consistent testing experience by using the `SpanStub` and
`MetricStub` data class objects.
3. Enable Metrics Testing for completions endpoint
4. Correct token metrics to use histograms instead of counts to capture
tokens per request rather than a cumulative count of tokens over the
lifecycle of the server.

## Test Plan
These are tests
This commit is contained in:
Emilio Garcia 2025-11-05 13:26:15 -05:00 committed by GitHub
parent 2619f3552e
commit ba50790a28
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 647 additions and 130 deletions

View file

@ -4,48 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Telemetry tests verifying @trace_protocol decorator format across stack modes."""
"""Telemetry tests verifying @trace_protocol decorator format across stack modes.
Note: The mock_otlp_collector fixture automatically clears telemetry data
before and after each test, ensuring test isolation.
"""
import json
def _span_attributes(span):
attrs = getattr(span, "attributes", None)
if attrs is None:
return {}
# ReadableSpan.attributes acts like a mapping
try:
return dict(attrs.items()) # type: ignore[attr-defined]
except AttributeError:
try:
return dict(attrs)
except TypeError:
return attrs
def _span_attr(span, key):
attrs = _span_attributes(span)
return attrs.get(key)
def _span_trace_id(span):
context = getattr(span, "context", None)
if context and getattr(context, "trace_id", None) is not None:
return f"{context.trace_id:032x}"
return getattr(span, "trace_id", None)
def _span_has_message(span, text: str) -> bool:
args = _span_attr(span, "__args__")
if not args or not isinstance(args, str):
return False
return text in args
def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_model_id):
"""Verify streaming adds chunk_count and __type__=async_generator."""
mock_otlp_collector.clear()
stream = llama_stack_client.chat.completions.create(
model=text_model_id,
messages=[{"role": "user", "content": "Test trace openai 1"}],
@ -62,16 +31,16 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod
(
span
for span in reversed(spans)
if _span_attr(span, "__type__") == "async_generator"
and _span_attr(span, "chunk_count")
and _span_has_message(span, "Test trace openai 1")
if span.get_span_type() == "async_generator"
and span.attributes.get("chunk_count")
and span.has_message("Test trace openai 1")
),
None,
)
assert async_generator_span is not None
raw_chunk_count = _span_attr(async_generator_span, "chunk_count")
raw_chunk_count = async_generator_span.attributes.get("chunk_count")
assert raw_chunk_count is not None
chunk_count = int(raw_chunk_count)
@ -80,7 +49,6 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod
def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, text_model_id):
"""Comprehensive validation of telemetry data format including spans and metrics."""
mock_otlp_collector.clear()
response = llama_stack_client.chat.completions.create(
model=text_model_id,
@ -101,37 +69,36 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
# Verify spans
spans = mock_otlp_collector.get_spans(expected_count=7)
target_span = next(
(span for span in reversed(spans) if _span_has_message(span, "Test trace openai with temperature 0.7")),
(span for span in reversed(spans) if span.has_message("Test trace openai with temperature 0.7")),
None,
)
assert target_span is not None
trace_id = _span_trace_id(target_span)
trace_id = target_span.get_trace_id()
assert trace_id is not None
spans = [span for span in spans if _span_trace_id(span) == trace_id]
spans = [span for span in spans if _span_attr(span, "__root__") or _span_attr(span, "__autotraced__")]
spans = [span for span in spans if span.get_trace_id() == trace_id]
spans = [span for span in spans if span.is_root_span() or span.is_autotraced()]
assert len(spans) >= 4
# Collect all model_ids found in spans
logged_model_ids = []
for span in spans:
attrs = _span_attributes(span)
attrs = span.attributes
assert attrs is not None
# Root span is created manually by tracing middleware, not by @trace_protocol decorator
is_root_span = attrs.get("__root__") is True
if is_root_span:
assert attrs.get("__location__") in ["library_client", "server"]
if span.is_root_span():
assert span.get_location() in ["library_client", "server"]
continue
assert attrs.get("__autotraced__")
assert attrs.get("__class__") and attrs.get("__method__")
assert attrs.get("__type__") in ["async", "sync", "async_generator"]
assert span.is_autotraced()
class_name, method_name = span.get_class_method()
assert class_name and method_name
assert span.get_span_type() in ["async", "sync", "async_generator"]
args_field = attrs.get("__args__")
args_field = span.attributes.get("__args__")
if args_field:
args = json.loads(args_field)
if "model_id" in args:
@ -140,21 +107,39 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
# At least one span should capture the fully qualified model ID
assert text_model_id in logged_model_ids, f"Expected to find {text_model_id} in spans, but got {logged_model_ids}"
# TODO: re-enable this once metrics get fixed
"""
# Verify token usage metrics in response
metrics = mock_otlp_collector.get_metrics()
# Verify token usage metrics in response using polling
expected_metrics = ["completion_tokens", "total_tokens", "prompt_tokens"]
metrics = mock_otlp_collector.get_metrics(expected_count=len(expected_metrics), expect_model_id=text_model_id)
assert len(metrics) > 0, "No metrics found within timeout"
assert metrics
for metric in metrics:
assert metric.name in ["completion_tokens", "total_tokens", "prompt_tokens"]
assert metric.unit == "tokens"
assert metric.data.data_points and len(metric.data.data_points) == 1
match metric.name:
case "completion_tokens":
assert metric.data.data_points[0].value == usage["completion_tokens"]
case "total_tokens":
assert metric.data.data_points[0].value == usage["total_tokens"]
case "prompt_tokens":
assert metric.data.data_points[0].value == usage["prompt_tokens"
"""
# Filter metrics to only those from the specific model used in the request
# Multiple metrics with the same name can exist (e.g., from safety models)
inference_model_metrics = {}
all_model_ids = set()
for name, metric in metrics.items():
if name in expected_metrics:
model_id = metric.attributes.get("model_id")
all_model_ids.add(model_id)
# Only include metrics from the specific model used in the test request
if model_id == text_model_id:
inference_model_metrics[name] = metric
# Verify expected metrics are present for our specific model
for metric_name in expected_metrics:
assert metric_name in inference_model_metrics, (
f"Expected metric {metric_name} for model {text_model_id} not found. "
f"Available models: {sorted(all_model_ids)}, "
f"Available metrics for {text_model_id}: {list(inference_model_metrics.keys())}"
)
# Verify metric values match usage data
assert inference_model_metrics["completion_tokens"].value == usage["completion_tokens"], (
f"Expected {usage['completion_tokens']} for completion_tokens, but got {inference_model_metrics['completion_tokens'].value}"
)
assert inference_model_metrics["total_tokens"].value == usage["total_tokens"], (
f"Expected {usage['total_tokens']} for total_tokens, but got {inference_model_metrics['total_tokens'].value}"
)
assert inference_model_metrics["prompt_tokens"].value == usage["prompt_tokens"], (
f"Expected {usage['prompt_tokens']} for prompt_tokens, but got {inference_model_metrics['prompt_tokens'].value}"
)