mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
fix(tests): add methods to standardize test span data access across all tests
This commit is contained in:
parent
79156bb08c
commit
e4e8a59325
2 changed files with 77 additions and 50 deletions
|
|
@ -29,6 +29,67 @@ class SpanStub:
|
||||||
return None
|
return None
|
||||||
return type("Context", (), {"trace_id": int(self.trace_id, 16)})()
|
return type("Context", (), {"trace_id": int(self.trace_id, 16)})()
|
||||||
|
|
||||||
|
def get_attributes(self) -> dict[str, Any]:
|
||||||
|
"""Get span attributes as a dictionary.
|
||||||
|
|
||||||
|
Handles different attribute types (mapping, dict, etc.) and returns
|
||||||
|
a consistent dictionary format.
|
||||||
|
"""
|
||||||
|
attrs = self.attributes
|
||||||
|
if attrs is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Handle mapping-like objects (e.g., mappingproxy)
|
||||||
|
try:
|
||||||
|
return dict(attrs.items()) # type: ignore[attr-defined]
|
||||||
|
except AttributeError:
|
||||||
|
try:
|
||||||
|
return dict(attrs)
|
||||||
|
except TypeError:
|
||||||
|
return dict(attrs) if attrs else {}
|
||||||
|
|
||||||
|
def get_attribute(self, key: str) -> Any:
|
||||||
|
"""Get a specific attribute value by key."""
|
||||||
|
attrs = self.get_attributes()
|
||||||
|
return attrs.get(key)
|
||||||
|
|
||||||
|
def get_trace_id(self) -> str | None:
|
||||||
|
"""Get trace ID in hex format.
|
||||||
|
|
||||||
|
Tries context.trace_id first, then falls back to direct trace_id.
|
||||||
|
"""
|
||||||
|
context = getattr(self, "context", None)
|
||||||
|
if context and getattr(context, "trace_id", None) is not None:
|
||||||
|
return f"{context.trace_id:032x}"
|
||||||
|
return getattr(self, "trace_id", None)
|
||||||
|
|
||||||
|
def has_message(self, text: str) -> bool:
|
||||||
|
"""Check if span contains a specific message in its args."""
|
||||||
|
args = self.get_attribute("__args__")
|
||||||
|
if not args or not isinstance(args, str):
|
||||||
|
return False
|
||||||
|
return text in args
|
||||||
|
|
||||||
|
def is_root_span(self) -> bool:
|
||||||
|
"""Check if this is a root span."""
|
||||||
|
return self.get_attribute("__root__") is True
|
||||||
|
|
||||||
|
def is_autotraced(self) -> bool:
|
||||||
|
"""Check if this span was automatically traced."""
|
||||||
|
return self.get_attribute("__autotraced__") is True
|
||||||
|
|
||||||
|
def get_span_type(self) -> str | None:
|
||||||
|
"""Get the span type (async, sync, async_generator)."""
|
||||||
|
return self.get_attribute("__type__")
|
||||||
|
|
||||||
|
def get_class_method(self) -> tuple[str | None, str | None]:
|
||||||
|
"""Get the class and method names for autotraced spans."""
|
||||||
|
return (self.get_attribute("__class__"), self.get_attribute("__method__"))
|
||||||
|
|
||||||
|
def get_location(self) -> str | None:
|
||||||
|
"""Get the location (library_client, server) for root spans."""
|
||||||
|
return self.get_attribute("__location__")
|
||||||
|
|
||||||
|
|
||||||
def _value_to_python(value: Any) -> Any:
|
def _value_to_python(value: Any) -> Any:
|
||||||
kind = value.WhichOneof("value")
|
kind = value.WhichOneof("value")
|
||||||
|
|
|
||||||
|
|
@ -13,39 +13,6 @@ before and after each test, ensuring test isolation.
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
def _span_attributes(span):
|
|
||||||
attrs = getattr(span, "attributes", None)
|
|
||||||
if attrs is None:
|
|
||||||
return {}
|
|
||||||
# ReadableSpan.attributes acts like a mapping
|
|
||||||
try:
|
|
||||||
return dict(attrs.items()) # type: ignore[attr-defined]
|
|
||||||
except AttributeError:
|
|
||||||
try:
|
|
||||||
return dict(attrs)
|
|
||||||
except TypeError:
|
|
||||||
return attrs
|
|
||||||
|
|
||||||
|
|
||||||
def _span_attr(span, key):
|
|
||||||
attrs = _span_attributes(span)
|
|
||||||
return attrs.get(key)
|
|
||||||
|
|
||||||
|
|
||||||
def _span_trace_id(span):
|
|
||||||
context = getattr(span, "context", None)
|
|
||||||
if context and getattr(context, "trace_id", None) is not None:
|
|
||||||
return f"{context.trace_id:032x}"
|
|
||||||
return getattr(span, "trace_id", None)
|
|
||||||
|
|
||||||
|
|
||||||
def _span_has_message(span, text: str) -> bool:
|
|
||||||
args = _span_attr(span, "__args__")
|
|
||||||
if not args or not isinstance(args, str):
|
|
||||||
return False
|
|
||||||
return text in args
|
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_model_id):
|
def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_model_id):
|
||||||
"""Verify streaming adds chunk_count and __type__=async_generator."""
|
"""Verify streaming adds chunk_count and __type__=async_generator."""
|
||||||
stream = llama_stack_client.chat.completions.create(
|
stream = llama_stack_client.chat.completions.create(
|
||||||
|
|
@ -64,16 +31,16 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod
|
||||||
(
|
(
|
||||||
span
|
span
|
||||||
for span in reversed(spans)
|
for span in reversed(spans)
|
||||||
if _span_attr(span, "__type__") == "async_generator"
|
if span.get_span_type() == "async_generator"
|
||||||
and _span_attr(span, "chunk_count")
|
and span.get_attribute("chunk_count")
|
||||||
and _span_has_message(span, "Test trace openai 1")
|
and span.has_message("Test trace openai 1")
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert async_generator_span is not None
|
assert async_generator_span is not None
|
||||||
|
|
||||||
raw_chunk_count = _span_attr(async_generator_span, "chunk_count")
|
raw_chunk_count = async_generator_span.get_attribute("chunk_count")
|
||||||
assert raw_chunk_count is not None
|
assert raw_chunk_count is not None
|
||||||
chunk_count = int(raw_chunk_count)
|
chunk_count = int(raw_chunk_count)
|
||||||
|
|
||||||
|
|
@ -101,37 +68,36 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
|
||||||
# Verify spans
|
# Verify spans
|
||||||
spans = mock_otlp_collector.get_spans(expected_count=7)
|
spans = mock_otlp_collector.get_spans(expected_count=7)
|
||||||
target_span = next(
|
target_span = next(
|
||||||
(span for span in reversed(spans) if _span_has_message(span, "Test trace openai with temperature 0.7")),
|
(span for span in reversed(spans) if span.has_message("Test trace openai with temperature 0.7")),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
assert target_span is not None
|
assert target_span is not None
|
||||||
|
|
||||||
trace_id = _span_trace_id(target_span)
|
trace_id = target_span.get_trace_id()
|
||||||
assert trace_id is not None
|
assert trace_id is not None
|
||||||
|
|
||||||
spans = [span for span in spans if _span_trace_id(span) == trace_id]
|
spans = [span for span in spans if span.get_trace_id() == trace_id]
|
||||||
spans = [span for span in spans if _span_attr(span, "__root__") or _span_attr(span, "__autotraced__")]
|
spans = [span for span in spans if span.is_root_span() or span.is_autotraced()]
|
||||||
assert len(spans) >= 4
|
assert len(spans) >= 4
|
||||||
|
|
||||||
# Collect all model_ids found in spans
|
# Collect all model_ids found in spans
|
||||||
logged_model_ids = []
|
logged_model_ids = []
|
||||||
|
|
||||||
for span in spans:
|
for span in spans:
|
||||||
attrs = _span_attributes(span)
|
attrs = span.get_attributes()
|
||||||
assert attrs is not None
|
assert attrs is not None
|
||||||
|
|
||||||
# Root span is created manually by tracing middleware, not by @trace_protocol decorator
|
# Root span is created manually by tracing middleware, not by @trace_protocol decorator
|
||||||
is_root_span = attrs.get("__root__") is True
|
if span.is_root_span():
|
||||||
|
assert span.get_location() in ["library_client", "server"]
|
||||||
if is_root_span:
|
|
||||||
assert attrs.get("__location__") in ["library_client", "server"]
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assert attrs.get("__autotraced__")
|
assert span.is_autotraced()
|
||||||
assert attrs.get("__class__") and attrs.get("__method__")
|
class_name, method_name = span.get_class_method()
|
||||||
assert attrs.get("__type__") in ["async", "sync", "async_generator"]
|
assert class_name and method_name
|
||||||
|
assert span.get_span_type() in ["async", "sync", "async_generator"]
|
||||||
|
|
||||||
args_field = attrs.get("__args__")
|
args_field = span.get_attribute("__args__")
|
||||||
if args_field:
|
if args_field:
|
||||||
args = json.loads(args_field)
|
args = json.loads(args_field)
|
||||||
if "model_id" in args:
|
if "model_id" in args:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue