diff --git a/tests/integration/telemetry/collectors/base.py b/tests/integration/telemetry/collectors/base.py index c5de46cef..268fa7d7a 100644 --- a/tests/integration/telemetry/collectors/base.py +++ b/tests/integration/telemetry/collectors/base.py @@ -29,6 +29,67 @@ class SpanStub: return None return type("Context", (), {"trace_id": int(self.trace_id, 16)})() + def get_attributes(self) -> dict[str, Any]: + """Get span attributes as a dictionary. + + Handles different attribute types (mapping, dict, etc.) and returns + a consistent dictionary format. + """ + attrs = self.attributes + if attrs is None: + return {} + + # Handle mapping-like objects (e.g., mappingproxy) + try: + return dict(attrs.items()) # type: ignore[attr-defined] + except AttributeError: + try: + return dict(attrs) + except TypeError: + return dict(attrs) if attrs else {} + + def get_attribute(self, key: str) -> Any: + """Get a specific attribute value by key.""" + attrs = self.get_attributes() + return attrs.get(key) + + def get_trace_id(self) -> str | None: + """Get trace ID in hex format. + + Tries context.trace_id first, then falls back to direct trace_id. + """ + context = getattr(self, "context", None) + if context and getattr(context, "trace_id", None) is not None: + return f"{context.trace_id:032x}" + return getattr(self, "trace_id", None) + + def has_message(self, text: str) -> bool: + """Check if span contains a specific message in its args.""" + args = self.get_attribute("__args__") + if not args or not isinstance(args, str): + return False + return text in args + + def is_root_span(self) -> bool: + """Check if this is a root span.""" + return self.get_attribute("__root__") is True + + def is_autotraced(self) -> bool: + """Check if this span was automatically traced.""" + return self.get_attribute("__autotraced__") is True + + def get_span_type(self) -> str | None: + """Get the span type (async, sync, async_generator).""" + return self.get_attribute("__type__") + + def get_class_method(self) -> tuple[str | None, str | None]: + """Get the class and method names for autotraced spans.""" + return (self.get_attribute("__class__"), self.get_attribute("__method__")) + + def get_location(self) -> str | None: + """Get the location (library_client, server) for root spans.""" + return self.get_attribute("__location__") + def _value_to_python(value: Any) -> Any: kind = value.WhichOneof("value") diff --git a/tests/integration/telemetry/test_completions.py b/tests/integration/telemetry/test_completions.py index 70d266ee7..dc06cffff 100644 --- a/tests/integration/telemetry/test_completions.py +++ b/tests/integration/telemetry/test_completions.py @@ -13,39 +13,6 @@ before and after each test, ensuring test isolation. import json -def _span_attributes(span): - attrs = getattr(span, "attributes", None) - if attrs is None: - return {} - # ReadableSpan.attributes acts like a mapping - try: - return dict(attrs.items()) # type: ignore[attr-defined] - except AttributeError: - try: - return dict(attrs) - except TypeError: - return attrs - - -def _span_attr(span, key): - attrs = _span_attributes(span) - return attrs.get(key) - - -def _span_trace_id(span): - context = getattr(span, "context", None) - if context and getattr(context, "trace_id", None) is not None: - return f"{context.trace_id:032x}" - return getattr(span, "trace_id", None) - - -def _span_has_message(span, text: str) -> bool: - args = _span_attr(span, "__args__") - if not args or not isinstance(args, str): - return False - return text in args - - def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_model_id): """Verify streaming adds chunk_count and __type__=async_generator.""" stream = llama_stack_client.chat.completions.create( @@ -64,16 +31,16 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod ( span for span in reversed(spans) - if _span_attr(span, "__type__") == "async_generator" - and _span_attr(span, "chunk_count") - and _span_has_message(span, "Test trace openai 1") + if span.get_span_type() == "async_generator" + and span.get_attribute("chunk_count") + and span.has_message("Test trace openai 1") ), None, ) assert async_generator_span is not None - raw_chunk_count = _span_attr(async_generator_span, "chunk_count") + raw_chunk_count = async_generator_span.get_attribute("chunk_count") assert raw_chunk_count is not None chunk_count = int(raw_chunk_count) @@ -101,37 +68,36 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, # Verify spans spans = mock_otlp_collector.get_spans(expected_count=7) target_span = next( - (span for span in reversed(spans) if _span_has_message(span, "Test trace openai with temperature 0.7")), + (span for span in reversed(spans) if span.has_message("Test trace openai with temperature 0.7")), None, ) assert target_span is not None - trace_id = _span_trace_id(target_span) + trace_id = target_span.get_trace_id() assert trace_id is not None - spans = [span for span in spans if _span_trace_id(span) == trace_id] - spans = [span for span in spans if _span_attr(span, "__root__") or _span_attr(span, "__autotraced__")] + spans = [span for span in spans if span.get_trace_id() == trace_id] + spans = [span for span in spans if span.is_root_span() or span.is_autotraced()] assert len(spans) >= 4 # Collect all model_ids found in spans logged_model_ids = [] for span in spans: - attrs = _span_attributes(span) + attrs = span.get_attributes() assert attrs is not None # Root span is created manually by tracing middleware, not by @trace_protocol decorator - is_root_span = attrs.get("__root__") is True - - if is_root_span: - assert attrs.get("__location__") in ["library_client", "server"] + if span.is_root_span(): + assert span.get_location() in ["library_client", "server"] continue - assert attrs.get("__autotraced__") - assert attrs.get("__class__") and attrs.get("__method__") - assert attrs.get("__type__") in ["async", "sync", "async_generator"] + assert span.is_autotraced() + class_name, method_name = span.get_class_method() + assert class_name and method_name + assert span.get_span_type() in ["async", "sync", "async_generator"] - args_field = attrs.get("__args__") + args_field = span.get_attribute("__args__") if args_field: args = json.loads(args_field) if "model_id" in args: