mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat(tests): metrics tests (#3966)
# What does this PR do? 1. Make telemetry tests as easy as possible for users by expanding the `SpanStub` data class and creating the `MetricStub` dataclass as a way to consistently marshal telemetry data in test fixtures and unmarshal and handle it in tests. 2. Structure server and client tests to always follow the same standards for consistent testing experience by using the `SpanStub` and `MetricStub` data class objects. 3. Enable Metrics Testing for completions endpoint 4. Correct token metrics to use histograms instead of counts to capture tokens per request rather than a cumulative count of tokens over the lifecycle of the server. ## Test Plan These are tests
This commit is contained in:
parent
2619f3552e
commit
ba50790a28
7 changed files with 647 additions and 130 deletions
|
|
@ -4,48 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Telemetry tests verifying @trace_protocol decorator format across stack modes."""
|
||||
"""Telemetry tests verifying @trace_protocol decorator format across stack modes.
|
||||
|
||||
Note: The mock_otlp_collector fixture automatically clears telemetry data
|
||||
before and after each test, ensuring test isolation.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
def _span_attributes(span):
|
||||
attrs = getattr(span, "attributes", None)
|
||||
if attrs is None:
|
||||
return {}
|
||||
# ReadableSpan.attributes acts like a mapping
|
||||
try:
|
||||
return dict(attrs.items()) # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
try:
|
||||
return dict(attrs)
|
||||
except TypeError:
|
||||
return attrs
|
||||
|
||||
|
||||
def _span_attr(span, key):
|
||||
attrs = _span_attributes(span)
|
||||
return attrs.get(key)
|
||||
|
||||
|
||||
def _span_trace_id(span):
|
||||
context = getattr(span, "context", None)
|
||||
if context and getattr(context, "trace_id", None) is not None:
|
||||
return f"{context.trace_id:032x}"
|
||||
return getattr(span, "trace_id", None)
|
||||
|
||||
|
||||
def _span_has_message(span, text: str) -> bool:
|
||||
args = _span_attr(span, "__args__")
|
||||
if not args or not isinstance(args, str):
|
||||
return False
|
||||
return text in args
|
||||
|
||||
|
||||
def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_model_id):
|
||||
"""Verify streaming adds chunk_count and __type__=async_generator."""
|
||||
mock_otlp_collector.clear()
|
||||
|
||||
stream = llama_stack_client.chat.completions.create(
|
||||
model=text_model_id,
|
||||
messages=[{"role": "user", "content": "Test trace openai 1"}],
|
||||
|
|
@ -62,16 +31,16 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod
|
|||
(
|
||||
span
|
||||
for span in reversed(spans)
|
||||
if _span_attr(span, "__type__") == "async_generator"
|
||||
and _span_attr(span, "chunk_count")
|
||||
and _span_has_message(span, "Test trace openai 1")
|
||||
if span.get_span_type() == "async_generator"
|
||||
and span.attributes.get("chunk_count")
|
||||
and span.has_message("Test trace openai 1")
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
assert async_generator_span is not None
|
||||
|
||||
raw_chunk_count = _span_attr(async_generator_span, "chunk_count")
|
||||
raw_chunk_count = async_generator_span.attributes.get("chunk_count")
|
||||
assert raw_chunk_count is not None
|
||||
chunk_count = int(raw_chunk_count)
|
||||
|
||||
|
|
@ -80,7 +49,6 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod
|
|||
|
||||
def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, text_model_id):
|
||||
"""Comprehensive validation of telemetry data format including spans and metrics."""
|
||||
mock_otlp_collector.clear()
|
||||
|
||||
response = llama_stack_client.chat.completions.create(
|
||||
model=text_model_id,
|
||||
|
|
@ -101,37 +69,36 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
|
|||
# Verify spans
|
||||
spans = mock_otlp_collector.get_spans(expected_count=7)
|
||||
target_span = next(
|
||||
(span for span in reversed(spans) if _span_has_message(span, "Test trace openai with temperature 0.7")),
|
||||
(span for span in reversed(spans) if span.has_message("Test trace openai with temperature 0.7")),
|
||||
None,
|
||||
)
|
||||
assert target_span is not None
|
||||
|
||||
trace_id = _span_trace_id(target_span)
|
||||
trace_id = target_span.get_trace_id()
|
||||
assert trace_id is not None
|
||||
|
||||
spans = [span for span in spans if _span_trace_id(span) == trace_id]
|
||||
spans = [span for span in spans if _span_attr(span, "__root__") or _span_attr(span, "__autotraced__")]
|
||||
spans = [span for span in spans if span.get_trace_id() == trace_id]
|
||||
spans = [span for span in spans if span.is_root_span() or span.is_autotraced()]
|
||||
assert len(spans) >= 4
|
||||
|
||||
# Collect all model_ids found in spans
|
||||
logged_model_ids = []
|
||||
|
||||
for span in spans:
|
||||
attrs = _span_attributes(span)
|
||||
attrs = span.attributes
|
||||
assert attrs is not None
|
||||
|
||||
# Root span is created manually by tracing middleware, not by @trace_protocol decorator
|
||||
is_root_span = attrs.get("__root__") is True
|
||||
|
||||
if is_root_span:
|
||||
assert attrs.get("__location__") in ["library_client", "server"]
|
||||
if span.is_root_span():
|
||||
assert span.get_location() in ["library_client", "server"]
|
||||
continue
|
||||
|
||||
assert attrs.get("__autotraced__")
|
||||
assert attrs.get("__class__") and attrs.get("__method__")
|
||||
assert attrs.get("__type__") in ["async", "sync", "async_generator"]
|
||||
assert span.is_autotraced()
|
||||
class_name, method_name = span.get_class_method()
|
||||
assert class_name and method_name
|
||||
assert span.get_span_type() in ["async", "sync", "async_generator"]
|
||||
|
||||
args_field = attrs.get("__args__")
|
||||
args_field = span.attributes.get("__args__")
|
||||
if args_field:
|
||||
args = json.loads(args_field)
|
||||
if "model_id" in args:
|
||||
|
|
@ -140,21 +107,39 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
|
|||
# At least one span should capture the fully qualified model ID
|
||||
assert text_model_id in logged_model_ids, f"Expected to find {text_model_id} in spans, but got {logged_model_ids}"
|
||||
|
||||
# TODO: re-enable this once metrics get fixed
|
||||
"""
|
||||
# Verify token usage metrics in response
|
||||
metrics = mock_otlp_collector.get_metrics()
|
||||
# Verify token usage metrics in response using polling
|
||||
expected_metrics = ["completion_tokens", "total_tokens", "prompt_tokens"]
|
||||
metrics = mock_otlp_collector.get_metrics(expected_count=len(expected_metrics), expect_model_id=text_model_id)
|
||||
assert len(metrics) > 0, "No metrics found within timeout"
|
||||
|
||||
assert metrics
|
||||
for metric in metrics:
|
||||
assert metric.name in ["completion_tokens", "total_tokens", "prompt_tokens"]
|
||||
assert metric.unit == "tokens"
|
||||
assert metric.data.data_points and len(metric.data.data_points) == 1
|
||||
match metric.name:
|
||||
case "completion_tokens":
|
||||
assert metric.data.data_points[0].value == usage["completion_tokens"]
|
||||
case "total_tokens":
|
||||
assert metric.data.data_points[0].value == usage["total_tokens"]
|
||||
case "prompt_tokens":
|
||||
assert metric.data.data_points[0].value == usage["prompt_tokens"
|
||||
"""
|
||||
# Filter metrics to only those from the specific model used in the request
|
||||
# Multiple metrics with the same name can exist (e.g., from safety models)
|
||||
inference_model_metrics = {}
|
||||
all_model_ids = set()
|
||||
|
||||
for name, metric in metrics.items():
|
||||
if name in expected_metrics:
|
||||
model_id = metric.attributes.get("model_id")
|
||||
all_model_ids.add(model_id)
|
||||
# Only include metrics from the specific model used in the test request
|
||||
if model_id == text_model_id:
|
||||
inference_model_metrics[name] = metric
|
||||
|
||||
# Verify expected metrics are present for our specific model
|
||||
for metric_name in expected_metrics:
|
||||
assert metric_name in inference_model_metrics, (
|
||||
f"Expected metric {metric_name} for model {text_model_id} not found. "
|
||||
f"Available models: {sorted(all_model_ids)}, "
|
||||
f"Available metrics for {text_model_id}: {list(inference_model_metrics.keys())}"
|
||||
)
|
||||
|
||||
# Verify metric values match usage data
|
||||
assert inference_model_metrics["completion_tokens"].value == usage["completion_tokens"], (
|
||||
f"Expected {usage['completion_tokens']} for completion_tokens, but got {inference_model_metrics['completion_tokens'].value}"
|
||||
)
|
||||
assert inference_model_metrics["total_tokens"].value == usage["total_tokens"], (
|
||||
f"Expected {usage['total_tokens']} for total_tokens, but got {inference_model_metrics['total_tokens'].value}"
|
||||
)
|
||||
assert inference_model_metrics["prompt_tokens"].value == usage["prompt_tokens"], (
|
||||
f"Expected {usage['prompt_tokens']} for prompt_tokens, but got {inference_model_metrics['prompt_tokens'].value}"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue