From e5370ffa744d31840f628fde9b3300668718d7aa Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Mon, 27 Oct 2025 15:02:24 -0700 Subject: [PATCH] tele tests # What does this PR do? ## Test Plan --- tests/integration/fixtures/common.py | 33 +++++ .../telemetry/collectors/__init__.py | 19 +++ .../integration/telemetry/collectors/base.py | 108 ++++++++++++++++ .../telemetry/collectors/in_memory.py | 71 +++++++++++ .../integration/telemetry/collectors/otlp.py | 115 ++++++++++++++++++ tests/integration/telemetry/conftest.py | 113 ++++++++--------- .../integration/telemetry/test_completions.py | 50 ++++---- 7 files changed, 418 insertions(+), 91 deletions(-) create mode 100644 tests/integration/telemetry/collectors/__init__.py create mode 100644 tests/integration/telemetry/collectors/base.py create mode 100644 tests/integration/telemetry/collectors/in_memory.py create mode 100644 tests/integration/telemetry/collectors/otlp.py diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 5fbf2c099..41822f850 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -88,6 +88,35 @@ def wait_for_server_ready(base_url: str, timeout: int = 30, process: subprocess. return False +def stop_server_on_port(port: int, timeout: float = 10.0) -> None: + """Terminate any server processes bound to the given port.""" + + try: + output = subprocess.check_output(["lsof", "-ti", f":{port}"], text=True) + except (subprocess.CalledProcessError, FileNotFoundError): + return + + pids = {int(line) for line in output.splitlines() if line.strip()} + if not pids: + return + + deadline = time.time() + timeout + for sig in (signal.SIGTERM, signal.SIGKILL): + for pid in list(pids): + try: + os.kill(pid, sig) + except ProcessLookupError: + pids.discard(pid) + + while not is_port_available(port) and time.time() < deadline: + time.sleep(0.1) + + if is_port_available(port): + return + + raise RuntimeError(f"Unable to free port {port} for test server restart") + + def get_provider_data(): # TODO: this needs to be generalized so each provider can have a sample provider data just # like sample run config on which we can do replace_env_vars() @@ -199,6 +228,10 @@ def instantiate_llama_stack_client(session): port = int(parts[2]) if len(parts) > 2 else int(os.environ.get("LLAMA_STACK_PORT", DEFAULT_PORT)) base_url = f"http://localhost:{port}" + force_restart = os.environ.get("LLAMA_STACK_TEST_FORCE_SERVER_RESTART") == "1" + if force_restart: + stop_server_on_port(port) + # Check if port is available if is_port_available(port): print(f"Starting llama stack server with config '{config_name}' on port {port}...") diff --git a/tests/integration/telemetry/collectors/__init__.py b/tests/integration/telemetry/collectors/__init__.py new file mode 100644 index 000000000..23d75a4a0 --- /dev/null +++ b/tests/integration/telemetry/collectors/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Telemetry collector helpers for integration tests.""" + +from .base import BaseTelemetryCollector, SpanStub +from .in_memory import InMemoryTelemetryCollector, InMemoryTelemetryManager +from .otlp import OtlpHttpTestCollector + +__all__ = [ + "BaseTelemetryCollector", + "SpanStub", + "InMemoryTelemetryCollector", + "InMemoryTelemetryManager", + "OtlpHttpTestCollector", +] diff --git a/tests/integration/telemetry/collectors/base.py b/tests/integration/telemetry/collectors/base.py new file mode 100644 index 000000000..9717e44b6 --- /dev/null +++ b/tests/integration/telemetry/collectors/base.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Shared helpers for telemetry test collectors.""" + +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any + + +@dataclass +class SpanStub: + name: str + attributes: dict[str, Any] + resource_attributes: dict[str, Any] | None = None + events: list[dict[str, Any]] | None = None + + +def _value_to_python(value: Any) -> Any: + kind = value.WhichOneof("value") + if kind == "string_value": + return value.string_value + if kind == "int_value": + return value.int_value + if kind == "double_value": + return value.double_value + if kind == "bool_value": + return value.bool_value + if kind == "bytes_value": + return value.bytes_value + if kind == "array_value": + return [_value_to_python(item) for item in value.array_value.values] + if kind == "kvlist_value": + return {kv.key: _value_to_python(kv.value) for kv in value.kvlist_value.values} + return None + + +def attributes_to_dict(key_values: Iterable[Any]) -> dict[str, Any]: + return {key_value.key: _value_to_python(key_value.value) for key_value in key_values} + + +def events_to_list(events: Iterable[Any]) -> list[dict[str, Any]]: + return [ + { + "name": event.name, + "timestamp": event.time_unix_nano, + "attributes": attributes_to_dict(event.attributes), + } + for event in events + ] + + +class BaseTelemetryCollector: + def get_spans( + self, + expected_count: int | None = None, + timeout: float = 5.0, + poll_interval: float = 0.05, + ) -> tuple[Any, ...]: + import time + + deadline = time.time() + timeout + min_count = expected_count if expected_count is not None else 1 + last_len: int | None = None + stable_iterations = 0 + + while True: + spans = tuple(self._snapshot_spans()) + + if len(spans) >= min_count: + if expected_count is not None and len(spans) >= expected_count: + return spans + + if last_len == len(spans): + stable_iterations += 1 + if stable_iterations >= 2: + return spans + else: + stable_iterations = 1 + else: + stable_iterations = 0 + + if time.time() >= deadline: + return spans + + last_len = len(spans) + time.sleep(poll_interval) + + def get_metrics(self) -> Any | None: + return self._snapshot_metrics() + + def clear(self) -> None: + self._clear_impl() + + def _snapshot_spans(self) -> tuple[Any, ...]: # pragma: no cover - interface hook + raise NotImplementedError + + def _snapshot_metrics(self) -> Any | None: # pragma: no cover - interface hook + raise NotImplementedError + + def _clear_impl(self) -> None: # pragma: no cover - interface hook + raise NotImplementedError + + def shutdown(self) -> None: + """Optional hook for subclasses with background workers.""" diff --git a/tests/integration/telemetry/collectors/in_memory.py b/tests/integration/telemetry/collectors/in_memory.py new file mode 100644 index 000000000..613b3860a --- /dev/null +++ b/tests/integration/telemetry/collectors/in_memory.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""In-memory telemetry collector for library-client tests.""" + +from typing import Any + +import opentelemetry.metrics as otel_metrics +import opentelemetry.trace as otel_trace +from opentelemetry import metrics, trace +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import InMemoryMetricReader +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +import llama_stack.core.telemetry.telemetry as telemetry_module + +from .base import BaseTelemetryCollector + + +class InMemoryTelemetryCollector(BaseTelemetryCollector): + def __init__(self, span_exporter: InMemorySpanExporter, metric_reader: InMemoryMetricReader) -> None: + self._span_exporter = span_exporter + self._metric_reader = metric_reader + + def _snapshot_spans(self) -> tuple[Any, ...]: + return tuple(self._span_exporter.get_finished_spans()) + + def _snapshot_metrics(self) -> Any | None: + data = self._metric_reader.get_metrics_data() + if data and data.resource_metrics: + resource_metric = data.resource_metrics[0] + if resource_metric.scope_metrics: + return resource_metric.scope_metrics[0].metrics + return None + + def _clear_impl(self) -> None: + self._span_exporter.clear() + self._metric_reader.get_metrics_data() + + +class InMemoryTelemetryManager: + def __init__(self) -> None: + if hasattr(otel_trace, "_TRACER_PROVIDER_SET_ONCE"): + otel_trace._TRACER_PROVIDER_SET_ONCE._done = False # type: ignore[attr-defined] + if hasattr(otel_metrics, "_METER_PROVIDER_SET_ONCE"): + otel_metrics._METER_PROVIDER_SET_ONCE._done = False # type: ignore[attr-defined] + + span_exporter = InMemorySpanExporter() + tracer_provider = TracerProvider() + tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter)) + trace.set_tracer_provider(tracer_provider) + + metric_reader = InMemoryMetricReader() + meter_provider = MeterProvider(metric_readers=[metric_reader]) + metrics.set_meter_provider(meter_provider) + + telemetry_module._TRACER_PROVIDER = tracer_provider + + self.collector = InMemoryTelemetryCollector(span_exporter, metric_reader) + self._tracer_provider = tracer_provider + self._meter_provider = meter_provider + + def shutdown(self) -> None: + telemetry_module._TRACER_PROVIDER = None + self._tracer_provider.shutdown() + self._meter_provider.shutdown() diff --git a/tests/integration/telemetry/collectors/otlp.py b/tests/integration/telemetry/collectors/otlp.py new file mode 100644 index 000000000..cfd7e1b8b --- /dev/null +++ b/tests/integration/telemetry/collectors/otlp.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""OTLP HTTP telemetry collector used for server-mode tests.""" + +import gzip +import threading +from http.server import BaseHTTPRequestHandler, HTTPServer +from socketserver import ThreadingMixIn +from typing import Any + +from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2 import ExportMetricsServiceRequest +from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest + +from .base import BaseTelemetryCollector, SpanStub, attributes_to_dict, events_to_list + + +class OtlpHttpTestCollector(BaseTelemetryCollector): + def __init__(self) -> None: + self._spans: list[SpanStub] = [] + self._metrics: list[Any] = [] + self._lock = threading.Lock() + + class _ThreadingHTTPServer(ThreadingMixIn, HTTPServer): + daemon_threads = True + allow_reuse_address = True + + self._server = _ThreadingHTTPServer(("127.0.0.1", 0), _CollectorHandler) + self._server.collector = self # type: ignore[attr-defined] + host, port = self._server.server_address[:2] + self.endpoint = f"http://{host}:{port}" + + self._thread = threading.Thread(target=self._server.serve_forever, name="otel-test-collector", daemon=True) + self._thread.start() + + def _handle_traces(self, request: ExportTraceServiceRequest) -> None: + new_spans: list[SpanStub] = [] + + for resource_spans in request.resource_spans: + resource_attrs = attributes_to_dict(resource_spans.resource.attributes) + + for scope_spans in resource_spans.scope_spans: + for span in scope_spans.spans: + attributes = attributes_to_dict(span.attributes) + events = events_to_list(span.events) if span.events else None + new_spans.append(SpanStub(span.name, attributes, resource_attrs or None, events)) + + if not new_spans: + return + + with self._lock: + self._spans.extend(new_spans) + + def _handle_metrics(self, request: ExportMetricsServiceRequest) -> None: + new_metrics: list[Any] = [] + for resource_metrics in request.resource_metrics: + for scope_metrics in resource_metrics.scope_metrics: + new_metrics.extend(scope_metrics.metrics) + + if not new_metrics: + return + + with self._lock: + self._metrics.extend(new_metrics) + + def _snapshot_spans(self) -> tuple[SpanStub, ...]: + with self._lock: + return tuple(self._spans) + + def _snapshot_metrics(self) -> Any | None: + with self._lock: + return list(self._metrics) if self._metrics else None + + def _clear_impl(self) -> None: + with self._lock: + self._spans.clear() + self._metrics.clear() + + def shutdown(self) -> None: + self._server.shutdown() + self._server.server_close() + self._thread.join(timeout=1) + + +class _CollectorHandler(BaseHTTPRequestHandler): + def do_POST(self): + collector: OtlpHttpTestCollector = self.server.collector # type: ignore[attr-defined] + length = int(self.headers.get("content-length", "0")) + body = self.rfile.read(length) + if self.headers.get("content-encoding") == "gzip": + body = gzip.decompress(body) + + if self.path == "/v1/traces": + request = ExportTraceServiceRequest() + request.ParseFromString(body) + collector._handle_traces(request) + self._respond_ok() + elif self.path == "/v1/metrics": + request = ExportMetricsServiceRequest() + request.ParseFromString(body) + collector._handle_metrics(request) + self._respond_ok() + else: + self.send_response(404) + self.end_headers() + + def log_message(self, format, *args): + return + + def _respond_ok(self) -> None: + self.send_response(200) + self.end_headers() diff --git a/tests/integration/telemetry/conftest.py b/tests/integration/telemetry/conftest.py index b055e47ac..dfb400ce7 100644 --- a/tests/integration/telemetry/conftest.py +++ b/tests/integration/telemetry/conftest.py @@ -4,92 +4,77 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -"""Telemetry test configuration using OpenTelemetry SDK exporters. +"""Telemetry test configuration supporting both library and server test modes.""" -This conftest provides in-memory telemetry collection for library_client mode only. -Tests using these fixtures should skip in server mode since the in-memory collector -cannot access spans from a separate server process. -""" +import os -from typing import Any - -import opentelemetry.metrics as otel_metrics -import opentelemetry.trace as otel_trace import pytest -from opentelemetry import metrics, trace -from opentelemetry.sdk.metrics import MeterProvider -from opentelemetry.sdk.metrics.export import InMemoryMetricReader -from opentelemetry.sdk.trace import ReadableSpan, TracerProvider -from opentelemetry.sdk.trace.export import SimpleSpanProcessor -from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter import llama_stack.core.telemetry.telemetry as telemetry_module from llama_stack.testing.api_recorder import patch_httpx_for_test_id from tests.integration.fixtures.common import instantiate_llama_stack_client - - -class TestCollector: - def __init__(self, span_exp, metric_read): - assert span_exp and metric_read - self.span_exporter = span_exp - self.metric_reader = metric_read - - def get_spans(self) -> tuple[ReadableSpan, ...]: - return self.span_exporter.get_finished_spans() - - def get_metrics(self) -> Any | None: - metrics = self.metric_reader.get_metrics_data() - if metrics and metrics.resource_metrics: - return metrics.resource_metrics[0].scope_metrics[0].metrics - return None - - def clear(self) -> None: - self.span_exporter.clear() - self.metric_reader.get_metrics_data() +from tests.integration.telemetry.collectors import InMemoryTelemetryManager, OtlpHttpTestCollector @pytest.fixture(scope="session") -def _telemetry_providers(): - """Set up in-memory OTEL providers before llama_stack_client initializes.""" - # Reset set-once flags to allow re-initialization - if hasattr(otel_trace, "_TRACER_PROVIDER_SET_ONCE"): - otel_trace._TRACER_PROVIDER_SET_ONCE._done = False # type: ignore - if hasattr(otel_metrics, "_METER_PROVIDER_SET_ONCE"): - otel_metrics._METER_PROVIDER_SET_ONCE._done = False # type: ignore +def telemetry_test_collector(): + stack_mode = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client") - # Create in-memory exporters/readers - span_exporter = InMemorySpanExporter() - tracer_provider = TracerProvider() - tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter)) - trace.set_tracer_provider(tracer_provider) + if stack_mode == "server": + try: + collector = OtlpHttpTestCollector() + except RuntimeError as exc: + pytest.skip(str(exc)) + env_overrides = { + "OTEL_EXPORTER_OTLP_ENDPOINT": collector.endpoint, + "OTEL_EXPORTER_OTLP_PROTOCOL": "http/protobuf", + "OTEL_BSP_SCHEDULE_DELAY": "200", + "OTEL_BSP_EXPORT_TIMEOUT": "2000", + } - metric_reader = InMemoryMetricReader() - meter_provider = MeterProvider(metric_readers=[metric_reader]) - metrics.set_meter_provider(meter_provider) + previous_env = {key: os.environ.get(key) for key in env_overrides} + previous_force_restart = os.environ.get("LLAMA_STACK_TEST_FORCE_SERVER_RESTART") - # Set module-level provider so TelemetryAdapter uses our in-memory providers - telemetry_module._TRACER_PROVIDER = tracer_provider + for key, value in env_overrides.items(): + os.environ[key] = value - yield (span_exporter, metric_reader, tracer_provider, meter_provider) + os.environ["LLAMA_STACK_TEST_FORCE_SERVER_RESTART"] = "1" + telemetry_module._TRACER_PROVIDER = None - telemetry_module._TRACER_PROVIDER = None - tracer_provider.shutdown() - meter_provider.shutdown() + try: + yield collector + finally: + collector.shutdown() + for key, prior in previous_env.items(): + if prior is None: + os.environ.pop(key, None) + else: + os.environ[key] = prior + if previous_force_restart is None: + os.environ.pop("LLAMA_STACK_TEST_FORCE_SERVER_RESTART", None) + else: + os.environ["LLAMA_STACK_TEST_FORCE_SERVER_RESTART"] = previous_force_restart + else: + manager = InMemoryTelemetryManager() + try: + yield manager.collector + finally: + manager.shutdown() @pytest.fixture(scope="session") -def llama_stack_client(_telemetry_providers, request): - """Override llama_stack_client to ensure in-memory telemetry providers are used.""" +def llama_stack_client(telemetry_test_collector, request): + """Ensure telemetry collector is ready before initializing the stack client.""" patch_httpx_for_test_id() client = instantiate_llama_stack_client(request.session) - return client @pytest.fixture -def mock_otlp_collector(_telemetry_providers): +def mock_otlp_collector(telemetry_test_collector): """Provides access to telemetry data and clears between tests.""" - span_exporter, metric_reader, _, _ = _telemetry_providers - collector = TestCollector(span_exporter, metric_reader) - yield collector - collector.clear() + telemetry_test_collector.clear() + try: + yield telemetry_test_collector + finally: + telemetry_test_collector.clear() diff --git a/tests/integration/telemetry/test_completions.py b/tests/integration/telemetry/test_completions.py index 77ca4d51c..49ee4de32 100644 --- a/tests/integration/telemetry/test_completions.py +++ b/tests/integration/telemetry/test_completions.py @@ -4,17 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -"""Telemetry tests verifying @trace_protocol decorator format using in-memory exporter.""" +"""Telemetry tests verifying @trace_protocol decorator format across stack modes.""" import json -import os - -import pytest - -pytestmark = pytest.mark.skipif( - os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE") == "server", - reason="In-memory telemetry tests only work in library_client mode (server mode runs in separate process)", -) def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_model_id): @@ -29,18 +21,20 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod chunks = list(stream) assert len(chunks) > 0 - spans = mock_otlp_collector.get_spans() + spans = mock_otlp_collector.get_spans(expected_count=5) assert len(spans) > 0 - chunk_count = None - for span in spans: - if span.attributes.get("__type__") == "async_generator": - chunk_count = span.attributes.get("chunk_count") - if chunk_count: - chunk_count = int(chunk_count) - break + async_generator_span = next( + (s for s in spans if s.attributes.get("__type__") == "async_generator" and s.attributes.get("chunk_count")), + None, + ) + + assert async_generator_span is not None + + raw_chunk_count = async_generator_span.attributes.get("chunk_count") + assert raw_chunk_count is not None + chunk_count = int(raw_chunk_count) - assert chunk_count is not None assert chunk_count == len(chunks) @@ -63,8 +57,9 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, assert usage.get("total_tokens") and usage["total_tokens"] > 0 # Verify spans - spans = mock_otlp_collector.get_spans() - assert len(spans) == 5 + spans = mock_otlp_collector.get_spans(expected_count=7) + spans = [span for span in spans if span.attributes.get("__root__") or span.attributes.get("__autotraced__")] + assert len(spans) >= 5 # we only need this captured one time logged_model_id = None @@ -77,15 +72,16 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, is_root_span = attrs.get("__root__") is True if is_root_span: - # Root spans have different attributes assert attrs.get("__location__") in ["library_client", "server"] - else: - # Non-root spans are created by @trace_protocol decorator - assert attrs.get("__autotraced__") - assert attrs.get("__class__") and attrs.get("__method__") - assert attrs.get("__type__") in ["async", "sync", "async_generator"] + continue - args = json.loads(attrs["__args__"]) + assert attrs.get("__autotraced__") + assert attrs.get("__class__") and attrs.get("__method__") + assert attrs.get("__type__") in ["async", "sync", "async_generator"] + + args_field = attrs.get("__args__") + if args_field: + args = json.loads(args_field) if "model_id" in args: logged_model_id = args["model_id"]