mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
fix(telemetry): token counters changed to histograms to reflect count per request
This commit is contained in:
parent
0e0bc8aba7
commit
23fce9718c
4 changed files with 114 additions and 120 deletions
|
|
@ -427,6 +427,7 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
||||||
"counters": {},
|
"counters": {},
|
||||||
"gauges": {},
|
"gauges": {},
|
||||||
"up_down_counters": {},
|
"up_down_counters": {},
|
||||||
|
"histograms": {},
|
||||||
}
|
}
|
||||||
_global_lock = threading.Lock()
|
_global_lock = threading.Lock()
|
||||||
_TRACER_PROVIDER = None
|
_TRACER_PROVIDER = None
|
||||||
|
|
@ -540,6 +541,16 @@ class Telemetry:
|
||||||
)
|
)
|
||||||
return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
|
return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
|
||||||
|
|
||||||
|
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
|
||||||
|
assert self.meter is not None
|
||||||
|
if name not in _GLOBAL_STORAGE["histograms"]:
|
||||||
|
_GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram(
|
||||||
|
name=name,
|
||||||
|
unit=unit,
|
||||||
|
description=f"Histogram for {name}",
|
||||||
|
)
|
||||||
|
return cast(metrics.Histogram, _GLOBAL_STORAGE["histograms"][name])
|
||||||
|
|
||||||
def _log_metric(self, event: MetricEvent) -> None:
|
def _log_metric(self, event: MetricEvent) -> None:
|
||||||
# Add metric as an event to the current span
|
# Add metric as an event to the current span
|
||||||
try:
|
try:
|
||||||
|
|
@ -571,7 +582,16 @@ class Telemetry:
|
||||||
# Log to OpenTelemetry meter if available
|
# Log to OpenTelemetry meter if available
|
||||||
if self.meter is None:
|
if self.meter is None:
|
||||||
return
|
return
|
||||||
if isinstance(event.value, int):
|
|
||||||
|
# Use histograms for token-related metrics (per-request measurements)
|
||||||
|
# Use counters for other cumulative metrics
|
||||||
|
token_metrics = {"prompt_tokens", "completion_tokens", "total_tokens"}
|
||||||
|
|
||||||
|
if event.metric in token_metrics:
|
||||||
|
# Token metrics are per-request measurements, use histogram
|
||||||
|
histogram = self._get_or_create_histogram(event.metric, event.unit)
|
||||||
|
histogram.record(event.value, attributes=_clean_attributes(event.attributes))
|
||||||
|
elif isinstance(event.value, int):
|
||||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||||
counter.add(event.value, attributes=_clean_attributes(event.attributes))
|
counter.add(event.value, attributes=_clean_attributes(event.attributes))
|
||||||
elif isinstance(event.value, float):
|
elif isinstance(event.value, float):
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,8 @@
|
||||||
|
|
||||||
"""Shared helpers for telemetry test collectors."""
|
"""Shared helpers for telemetry test collectors."""
|
||||||
|
|
||||||
from collections.abc import Iterable, Mapping
|
import time
|
||||||
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -19,29 +20,13 @@ class MetricStub:
|
||||||
value: Any
|
value: Any
|
||||||
attributes: dict[str, Any] | None = None
|
attributes: dict[str, Any] | None = None
|
||||||
|
|
||||||
def get_value(self) -> Any:
|
|
||||||
"""Get the metric value."""
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
"""Get the metric name."""
|
|
||||||
return self.name
|
|
||||||
|
|
||||||
def get_attributes(self) -> dict[str, Any]:
|
|
||||||
"""Get metric attributes as a dictionary."""
|
|
||||||
return self.attributes or {}
|
|
||||||
|
|
||||||
def get_attribute(self, key: str) -> Any:
|
|
||||||
"""Get a specific attribute value by key."""
|
|
||||||
return self.get_attributes().get(key)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SpanStub:
|
class SpanStub:
|
||||||
"""Unified span interface for both in-memory and OTLP collectors."""
|
"""Unified span interface for both in-memory and OTLP collectors."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
attributes: Mapping[str, Any] | None = None
|
attributes: dict[str, Any] | None = None
|
||||||
resource_attributes: dict[str, Any] | None = None
|
resource_attributes: dict[str, Any] | None = None
|
||||||
events: list[dict[str, Any]] | None = None
|
events: list[dict[str, Any]] | None = None
|
||||||
trace_id: str | None = None
|
trace_id: str | None = None
|
||||||
|
|
@ -54,19 +39,6 @@ class SpanStub:
|
||||||
return None
|
return None
|
||||||
return type("Context", (), {"trace_id": int(self.trace_id, 16)})()
|
return type("Context", (), {"trace_id": int(self.trace_id, 16)})()
|
||||||
|
|
||||||
def get_attributes(self) -> dict[str, Any]:
|
|
||||||
"""Get span attributes as a dictionary.
|
|
||||||
|
|
||||||
Handles different attribute types (mapping, dict, etc.) and returns
|
|
||||||
a consistent dictionary format.
|
|
||||||
"""
|
|
||||||
return BaseTelemetryCollector._convert_attributes_to_dict(self.attributes)
|
|
||||||
|
|
||||||
def get_attribute(self, key: str) -> Any:
|
|
||||||
"""Get a specific attribute value by key."""
|
|
||||||
attrs = self.get_attributes()
|
|
||||||
return attrs.get(key)
|
|
||||||
|
|
||||||
def get_trace_id(self) -> str | None:
|
def get_trace_id(self) -> str | None:
|
||||||
"""Get trace ID in hex format.
|
"""Get trace ID in hex format.
|
||||||
|
|
||||||
|
|
@ -79,30 +51,42 @@ class SpanStub:
|
||||||
|
|
||||||
def has_message(self, text: str) -> bool:
|
def has_message(self, text: str) -> bool:
|
||||||
"""Check if span contains a specific message in its args."""
|
"""Check if span contains a specific message in its args."""
|
||||||
args = self.get_attribute("__args__")
|
if self.attributes is None:
|
||||||
|
return False
|
||||||
|
args = self.attributes.get("__args__")
|
||||||
if not args or not isinstance(args, str):
|
if not args or not isinstance(args, str):
|
||||||
return False
|
return False
|
||||||
return text in args
|
return text in args
|
||||||
|
|
||||||
def is_root_span(self) -> bool:
|
def is_root_span(self) -> bool:
|
||||||
"""Check if this is a root span."""
|
"""Check if this is a root span."""
|
||||||
return self.get_attribute("__root__") is True
|
if self.attributes is None:
|
||||||
|
return False
|
||||||
|
return self.attributes.get("__root__") is True
|
||||||
|
|
||||||
def is_autotraced(self) -> bool:
|
def is_autotraced(self) -> bool:
|
||||||
"""Check if this span was automatically traced."""
|
"""Check if this span was automatically traced."""
|
||||||
return self.get_attribute("__autotraced__") is True
|
if self.attributes is None:
|
||||||
|
return False
|
||||||
|
return self.attributes.get("__autotraced__") is True
|
||||||
|
|
||||||
def get_span_type(self) -> str | None:
|
def get_span_type(self) -> str | None:
|
||||||
"""Get the span type (async, sync, async_generator)."""
|
"""Get the span type (async, sync, async_generator)."""
|
||||||
return self.get_attribute("__type__")
|
if self.attributes is None:
|
||||||
|
return None
|
||||||
|
return self.attributes.get("__type__")
|
||||||
|
|
||||||
def get_class_method(self) -> tuple[str | None, str | None]:
|
def get_class_method(self) -> tuple[str | None, str | None]:
|
||||||
"""Get the class and method names for autotraced spans."""
|
"""Get the class and method names for autotraced spans."""
|
||||||
return (self.get_attribute("__class__"), self.get_attribute("__method__"))
|
if self.attributes is None:
|
||||||
|
return None, None
|
||||||
|
return (self.attributes.get("__class__"), self.attributes.get("__method__"))
|
||||||
|
|
||||||
def get_location(self) -> str | None:
|
def get_location(self) -> str | None:
|
||||||
"""Get the location (library_client, server) for root spans."""
|
"""Get the location (library_client, server) for root spans."""
|
||||||
return self.get_attribute("__location__")
|
if self.attributes is None:
|
||||||
|
return None
|
||||||
|
return self.attributes.get("__location__")
|
||||||
|
|
||||||
|
|
||||||
def _value_to_python(value: Any) -> Any:
|
def _value_to_python(value: Any) -> Any:
|
||||||
|
|
@ -152,8 +136,6 @@ class BaseTelemetryCollector:
|
||||||
timeout: float = 5.0,
|
timeout: float = 5.0,
|
||||||
poll_interval: float = 0.05,
|
poll_interval: float = 0.05,
|
||||||
) -> tuple[SpanStub, ...]:
|
) -> tuple[SpanStub, ...]:
|
||||||
import time
|
|
||||||
|
|
||||||
deadline = time.time() + timeout
|
deadline = time.time() + timeout
|
||||||
min_count = expected_count if expected_count is not None else 1
|
min_count = expected_count if expected_count is not None else 1
|
||||||
last_len: int | None = None
|
last_len: int | None = None
|
||||||
|
|
@ -188,8 +170,8 @@ class BaseTelemetryCollector:
|
||||||
poll_interval: float = 0.05,
|
poll_interval: float = 0.05,
|
||||||
) -> dict[str, MetricStub]:
|
) -> dict[str, MetricStub]:
|
||||||
"""Get metrics with polling until metrics are available or timeout is reached."""
|
"""Get metrics with polling until metrics are available or timeout is reached."""
|
||||||
import time
|
|
||||||
|
|
||||||
|
# metrics need to be collected since get requests delete stored metrics
|
||||||
deadline = time.time() + timeout
|
deadline = time.time() + timeout
|
||||||
min_count = expected_count if expected_count is not None else 1
|
min_count = expected_count if expected_count is not None else 1
|
||||||
accumulated_metrics = {}
|
accumulated_metrics = {}
|
||||||
|
|
@ -197,14 +179,11 @@ class BaseTelemetryCollector:
|
||||||
while time.time() < deadline:
|
while time.time() < deadline:
|
||||||
current_metrics = self._snapshot_metrics()
|
current_metrics = self._snapshot_metrics()
|
||||||
if current_metrics:
|
if current_metrics:
|
||||||
# Accumulate new metrics without losing existing ones
|
|
||||||
for metric in current_metrics:
|
for metric in current_metrics:
|
||||||
metric_name = metric.get_name()
|
metric_name = metric.name
|
||||||
if metric_name not in accumulated_metrics:
|
if metric_name not in accumulated_metrics:
|
||||||
accumulated_metrics[metric_name] = metric
|
accumulated_metrics[metric_name] = metric
|
||||||
else:
|
else:
|
||||||
# If we already have this metric, keep the latest one
|
|
||||||
# (in case metrics are updated with new values)
|
|
||||||
accumulated_metrics[metric_name] = metric
|
accumulated_metrics[metric_name] = metric
|
||||||
|
|
||||||
# Check if we have enough metrics
|
# Check if we have enough metrics
|
||||||
|
|
@ -258,7 +237,7 @@ class BaseTelemetryCollector:
|
||||||
This helper reduces code duplication between collectors.
|
This helper reduces code duplication between collectors.
|
||||||
"""
|
"""
|
||||||
trace_id, span_id = BaseTelemetryCollector._extract_trace_span_ids(span)
|
trace_id, span_id = BaseTelemetryCollector._extract_trace_span_ids(span)
|
||||||
attributes = BaseTelemetryCollector._convert_attributes_to_dict(span.attributes)
|
attributes = BaseTelemetryCollector._convert_attributes_to_dict(span.attributes) or {}
|
||||||
|
|
||||||
return SpanStub(
|
return SpanStub(
|
||||||
name=span.name,
|
name=span.name,
|
||||||
|
|
@ -273,7 +252,7 @@ class BaseTelemetryCollector:
|
||||||
|
|
||||||
This helper handles the different structure of protobuf spans.
|
This helper handles the different structure of protobuf spans.
|
||||||
"""
|
"""
|
||||||
attributes = attributes_to_dict(span.attributes)
|
attributes = attributes_to_dict(span.attributes) or {}
|
||||||
events = events_to_list(span.events) if span.events else None
|
events = events_to_list(span.events) if span.events else None
|
||||||
trace_id = span.trace_id.hex() if span.trace_id else None
|
trace_id = span.trace_id.hex() if span.trace_id else None
|
||||||
span_id = span.span_id.hex() if span.span_id else None
|
span_id = span.span_id.hex() if span.span_id else None
|
||||||
|
|
@ -300,12 +279,22 @@ class BaseTelemetryCollector:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get the value from the first data point
|
# Get the value from the first data point
|
||||||
value = metric.data.data_points[0].value
|
data_point = metric.data.data_points[0]
|
||||||
|
|
||||||
|
# Handle different metric types
|
||||||
|
if hasattr(data_point, "value"):
|
||||||
|
# Counter or Gauge
|
||||||
|
value = data_point.value
|
||||||
|
elif hasattr(data_point, "sum"):
|
||||||
|
# Histogram - use the sum of all recorded values
|
||||||
|
value = data_point.sum
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
# Extract attributes if available
|
# Extract attributes if available
|
||||||
attributes = {}
|
attributes = {}
|
||||||
if hasattr(metric.data.data_points[0], "attributes"):
|
if hasattr(data_point, "attributes"):
|
||||||
attrs = metric.data.data_points[0].attributes
|
attrs = data_point.attributes
|
||||||
if attrs is not None and hasattr(attrs, "items"):
|
if attrs is not None and hasattr(attrs, "items"):
|
||||||
attributes = dict(attrs.items())
|
attributes = dict(attrs.items())
|
||||||
elif attrs is not None and not isinstance(attrs, dict):
|
elif attrs is not None and not isinstance(attrs, dict):
|
||||||
|
|
@ -314,9 +303,48 @@ class BaseTelemetryCollector:
|
||||||
return MetricStub(
|
return MetricStub(
|
||||||
name=metric.name,
|
name=metric.name,
|
||||||
value=value,
|
value=value,
|
||||||
attributes=attributes if attributes else None,
|
attributes=attributes or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_metric_stub_from_protobuf(metric: Any) -> MetricStub | None:
|
||||||
|
"""Create MetricStub from protobuf metric object.
|
||||||
|
|
||||||
|
Protobuf metrics have a different structure than OpenTelemetry metrics.
|
||||||
|
They can have sum, gauge, or histogram data.
|
||||||
|
"""
|
||||||
|
if not hasattr(metric, "name"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Try to extract value from different metric types
|
||||||
|
for metric_type in ["sum", "gauge", "histogram"]:
|
||||||
|
if hasattr(metric, metric_type):
|
||||||
|
metric_data = getattr(metric, metric_type)
|
||||||
|
if metric_data and hasattr(metric_data, "data_points"):
|
||||||
|
data_points = metric_data.data_points
|
||||||
|
if data_points and len(data_points) > 0:
|
||||||
|
data_point = data_points[0]
|
||||||
|
|
||||||
|
# Extract attributes first (needed for all metric types)
|
||||||
|
attributes = (
|
||||||
|
attributes_to_dict(data_point.attributes) if hasattr(data_point, "attributes") else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract value based on metric type
|
||||||
|
if metric_type == "sum":
|
||||||
|
value = data_point.as_int
|
||||||
|
elif metric_type == "gauge":
|
||||||
|
value = data_point.as_double
|
||||||
|
else: # histogram
|
||||||
|
value = data_point.sum
|
||||||
|
|
||||||
|
return MetricStub(
|
||||||
|
name=metric.name,
|
||||||
|
value=value,
|
||||||
|
attributes=attributes,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self._clear_impl()
|
self._clear_impl()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ import os
|
||||||
import threading
|
import threading
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
from socketserver import ThreadingMixIn
|
from socketserver import ThreadingMixIn
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2 import ExportMetricsServiceRequest
|
from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2 import ExportMetricsServiceRequest
|
||||||
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest
|
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest
|
||||||
|
|
@ -83,54 +82,6 @@ class OtlpHttpTestCollector(BaseTelemetryCollector):
|
||||||
self._spans.clear()
|
self._spans.clear()
|
||||||
self._metrics.clear()
|
self._metrics.clear()
|
||||||
|
|
||||||
def _create_metric_stub_from_protobuf(self, metric: Any) -> MetricStub | None:
|
|
||||||
"""Create MetricStub from protobuf metric object.
|
|
||||||
|
|
||||||
Protobuf metrics have a different structure than OpenTelemetry metrics.
|
|
||||||
They can have sum, gauge, or histogram data.
|
|
||||||
"""
|
|
||||||
if not hasattr(metric, "name"):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Try to extract value from different metric types
|
|
||||||
for metric_type in ["sum", "gauge", "histogram"]:
|
|
||||||
if hasattr(metric, metric_type):
|
|
||||||
metric_data = getattr(metric, metric_type)
|
|
||||||
if metric_data and hasattr(metric_data, "data_points"):
|
|
||||||
data_points = metric_data.data_points
|
|
||||||
if data_points and len(data_points) > 0:
|
|
||||||
data_point = data_points[0]
|
|
||||||
|
|
||||||
# Extract value based on metric type
|
|
||||||
if metric_type == "sum":
|
|
||||||
value = data_point.as_int
|
|
||||||
elif metric_type == "gauge":
|
|
||||||
value = data_point.as_double
|
|
||||||
else: # histogram
|
|
||||||
value = data_point.count
|
|
||||||
|
|
||||||
# Extract attributes if available
|
|
||||||
attributes = self._extract_attributes_from_data_point(data_point)
|
|
||||||
|
|
||||||
return MetricStub(
|
|
||||||
name=metric.name,
|
|
||||||
value=value,
|
|
||||||
attributes=attributes if attributes else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _extract_attributes_from_data_point(self, data_point: Any) -> dict[str, Any]:
|
|
||||||
"""Extract attributes from a protobuf data point."""
|
|
||||||
if not hasattr(data_point, "attributes"):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
attrs = data_point.attributes
|
|
||||||
if not attrs:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
return {kv.key: kv.value.string_value or kv.value.int_value or kv.value.double_value for kv in attrs}
|
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
self._server.shutdown()
|
self._server.shutdown()
|
||||||
self._server.server_close()
|
self._server.server_close()
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod
|
||||||
span
|
span
|
||||||
for span in reversed(spans)
|
for span in reversed(spans)
|
||||||
if span.get_span_type() == "async_generator"
|
if span.get_span_type() == "async_generator"
|
||||||
and span.get_attribute("chunk_count")
|
and span.attributes.get("chunk_count")
|
||||||
and span.has_message("Test trace openai 1")
|
and span.has_message("Test trace openai 1")
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
|
|
@ -40,7 +40,7 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod
|
||||||
|
|
||||||
assert async_generator_span is not None
|
assert async_generator_span is not None
|
||||||
|
|
||||||
raw_chunk_count = async_generator_span.get_attribute("chunk_count")
|
raw_chunk_count = async_generator_span.attributes.get("chunk_count")
|
||||||
assert raw_chunk_count is not None
|
assert raw_chunk_count is not None
|
||||||
chunk_count = int(raw_chunk_count)
|
chunk_count = int(raw_chunk_count)
|
||||||
|
|
||||||
|
|
@ -85,7 +85,7 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
|
||||||
logged_model_ids = []
|
logged_model_ids = []
|
||||||
|
|
||||||
for span in spans:
|
for span in spans:
|
||||||
attrs = span.get_attributes()
|
attrs = span.attributes
|
||||||
assert attrs is not None
|
assert attrs is not None
|
||||||
|
|
||||||
# Root span is created manually by tracing middleware, not by @trace_protocol decorator
|
# Root span is created manually by tracing middleware, not by @trace_protocol decorator
|
||||||
|
|
@ -98,7 +98,7 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
|
||||||
assert class_name and method_name
|
assert class_name and method_name
|
||||||
assert span.get_span_type() in ["async", "sync", "async_generator"]
|
assert span.get_span_type() in ["async", "sync", "async_generator"]
|
||||||
|
|
||||||
args_field = span.get_attribute("__args__")
|
args_field = span.attributes.get("__args__")
|
||||||
if args_field:
|
if args_field:
|
||||||
args = json.loads(args_field)
|
args = json.loads(args_field)
|
||||||
if "model_id" in args:
|
if "model_id" in args:
|
||||||
|
|
@ -115,37 +115,32 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
|
||||||
# Filter metrics to only those from the specific model used in the request
|
# Filter metrics to only those from the specific model used in the request
|
||||||
# This prevents issues when multiple metrics with the same name exist from different models
|
# This prevents issues when multiple metrics with the same name exist from different models
|
||||||
# (e.g., when safety models like llama-guard are also called)
|
# (e.g., when safety models like llama-guard are also called)
|
||||||
model_metrics = {}
|
inference_model_metrics = {}
|
||||||
all_model_ids = set()
|
all_model_ids = set()
|
||||||
|
|
||||||
for name, metric in metrics.items():
|
for name, metric in metrics.items():
|
||||||
if name in expected_metrics:
|
if name in expected_metrics:
|
||||||
model_id = metric.get_attribute("model_id")
|
model_id = metric.attributes.get("model_id")
|
||||||
all_model_ids.add(model_id)
|
all_model_ids.add(model_id)
|
||||||
# Only include metrics from the specific model used in the test request
|
# Only include metrics from the specific model used in the test request
|
||||||
if model_id == text_model_id:
|
if model_id == text_model_id:
|
||||||
model_metrics[name] = metric
|
inference_model_metrics[name] = metric
|
||||||
|
|
||||||
# Provide helpful error message if we have metrics from multiple models
|
|
||||||
if len(all_model_ids) > 1:
|
|
||||||
print(f"Note: Found metrics from multiple models: {sorted(all_model_ids)}")
|
|
||||||
print(f"Filtering to only metrics from test model: {text_model_id}")
|
|
||||||
|
|
||||||
# Verify expected metrics are present for our specific model
|
# Verify expected metrics are present for our specific model
|
||||||
for metric_name in expected_metrics:
|
for metric_name in expected_metrics:
|
||||||
assert metric_name in model_metrics, (
|
assert metric_name in inference_model_metrics, (
|
||||||
f"Expected metric {metric_name} for model {text_model_id} not found. "
|
f"Expected metric {metric_name} for model {text_model_id} not found. "
|
||||||
f"Available models: {sorted(all_model_ids)}, "
|
f"Available models: {sorted(all_model_ids)}, "
|
||||||
f"Available metrics for {text_model_id}: {list(model_metrics.keys())}"
|
f"Available metrics for {text_model_id}: {list(inference_model_metrics.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify metric values match usage data
|
# Verify metric values match usage data
|
||||||
assert model_metrics["completion_tokens"].get_value() == usage["completion_tokens"], (
|
assert inference_model_metrics["completion_tokens"].value == usage["completion_tokens"], (
|
||||||
f"Expected {usage['completion_tokens']} for completion_tokens, but got {model_metrics['completion_tokens'].get_value()}"
|
f"Expected {usage['completion_tokens']} for completion_tokens, but got {inference_model_metrics['completion_tokens'].value}"
|
||||||
)
|
)
|
||||||
assert model_metrics["total_tokens"].get_value() == usage["total_tokens"], (
|
assert inference_model_metrics["total_tokens"].value == usage["total_tokens"], (
|
||||||
f"Expected {usage['total_tokens']} for total_tokens, but got {model_metrics['total_tokens'].get_value()}"
|
f"Expected {usage['total_tokens']} for total_tokens, but got {inference_model_metrics['total_tokens'].value}"
|
||||||
)
|
)
|
||||||
assert model_metrics["prompt_tokens"].get_value() == usage["prompt_tokens"], (
|
assert inference_model_metrics["prompt_tokens"].value == usage["prompt_tokens"], (
|
||||||
f"Expected {usage['prompt_tokens']} for prompt_tokens, but got {model_metrics['prompt_tokens'].get_value()}"
|
f"Expected {usage['prompt_tokens']} for prompt_tokens, but got {inference_model_metrics['prompt_tokens'].value}"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue