diff --git a/docs/docs/providers/telemetry/inline_otel.mdx b/docs/docs/providers/telemetry/inline_otel.mdx new file mode 100644 index 000000000..0c0491e8a --- /dev/null +++ b/docs/docs/providers/telemetry/inline_otel.mdx @@ -0,0 +1,33 @@ +--- +description: "Native OpenTelemetry provider with full access to OTel Tracer and Meter APIs for advanced instrumentation." +sidebar_label: Otel +title: inline::otel +--- + +# inline::otel + +## Description + +Native OpenTelemetry provider with full access to OTel Tracer and Meter APIs for advanced instrumentation. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `service_name` | `` | No | | The name of the service to be monitored. + Is overridden by the OTEL_SERVICE_NAME or OTEL_RESOURCE_ATTRIBUTES environment variables. | +| `service_version` | `str \| None` | No | | The version of the service to be monitored. + Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable. | +| `deployment_environment` | `str \| None` | No | | The name of the environment of the service to be monitored. + Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable. | +| `span_processor` | `BatchSpanProcessor \| SimpleSpanProcessor \| None` | No | batch | The span processor to use. + Is overriden by the OTEL_SPAN_PROCESSOR environment variable. | + +## Sample Configuration + +```yaml +service_name: ${env.OTEL_SERVICE_NAME:=llama-stack} +service_version: ${env.OTEL_SERVICE_VERSION:=} +deployment_environment: ${env.OTEL_DEPLOYMENT_ENVIRONMENT:=} +span_processor: ${env.OTEL_SPAN_PROCESSOR:=batch} +``` diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index 9f8f35bdf..a422bc9d0 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -32,7 +32,7 @@ from termcolor import cprint from llama_stack.core.build import print_pip_install_help from llama_stack.core.configure import parse_and_maybe_upgrade_config -from llama_stack.core.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec +from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec from llama_stack.core.request_headers import ( PROVIDER_DATA_VAR, request_provider_data_context, @@ -49,7 +49,6 @@ from llama_stack.core.utils.context import preserve_contexts_async_generator from llama_stack.core.utils.exec import in_notebook from llama_stack.log import get_logger - logger = get_logger(name=__name__, category="core") T = TypeVar("T") diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index ab051cb2e..a1702ff13 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -63,7 +63,6 @@ from llama_stack.core.utils.context import preserve_contexts_async_generator from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api - from .auth import AuthenticationMiddleware from .quota import QuotaMiddleware @@ -236,9 +235,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: try: if is_streaming: - gen = preserve_contexts_async_generator( - sse_generator(func(**kwargs)), [PROVIDER_DATA_VAR] - ) + gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), [PROVIDER_DATA_VAR]) return StreamingResponse(gen, media_type="text/event-stream") else: value = func(**kwargs) @@ -282,7 +279,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: ] ) - setattr(route_handler, "__signature__", sig.replace(parameters=new_params)) + route_handler.__signature__ = sig.replace(parameters=new_params) return route_handler diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index 4536275bd..0413e47c5 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -359,7 +359,6 @@ class Stack: await refresh_registry_once(impls) self.impls = impls - # safely access impls without raising an exception def get_impls(self) -> dict[Api, Any]: if self.impls is None: diff --git a/llama_stack/core/telemetry/__init__.py b/llama_stack/core/telemetry/__init__.py index 3c22a16d4..b5e7174df 100644 --- a/llama_stack/core/telemetry/__init__.py +++ b/llama_stack/core/telemetry/__init__.py @@ -1,4 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. \ No newline at end of file +# the root directory of this source tree. + +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/core/telemetry/telemetry.py b/llama_stack/core/telemetry/telemetry.py index ef2ddd6a2..254876b70 100644 --- a/llama_stack/core/telemetry/telemetry.py +++ b/llama_stack/core/telemetry/telemetry.py @@ -4,14 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from abc import abstractmethod -from fastapi import FastAPI -from pydantic import BaseModel -from opentelemetry.trace import Tracer +from fastapi import FastAPI from opentelemetry.metrics import Meter -from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.resources import Attributes +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.trace import Tracer +from pydantic import BaseModel from sqlalchemy import Engine @@ -19,39 +19,44 @@ class TelemetryProvider(BaseModel): """ TelemetryProvider standardizes how telemetry is provided to the application. """ + @abstractmethod def fastapi_middleware(self, app: FastAPI, *args, **kwargs): """ Injects FastAPI middleware that instruments the application for telemetry. """ ... - + @abstractmethod def sqlalchemy_instrumentation(self, engine: Engine | None = None): """ Injects SQLAlchemy instrumentation that instruments the application for telemetry. """ ... - + @abstractmethod - def get_tracer(self, - instrumenting_module_name: str, - instrumenting_library_version: str | None = None, - tracer_provider: TracerProvider | None = None, - schema_url: str | None = None, - attributes: Attributes | None = None + def get_tracer( + self, + instrumenting_module_name: str, + instrumenting_library_version: str | None = None, + tracer_provider: TracerProvider | None = None, + schema_url: str | None = None, + attributes: Attributes | None = None, ) -> Tracer: """ Gets a tracer. """ ... - + @abstractmethod - def get_meter(self, name: str, - version: str = "", - meter_provider: MeterProvider | None = None, - schema_url: str | None = None, - attributes: Attributes | None = None) -> Meter: + def get_meter( + self, + name: str, + version: str = "", + meter_provider: MeterProvider | None = None, + schema_url: str | None = None, + attributes: Attributes | None = None, + ) -> Meter: """ Gets a meter. """ diff --git a/llama_stack/providers/inline/telemetry/meta_reference/middleware.py b/llama_stack/providers/inline/telemetry/meta_reference/middleware.py index 6902bb125..219c344ef 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/middleware.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/middleware.py @@ -1,15 +1,22 @@ -from aiohttp import hdrs +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + from typing import Any +from aiohttp import hdrs + from llama_stack.apis.datatypes import Api from llama_stack.core.external import ExternalApiSpec from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace - logger = get_logger(name=__name__, category="telemetry::meta_reference") + class TracingMiddleware: def __init__( self, diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 396238850..596b93551 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -10,7 +10,6 @@ import threading from typing import Any, cast from fastapi import FastAPI - from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter @@ -23,11 +22,6 @@ from opentelemetry.semconv.attributes import service_attributes from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from opentelemetry.util.types import Attributes -from llama_stack.core.external import ExternalApiSpec -from llama_stack.core.server.tracing import TelemetryProvider -from llama_stack.providers.inline.telemetry.meta_reference.middleware import TracingMiddleware - - from llama_stack.apis.telemetry import ( Event, MetricEvent, @@ -47,10 +41,13 @@ from llama_stack.apis.telemetry import ( UnstructuredLogEvent, ) from llama_stack.core.datatypes import Api +from llama_stack.core.external import ExternalApiSpec +from llama_stack.core.server.tracing import TelemetryProvider from llama_stack.log import get_logger from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( ConsoleSpanProcessor, ) +from llama_stack.providers.inline.telemetry.meta_reference.middleware import TracingMiddleware from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import ( SQLiteSpanProcessor, ) @@ -381,7 +378,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry, TelemetryProvider): max_depth=max_depth, ) ) - + def fastapi_middleware( self, app: FastAPI, diff --git a/llama_stack/providers/inline/telemetry/otel/__init__.py b/llama_stack/providers/inline/telemetry/otel/__init__.py index f432d3364..2370b0752 100644 --- a/llama_stack/providers/inline/telemetry/otel/__init__.py +++ b/llama_stack/providers/inline/telemetry/otel/__init__.py @@ -12,13 +12,12 @@ __all__ = ["OTelTelemetryConfig"] async def get_provider_impl(config: OTelTelemetryConfig, deps): """ Get the OTel telemetry provider implementation. - + This function is called by the Llama Stack registry to instantiate the provider. """ from .otel import OTelTelemetryProvider - + # The provider is synchronously initialized via Pydantic model_post_init # No async initialization needed return OTelTelemetryProvider(config=config) - diff --git a/llama_stack/providers/inline/telemetry/otel/config.py b/llama_stack/providers/inline/telemetry/otel/config.py index ad4982716..709944cd4 100644 --- a/llama_stack/providers/inline/telemetry/otel/config.py +++ b/llama_stack/providers/inline/telemetry/otel/config.py @@ -1,8 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + from typing import Any, Literal from pydantic import BaseModel, Field - type BatchSpanProcessor = Literal["batch"] type SimpleSpanProcessor = Literal["simple"] @@ -13,26 +18,27 @@ class OTelTelemetryConfig(BaseModel): Most configuration is set using environment variables. See https://opentelemetry.io/docs/specs/otel/configuration/sdk-configuration-variables/ for more information. """ + service_name: str = Field( - description="""The name of the service to be monitored. + description="""The name of the service to be monitored. Is overridden by the OTEL_SERVICE_NAME or OTEL_RESOURCE_ATTRIBUTES environment variables.""", ) service_version: str | None = Field( default=None, - description="""The version of the service to be monitored. - Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable.""" + description="""The version of the service to be monitored. + Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable.""", ) deployment_environment: str | None = Field( default=None, - description="""The name of the environment of the service to be monitored. - Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable.""" + description="""The name of the environment of the service to be monitored. + Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable.""", ) span_processor: BatchSpanProcessor | SimpleSpanProcessor | None = Field( - description="""The span processor to use. + description="""The span processor to use. Is overriden by the OTEL_SPAN_PROCESSOR environment variable.""", - default="batch" + default="batch", ) - + @classmethod def sample_run_config(cls, __distro_dir__: str = "") -> dict[str, Any]: """Sample configuration for use in distributions.""" diff --git a/llama_stack/providers/inline/telemetry/otel/otel.py b/llama_stack/providers/inline/telemetry/otel/otel.py index 08a2c9a63..68d91a78d 100644 --- a/llama_stack/providers/inline/telemetry/otel/otel.py +++ b/llama_stack/providers/inline/telemetry/otel/otel.py @@ -1,24 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + import os -from opentelemetry import trace, metrics +from fastapi import FastAPI +from opentelemetry import metrics, trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor +from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor +from opentelemetry.metrics import Meter +from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.resources import Attributes, Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, SimpleSpanProcessor -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter -from opentelemetry.sdk.metrics import MeterProvider -from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from opentelemetry.trace import Tracer -from opentelemetry.metrics import Meter -from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor +from sqlalchemy import Engine from llama_stack.core.telemetry.telemetry import TelemetryProvider from llama_stack.log import get_logger -from sqlalchemy import Engine - from .config import OTelTelemetryConfig -from fastapi import FastAPI - logger = get_logger(name=__name__, category="telemetry::otel") @@ -27,6 +31,7 @@ class OTelTelemetryProvider(TelemetryProvider): """ A simple Open Telemetry native telemetry provider. """ + config: OTelTelemetryConfig def model_post_init(self, __context): @@ -56,66 +61,66 @@ class OTelTelemetryProvider(TelemetryProvider): tracer_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter)) elif self.config.span_processor == "simple": tracer_provider.add_span_processor(SimpleSpanProcessor(otlp_span_exporter)) - + meter_provider = MeterProvider(resource=resource) metrics.set_meter_provider(meter_provider) # Do not fail the application, but warn the user if the endpoints are not set properly. if not os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"): if not os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"): - logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_TRACES_ENDPOINT is not set. Traces will not be exported.") + logger.warning( + "OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_TRACES_ENDPOINT is not set. Traces will not be exported." + ) if not os.environ.get("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT"): - logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_METRICS_ENDPOINT is not set. Metrics will not be exported.") - + logger.warning( + "OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_METRICS_ENDPOINT is not set. Metrics will not be exported." + ) def fastapi_middleware(self, app: FastAPI): """ Instrument FastAPI with OTel for automatic tracing and metrics. - + Captures: - Distributed traces for all HTTP requests (via FastAPIInstrumentor) - HTTP metrics following semantic conventions (custom middleware) """ # Enable automatic tracing FastAPIInstrumentor.instrument_app(app) - + # Add custom middleware for HTTP metrics meter = self.get_meter("llama_stack.http.server") - + # Create HTTP metrics following semantic conventions # https://opentelemetry.io/docs/specs/semconv/http/http-metrics/ request_duration = meter.create_histogram( - "http.server.request.duration", - unit="ms", - description="Duration of HTTP server requests" + "http.server.request.duration", unit="ms", description="Duration of HTTP server requests" ) - + active_requests = meter.create_up_down_counter( - "http.server.active_requests", - unit="requests", - description="Number of active HTTP server requests" + "http.server.active_requests", unit="requests", description="Number of active HTTP server requests" ) - + request_count = meter.create_counter( - "http.server.request.count", - unit="requests", - description="Total number of HTTP server requests" + "http.server.request.count", unit="requests", description="Total number of HTTP server requests" ) - + # Add middleware to record metrics @app.middleware("http") # type: ignore[misc] async def http_metrics_middleware(request, call_next): import time - + # Record active request - active_requests.add(1, { - "http.method": request.method, - "http.route": request.url.path, - }) - + active_requests.add( + 1, + { + "http.method": request.method, + "http.route": request.url.path, + }, + ) + start_time = time.time() status_code = 500 # Default to error - + try: response = await call_next(request) status_code = response.status_code @@ -124,22 +129,24 @@ class OTelTelemetryProvider(TelemetryProvider): finally: # Record metrics duration_ms = (time.time() - start_time) * 1000 - + attributes = { "http.method": request.method, "http.route": request.url.path, "http.status_code": status_code, } - + request_duration.record(duration_ms, attributes) request_count.add(1, attributes) - active_requests.add(-1, { - "http.method": request.method, - "http.route": request.url.path, - }) - - return response + active_requests.add( + -1, + { + "http.method": request.method, + "http.route": request.url.path, + }, + ) + return response def sqlalchemy_instrumentation(self, engine: Engine | None = None): kwargs = {} @@ -147,34 +154,30 @@ class OTelTelemetryProvider(TelemetryProvider): kwargs["engine"] = engine SQLAlchemyInstrumentor().instrument(**kwargs) - - def get_tracer(self, - instrumenting_module_name: str, - instrumenting_library_version: str | None = None, - tracer_provider: TracerProvider | None = None, - schema_url: str | None = None, - attributes: Attributes | None = None + def get_tracer( + self, + instrumenting_module_name: str, + instrumenting_library_version: str | None = None, + tracer_provider: TracerProvider | None = None, + schema_url: str | None = None, + attributes: Attributes | None = None, ) -> Tracer: return trace.get_tracer( - instrumenting_module_name=instrumenting_module_name, - instrumenting_library_version=instrumenting_library_version, - tracer_provider=tracer_provider, - schema_url=schema_url, - attributes=attributes + instrumenting_module_name=instrumenting_module_name, + instrumenting_library_version=instrumenting_library_version, + tracer_provider=tracer_provider, + schema_url=schema_url, + attributes=attributes, ) - - def get_meter(self, - name: str, - version: str = "", - meter_provider: MeterProvider | None = None, - schema_url: str | None = None, - attributes: Attributes | None = None + def get_meter( + self, + name: str, + version: str = "", + meter_provider: MeterProvider | None = None, + schema_url: str | None = None, + attributes: Attributes | None = None, ) -> Meter: return metrics.get_meter( - name=name, - version=version, - meter_provider=meter_provider, - schema_url=schema_url, - attributes=attributes - ) \ No newline at end of file + name=name, version=version, meter_provider=meter_provider, schema_url=schema_url, attributes=attributes + ) diff --git a/tests/integration/telemetry/__init__.py b/tests/integration/telemetry/__init__.py index d4a3e15c8..756f351d8 100644 --- a/tests/integration/telemetry/__init__.py +++ b/tests/integration/telemetry/__init__.py @@ -3,4 +3,3 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - diff --git a/tests/integration/telemetry/mocking/README.md b/tests/integration/telemetry/mocking/README.md index 853022622..3fedea75d 100644 --- a/tests/integration/telemetry/mocking/README.md +++ b/tests/integration/telemetry/mocking/README.md @@ -24,7 +24,7 @@ class MockServerBase(BaseModel): async def await_start(self): # Start server and wait until ready ... - + def stop(self): # Stop server and cleanup ... @@ -49,29 +49,29 @@ Add to `servers.py`: ```python class MockRedisServer(MockServerBase): """Mock Redis server.""" - + port: int = Field(default=6379) - + # Non-Pydantic fields server: Any = Field(default=None, exclude=True) - + def model_post_init(self, __context): self.server = None - + async def await_start(self): """Start Redis mock and wait until ready.""" # Start your server self.server = create_redis_server(self.port) self.server.start() - + # Wait for port to be listening for _ in range(10): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - if sock.connect_ex(('localhost', self.port)) == 0: + if sock.connect_ex(("localhost", self.port)) == 0: sock.close() return # Ready! await asyncio.sleep(0.1) - + def stop(self): if self.server: self.server.stop() @@ -101,11 +101,11 @@ The harness automatically: ## Benefits -✅ **Parallel Startup** - All servers start simultaneously -✅ **Type-Safe** - Pydantic validation -✅ **Simple** - Just implement 2 methods -✅ **Fast** - No HTTP polling, direct port checking -✅ **Clean** - Async/await pattern +✅ **Parallel Startup** - All servers start simultaneously +✅ **Type-Safe** - Pydantic validation +✅ **Simple** - Just implement 2 methods +✅ **Fast** - No HTTP polling, direct port checking +✅ **Clean** - Async/await pattern ## Usage in Tests @@ -116,6 +116,7 @@ def mock_servers(): yield servers stop_mock_servers(servers) + # Access specific servers @pytest.fixture(scope="module") def mock_redis(mock_servers): diff --git a/tests/integration/telemetry/mocking/__init__.py b/tests/integration/telemetry/mocking/__init__.py index 99ad92856..3a934a002 100644 --- a/tests/integration/telemetry/mocking/__init__.py +++ b/tests/integration/telemetry/mocking/__init__.py @@ -14,9 +14,9 @@ This module provides: - Mock server harness for parallel async startup """ +from .harness import MockServerConfig, start_mock_servers_async, stop_mock_servers from .mock_base import MockServerBase from .servers import MockOTLPCollector, MockVLLMServer -from .harness import MockServerConfig, start_mock_servers_async, stop_mock_servers __all__ = [ "MockServerBase", @@ -26,4 +26,3 @@ __all__ = [ "start_mock_servers_async", "stop_mock_servers", ] - diff --git a/tests/integration/telemetry/mocking/harness.py b/tests/integration/telemetry/mocking/harness.py index 09b80f70f..d877abbf9 100644 --- a/tests/integration/telemetry/mocking/harness.py +++ b/tests/integration/telemetry/mocking/harness.py @@ -14,7 +14,7 @@ HOW TO ADD A NEW MOCK SERVER: """ import asyncio -from typing import Any, Dict, List +from typing import Any from pydantic import BaseModel, Field @@ -24,10 +24,10 @@ from .mock_base import MockServerBase class MockServerConfig(BaseModel): """ Configuration for a mock server to start. - + **TO ADD A NEW MOCK SERVER:** Just create a MockServerConfig instance with your server class. - + Example: MockServerConfig( name="Mock MyService", @@ -35,73 +35,72 @@ class MockServerConfig(BaseModel): init_kwargs={"port": 9000, "config_param": "value"}, ) """ - + model_config = {"arbitrary_types_allowed": True} - + name: str = Field(description="Display name for logging") server_class: type = Field(description="Mock server class (must inherit from MockServerBase)") - init_kwargs: Dict[str, Any] = Field(default_factory=dict, description="Kwargs to pass to server constructor") + init_kwargs: dict[str, Any] = Field(default_factory=dict, description="Kwargs to pass to server constructor") -async def start_mock_servers_async(mock_servers_config: List[MockServerConfig]) -> Dict[str, MockServerBase]: +async def start_mock_servers_async(mock_servers_config: list[MockServerConfig]) -> dict[str, MockServerBase]: """ Start all mock servers in parallel and wait for them to be ready. - + **HOW IT WORKS:** 1. Creates all server instances 2. Calls await_start() on all servers in parallel 3. Returns when all are ready - + **SIMPLE TO USE:** servers = await start_mock_servers_async([config1, config2, ...]) - + Args: mock_servers_config: List of mock server configurations - + Returns: Dict mapping server name to server instance """ servers = {} start_tasks = [] - + # Create all servers and prepare start tasks for config in mock_servers_config: server = config.server_class(**config.init_kwargs) servers[config.name] = server start_tasks.append(server.await_start()) - + # Start all servers in parallel try: await asyncio.gather(*start_tasks) - + # Print readiness confirmation for name in servers.keys(): print(f"[INFO] {name} ready") - + except Exception as e: # If any server fails, stop all servers for server in servers.values(): try: server.stop() - except: + except Exception: pass - raise RuntimeError(f"Failed to start mock servers: {e}") - + raise RuntimeError(f"Failed to start mock servers: {e}") from None + return servers -def stop_mock_servers(servers: Dict[str, Any]): +def stop_mock_servers(servers: dict[str, Any]): """ Stop all mock servers. - + Args: servers: Dict of server instances from start_mock_servers_async() """ for name, server in servers.items(): try: - if hasattr(server, 'get_request_count'): + if hasattr(server, "get_request_count"): print(f"\n[INFO] {name} received {server.get_request_count()} requests") server.stop() except Exception as e: print(f"[WARN] Error stopping {name}: {e}") - diff --git a/tests/integration/telemetry/mocking/mock_base.py b/tests/integration/telemetry/mocking/mock_base.py index 803058457..5eebcab7a 100644 --- a/tests/integration/telemetry/mocking/mock_base.py +++ b/tests/integration/telemetry/mocking/mock_base.py @@ -10,25 +10,25 @@ Base class for mock servers with async startup support. All mock servers should inherit from MockServerBase and implement await_start(). """ -import asyncio from abc import abstractmethod -from pydantic import BaseModel, Field + +from pydantic import BaseModel class MockServerBase(BaseModel): """ Pydantic base model for mock servers. - + **TO CREATE A NEW MOCK SERVER:** 1. Inherit from this class 2. Implement async def await_start(self) 3. Implement def stop(self) 4. Done! - + Example: class MyMockServer(MockServerBase): port: int = 8080 - + async def await_start(self): # Start your server self.server = create_server() @@ -36,34 +36,33 @@ class MockServerBase(BaseModel): # Wait until ready (can check internal state, no HTTP needed) while not self.server.is_listening(): await asyncio.sleep(0.1) - + def stop(self): if self.server: self.server.stop() """ - + model_config = {"arbitrary_types_allowed": True} - + @abstractmethod async def await_start(self): """ Start the server and wait until it's ready. - + This method should: 1. Start the server (synchronous or async) 2. Wait until the server is fully ready to accept requests 3. Return when ready - + Subclasses can check internal state directly - no HTTP polling needed! """ ... - + @abstractmethod def stop(self): """ Stop the server and clean up resources. - + This method should gracefully shut down the server. """ ... - diff --git a/tests/integration/telemetry/mocking/servers.py b/tests/integration/telemetry/mocking/servers.py index e055f41b6..fd63f9baf 100644 --- a/tests/integration/telemetry/mocking/servers.py +++ b/tests/integration/telemetry/mocking/servers.py @@ -20,7 +20,7 @@ import json import socket import threading import time -from typing import Any, Dict, List +from typing import Any from pydantic import Field @@ -30,10 +30,10 @@ from .mock_base import MockServerBase class MockOTLPCollector(MockServerBase): """ Mock OTLP collector HTTP server. - + Receives real OTLP exports from Llama Stack and stores them for verification. Runs on localhost:4318 (standard OTLP HTTP port). - + Usage: collector = MockOTLPCollector() await collector.await_start() @@ -41,115 +41,119 @@ class MockOTLPCollector(MockServerBase): print(f"Received {collector.get_trace_count()} traces") collector.stop() """ - + port: int = Field(default=4318, description="Port to run collector on") - + # Non-Pydantic fields (set after initialization) - traces: List[Dict] = Field(default_factory=list, exclude=True) - metrics: List[Dict] = Field(default_factory=list, exclude=True) + traces: list[dict] = Field(default_factory=list, exclude=True) + metrics: list[dict] = Field(default_factory=list, exclude=True) server: Any = Field(default=None, exclude=True) server_thread: Any = Field(default=None, exclude=True) - + def model_post_init(self, __context): """Initialize after Pydantic validation.""" self.traces = [] self.metrics = [] self.server = None self.server_thread = None - + def _create_handler_class(self): """Create the HTTP handler class for this collector instance.""" collector_self = self - + class OTLPHandler(http.server.BaseHTTPRequestHandler): """HTTP request handler for OTLP requests.""" - + def log_message(self, format, *args): """Suppress HTTP server logs.""" pass - - def do_GET(self): + + def do_GET(self): # noqa: N802 """Handle GET requests.""" # No readiness endpoint needed - using await_start() instead self.send_response(404) self.end_headers() - - def do_POST(self): + + def do_POST(self): # noqa: N802 """Handle OTLP POST requests.""" - content_length = int(self.headers.get('Content-Length', 0)) - body = self.rfile.read(content_length) if content_length > 0 else b'' - + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) if content_length > 0 else b"" + # Store the export request - if '/v1/traces' in self.path: - collector_self.traces.append({ - 'body': body, - 'timestamp': time.time(), - }) - elif '/v1/metrics' in self.path: - collector_self.metrics.append({ - 'body': body, - 'timestamp': time.time(), - }) - + if "/v1/traces" in self.path: + collector_self.traces.append( + { + "body": body, + "timestamp": time.time(), + } + ) + elif "/v1/metrics" in self.path: + collector_self.metrics.append( + { + "body": body, + "timestamp": time.time(), + } + ) + # Always return success (200 OK) self.send_response(200) - self.send_header('Content-Type', 'application/json') + self.send_header("Content-Type", "application/json") self.end_headers() - self.wfile.write(b'{}') - + self.wfile.write(b"{}") + return OTLPHandler - + async def await_start(self): """ Start the OTLP collector and wait until ready. - + This method is async and can be awaited to ensure the server is ready. """ # Create handler and start the HTTP server handler_class = self._create_handler_class() - self.server = http.server.HTTPServer(('localhost', self.port), handler_class) + self.server = http.server.HTTPServer(("localhost", self.port), handler_class) self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True) self.server_thread.start() - + # Wait for server to be listening on the port for _ in range(10): try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - result = sock.connect_ex(('localhost', self.port)) + result = sock.connect_ex(("localhost", self.port)) sock.close() if result == 0: # Port is listening return - except: + except Exception: pass await asyncio.sleep(0.1) - + raise RuntimeError(f"OTLP collector failed to start on port {self.port}") - + def stop(self): """Stop the OTLP collector server.""" if self.server: self.server.shutdown() self.server.server_close() - + def clear(self): """Clear all captured telemetry data.""" self.traces = [] self.metrics = [] - + def get_trace_count(self) -> int: """Get number of trace export requests received.""" return len(self.traces) - + def get_metric_count(self) -> int: """Get number of metric export requests received.""" return len(self.metrics) - - def get_all_traces(self) -> List[Dict]: + + def get_all_traces(self) -> list[dict]: """Get all captured trace exports.""" return self.traces - - def get_all_metrics(self) -> List[Dict]: + + def get_all_metrics(self) -> list[dict]: """Get all captured metric exports.""" return self.metrics @@ -157,14 +161,14 @@ class MockOTLPCollector(MockServerBase): class MockVLLMServer(MockServerBase): """ Mock vLLM inference server with OpenAI-compatible API. - + Returns valid OpenAI Python client response objects for: - Chat completions (/v1/chat/completions) - Text completions (/v1/completions) - Model listing (/v1/models) - + Runs on localhost:8000 (standard vLLM port). - + Usage: server = MockVLLMServer(models=["my-model"]) await server.await_start() @@ -172,94 +176,97 @@ class MockVLLMServer(MockServerBase): print(f"Handled {server.get_request_count()} requests") server.stop() """ - + port: int = Field(default=8000, description="Port to run server on") - models: List[str] = Field( - default_factory=lambda: ["meta-llama/Llama-3.2-1B-Instruct"], - description="List of model IDs to serve" + models: list[str] = Field( + default_factory=lambda: ["meta-llama/Llama-3.2-1B-Instruct"], description="List of model IDs to serve" ) - + # Non-Pydantic fields - requests_received: List[Dict] = Field(default_factory=list, exclude=True) + requests_received: list[dict] = Field(default_factory=list, exclude=True) server: Any = Field(default=None, exclude=True) server_thread: Any = Field(default=None, exclude=True) - + def model_post_init(self, __context): """Initialize after Pydantic validation.""" self.requests_received = [] self.server = None self.server_thread = None - + def _create_handler_class(self): """Create the HTTP handler class for this vLLM instance.""" server_self = self - + class VLLMHandler(http.server.BaseHTTPRequestHandler): """HTTP request handler for vLLM API.""" - + def log_message(self, format, *args): """Suppress HTTP server logs.""" pass - - def log_request(self, code='-', size='-'): + + def log_request(self, code="-", size="-"): """Log incoming requests for debugging.""" print(f"[DEBUG] Mock vLLM received: {self.command} {self.path} -> {code}") - - def do_GET(self): + + def do_GET(self): # noqa: N802 """Handle GET requests (models list, health check).""" # Log GET requests too - server_self.requests_received.append({ - 'path': self.path, - 'method': 'GET', - 'timestamp': time.time(), - }) - - if self.path == '/v1/models': + server_self.requests_received.append( + { + "path": self.path, + "method": "GET", + "timestamp": time.time(), + } + ) + + if self.path == "/v1/models": response = self._create_models_list_response() self._send_json_response(200, response) - - elif self.path == '/health' or self.path == '/v1/health': + + elif self.path == "/health" or self.path == "/v1/health": self._send_json_response(200, {"status": "healthy"}) - + else: self.send_response(404) self.end_headers() - - def do_POST(self): + + def do_POST(self): # noqa: N802 """Handle POST requests (chat/text completions).""" - content_length = int(self.headers.get('Content-Length', 0)) - body = self.rfile.read(content_length) if content_length > 0 else b'{}' - + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) if content_length > 0 else b"{}" + try: request_data = json.loads(body) - except: + except Exception: request_data = {} - + # Log the request - server_self.requests_received.append({ - 'path': self.path, - 'body': request_data, - 'timestamp': time.time(), - }) - + server_self.requests_received.append( + { + "path": self.path, + "body": request_data, + "timestamp": time.time(), + } + ) + # Route to appropriate handler - if '/chat/completions' in self.path: + if "/chat/completions" in self.path: response = self._create_chat_completion_response(request_data) self._send_json_response(200, response) - - elif '/completions' in self.path: + + elif "/completions" in self.path: response = self._create_text_completion_response(request_data) self._send_json_response(200, response) - + else: self._send_json_response(200, {"status": "ok"}) - + # ---------------------------------------------------------------- # Response Generators # **TO MODIFY RESPONSES:** Edit these methods # ---------------------------------------------------------------- - - def _create_models_list_response(self) -> Dict: + + def _create_models_list_response(self) -> dict: """Create OpenAI models list response with configured models.""" return { "object": "list", @@ -271,13 +278,13 @@ class MockVLLMServer(MockServerBase): "owned_by": "meta", } for model_id in server_self.models - ] + ], } - - def _create_chat_completion_response(self, request_data: Dict) -> Dict: + + def _create_chat_completion_response(self, request_data: dict) -> dict: """ Create OpenAI ChatCompletion response. - + Returns a valid response matching openai.types.ChatCompletion """ return { @@ -285,16 +292,18 @@ class MockVLLMServer(MockServerBase): "object": "chat.completion", "created": int(time.time()), "model": request_data.get("model", "meta-llama/Llama-3.2-1B-Instruct"), - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "This is a test response from mock vLLM server.", - "tool_calls": None, - }, - "logprobs": None, - "finish_reason": "stop", - }], + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "This is a test response from mock vLLM server.", + "tool_calls": None, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], "usage": { "prompt_tokens": 25, "completion_tokens": 15, @@ -304,11 +313,11 @@ class MockVLLMServer(MockServerBase): "system_fingerprint": None, "service_tier": None, } - - def _create_text_completion_response(self, request_data: Dict) -> Dict: + + def _create_text_completion_response(self, request_data: dict) -> dict: """ Create OpenAI Completion response. - + Returns a valid response matching openai.types.Completion """ return { @@ -316,12 +325,14 @@ class MockVLLMServer(MockServerBase): "object": "text_completion", "created": int(time.time()), "model": request_data.get("model", "meta-llama/Llama-3.2-1B-Instruct"), - "choices": [{ - "text": "This is a test completion.", - "index": 0, - "logprobs": None, - "finish_reason": "stop", - }], + "choices": [ + { + "text": "This is a test completion.", + "index": 0, + "logprobs": None, + "finish_reason": "stop", + } + ], "usage": { "prompt_tokens": 10, "completion_tokens": 8, @@ -330,58 +341,57 @@ class MockVLLMServer(MockServerBase): }, "system_fingerprint": None, } - - def _send_json_response(self, status_code: int, data: Dict): + + def _send_json_response(self, status_code: int, data: dict): """Helper to send JSON response.""" self.send_response(status_code) - self.send_header('Content-Type', 'application/json') + self.send_header("Content-Type", "application/json") self.end_headers() self.wfile.write(json.dumps(data).encode()) - + return VLLMHandler - + async def await_start(self): """ Start the vLLM server and wait until ready. - + This method is async and can be awaited to ensure the server is ready. """ # Create handler and start the HTTP server handler_class = self._create_handler_class() - self.server = http.server.HTTPServer(('localhost', self.port), handler_class) + self.server = http.server.HTTPServer(("localhost", self.port), handler_class) self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True) self.server_thread.start() - + # Wait for server to be listening on the port for _ in range(10): try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - result = sock.connect_ex(('localhost', self.port)) + result = sock.connect_ex(("localhost", self.port)) sock.close() if result == 0: # Port is listening return - except: + except Exception: pass await asyncio.sleep(0.1) - + raise RuntimeError(f"vLLM server failed to start on port {self.port}") - + def stop(self): """Stop the vLLM server.""" if self.server: self.server.shutdown() self.server.server_close() - + def clear(self): """Clear request history.""" self.requests_received = [] - + def get_request_count(self) -> int: """Get number of requests received.""" return len(self.requests_received) - - def get_all_requests(self) -> List[Dict]: + + def get_all_requests(self) -> list[dict]: """Get all received requests with their bodies.""" return self.requests_received - diff --git a/tests/integration/telemetry/test_otel_e2e.py b/tests/integration/telemetry/test_otel_e2e.py index 06ad79383..3df36db30 100644 --- a/tests/integration/telemetry/test_otel_e2e.py +++ b/tests/integration/telemetry/test_otel_e2e.py @@ -34,7 +34,7 @@ import os import socket import subprocess import time -from typing import Any, Dict, List +from typing import Any import pytest import requests @@ -44,28 +44,28 @@ from pydantic import BaseModel, Field # Mock servers are in the mocking/ subdirectory from .mocking import ( MockOTLPCollector, - MockVLLMServer, MockServerConfig, + MockVLLMServer, start_mock_servers_async, stop_mock_servers, ) - # ============================================================================ # DATA MODELS # ============================================================================ + class TelemetryTestCase(BaseModel): """ Pydantic model defining expected telemetry for an API call. - + **TO ADD A NEW TEST CASE:** Add to TEST_CASES list below. """ - + name: str = Field(description="Unique test case identifier") http_method: str = Field(description="HTTP method (GET, POST, etc.)") api_path: str = Field(description="API path (e.g., '/v1/models')") - request_body: Dict[str, Any] | None = Field(default=None) + request_body: dict[str, Any] | None = Field(default=None) expected_http_status: int = Field(default=200) expected_trace_exports: int = Field(default=1, description="Minimum number of trace exports expected") expected_metric_exports: int = Field(default=0, description="Minimum number of metric exports expected") @@ -103,71 +103,74 @@ TEST_CASES = [ # TEST INFRASTRUCTURE # ============================================================================ + class TelemetryTestRunner: """ Executes TelemetryTestCase instances against real Llama Stack. - + **HOW IT WORKS:** 1. Makes real HTTP request to the stack 2. Waits for telemetry export 3. Verifies exports were sent to mock collector """ - + def __init__(self, base_url: str, collector: MockOTLPCollector): self.base_url = base_url self.collector = collector - + def run_test_case(self, test_case: TelemetryTestCase, verbose: bool = False) -> bool: """Execute a single test case and verify telemetry.""" initial_traces = self.collector.get_trace_count() initial_metrics = self.collector.get_metric_count() - + if verbose: print(f"\n--- {test_case.name} ---") print(f" {test_case.http_method} {test_case.api_path}") - + # Make real HTTP request to Llama Stack try: url = f"{self.base_url}{test_case.api_path}" - + if test_case.http_method == "GET": response = requests.get(url, timeout=5) elif test_case.http_method == "POST": response = requests.post(url, json=test_case.request_body or {}, timeout=5) else: response = requests.request(test_case.http_method, url, timeout=5) - + if verbose: print(f" HTTP Response: {response.status_code}") - + status_match = response.status_code == test_case.expected_http_status - + except requests.exceptions.RequestException as e: if verbose: print(f" Request failed: {e}") status_match = False - + # Wait for automatic instrumentation to export telemetry # Traces export immediately, metrics export every 1 second (configured via env var) time.sleep(2.0) # Wait for both traces and metrics to export - + # Verify traces were exported to mock collector new_traces = self.collector.get_trace_count() - initial_traces traces_exported = new_traces >= test_case.expected_trace_exports - + # Verify metrics were exported (if expected) new_metrics = self.collector.get_metric_count() - initial_metrics metrics_exported = new_metrics >= test_case.expected_metric_exports - + if verbose: - print(f" Expected: >={test_case.expected_trace_exports} trace exports, >={test_case.expected_metric_exports} metric exports") + print( + f" Expected: >={test_case.expected_trace_exports} trace exports, >={test_case.expected_metric_exports} metric exports" + ) print(f" Actual: {new_traces} trace exports, {new_metrics} metric exports") result = status_match and traces_exported and metrics_exported print(f" Result: {'PASS' if result else 'FAIL'}") - + return status_match and traces_exported and metrics_exported - - def run_all_test_cases(self, test_cases: List[TelemetryTestCase], verbose: bool = True) -> Dict[str, bool]: + + def run_all_test_cases(self, test_cases: list[TelemetryTestCase], verbose: bool = True) -> dict[str, bool]: """Run all test cases and return results.""" results = {} for test_case in test_cases: @@ -179,11 +182,12 @@ class TelemetryTestRunner: # HELPER FUNCTIONS # ============================================================================ + def is_port_available(port: int) -> bool: """Check if a TCP port is available for binding.""" try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(('localhost', port)) + sock.bind(("localhost", port)) return True except OSError: return False @@ -193,20 +197,21 @@ def is_port_available(port: int) -> bool: # PYTEST FIXTURES # ============================================================================ + @pytest.fixture(scope="module") def mock_servers(): """ Fixture: Start all mock servers in parallel using async harness. - + **TO ADD A NEW MOCK SERVER:** Just add a MockServerConfig to the MOCK_SERVERS list below. """ import asyncio - + # ======================================================================== # MOCK SERVER CONFIGURATION # **TO ADD A NEW MOCK:** Just add a MockServerConfig instance below - # + # # Example: # MockServerConfig( # name="Mock MyService", @@ -214,7 +219,7 @@ def mock_servers(): # init_kwargs={"port": 9000, "param": "value"}, # ), # ======================================================================== - MOCK_SERVERS = [ + mock_servers_config = [ MockServerConfig( name="Mock OTLP Collector", server_class=MockOTLPCollector, @@ -230,17 +235,17 @@ def mock_servers(): ), # Add more mock servers here - they will start in parallel automatically! ] - + # Start all servers in parallel - servers = asyncio.run(start_mock_servers_async(MOCK_SERVERS)) - + servers = asyncio.run(start_mock_servers_async(mock_servers_config)) + # Verify vLLM models models_response = requests.get("http://localhost:8000/v1/models", timeout=1) models_data = models_response.json() print(f"[INFO] Mock vLLM serving {len(models_data['data'])} models: {[m['id'] for m in models_data['data']]}") - + yield servers - + # Stop all servers stop_mock_servers(servers) @@ -261,22 +266,22 @@ def mock_vllm_server(mock_servers): def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server): """ Fixture: Start real Llama Stack server with automatic OTel instrumentation. - + **THIS IS THE MAIN FIXTURE** - it runs: opentelemetry-instrument llama stack run --config run.yaml - + **TO MODIFY STACK CONFIG:** Edit run_config dict below """ config_dir = tmp_path_factory.mktemp("otel-stack-config") - + # Ensure mock vLLM is ready and accessible before starting Llama Stack - print(f"\n[INFO] Verifying mock vLLM is accessible at http://localhost:8000...") + print("\n[INFO] Verifying mock vLLM is accessible at http://localhost:8000...") try: vllm_models = requests.get("http://localhost:8000/v1/models", timeout=2) print(f"[INFO] Mock vLLM models endpoint response: {vllm_models.status_code}") except Exception as e: pytest.fail(f"Mock vLLM not accessible before starting Llama Stack: {e}") - + # Create run.yaml with inference provider # **TO ADD MORE PROVIDERS:** Add to providers dict run_config = { @@ -300,19 +305,19 @@ def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server): } ], } - + config_file = config_dir / "run.yaml" with open(config_file, "w") as f: yaml.dump(run_config, f) - + # Find available port for Llama Stack port = 5555 while not is_port_available(port) and port < 5600: port += 1 - + if port >= 5600: pytest.skip("No available ports for test server") - + # Set environment variables for OTel instrumentation # NOTE: These only affect the subprocess, not other tests env = os.environ.copy() @@ -321,29 +326,32 @@ def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server): env["OTEL_SERVICE_NAME"] = "llama-stack-e2e-test" env["LLAMA_STACK_PORT"] = str(port) env["OTEL_PYTHON_LOGGING_AUTO_INSTRUMENTATION_ENABLED"] = "true" - + # Configure fast metric export for testing (default is 60 seconds) # This makes metrics export every 500ms instead of every 60 seconds env["OTEL_METRIC_EXPORT_INTERVAL"] = "500" # milliseconds env["OTEL_METRIC_EXPORT_TIMEOUT"] = "1000" # milliseconds - + # Disable inference recording to ensure real requests to our mock vLLM # This is critical - without this, Llama Stack replays cached responses # Safe to remove here as it only affects the subprocess environment if "LLAMA_STACK_TEST_INFERENCE_MODE" in env: del env["LLAMA_STACK_TEST_INFERENCE_MODE"] - + # Start server with automatic instrumentation cmd = [ "opentelemetry-instrument", # ← Automatic instrumentation wrapper - "llama", "stack", "run", + "llama", + "stack", + "run", str(config_file), - "--port", str(port), + "--port", + str(port), ] - + print(f"\n[INFO] Starting Llama Stack with OTel instrumentation on port {port}") print(f"[INFO] Command: {' '.join(cmd)}") - + process = subprocess.Popen( cmd, env=env, @@ -351,11 +359,11 @@ def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server): stderr=subprocess.PIPE, text=True, ) - + # Wait for server to start max_wait = 30 base_url = f"http://localhost:{port}" - + for i in range(max_wait): try: response = requests.get(f"{base_url}/v1/health", timeout=1) @@ -368,16 +376,16 @@ def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server): stdout, stderr = process.communicate(timeout=5) pytest.fail(f"Server failed to start.\nStdout: {stdout}\nStderr: {stderr}") time.sleep(1) - + yield { - 'base_url': base_url, - 'port': port, - 'collector': mock_otlp_collector, - 'vllm_server': mock_vllm_server, + "base_url": base_url, + "port": port, + "collector": mock_otlp_collector, + "vllm_server": mock_vllm_server, } - + # Cleanup - print(f"\n[INFO] Stopping Llama Stack server") + print("\n[INFO] Stopping Llama Stack server") process.terminate() try: process.wait(timeout=5) @@ -391,26 +399,27 @@ def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server): # **TO ADD NEW E2E TESTS:** Add methods to this class # ============================================================================ + @pytest.mark.slow class TestOTelE2E: """ End-to-end tests with real Llama Stack server. - + These tests verify the complete flow: - Real Llama Stack with opentelemetry-instrument - Real API calls - Real automatic instrumentation - Mock OTLP collector captures exports """ - + def test_server_starts_with_auto_instrumentation(self, llama_stack_server): """Verify server starts successfully with opentelemetry-instrument.""" - base_url = llama_stack_server['base_url'] - + base_url = llama_stack_server["base_url"] + # Try different health check endpoints health_endpoints = ["/health", "/v1/health", "/"] server_responding = False - + for endpoint in health_endpoints: try: response = requests.get(f"{base_url}{endpoint}", timeout=5) @@ -420,36 +429,36 @@ class TestOTelE2E: break except Exception as e: print(f"[DEBUG] {endpoint} failed: {e}") - + assert server_responding, f"Server not responding on any endpoint at {base_url}" - + print(f"\n[PASS] Llama Stack running with OTel at {base_url}") - + def test_all_test_cases_via_runner(self, llama_stack_server): """ **MAIN TEST:** Run all TelemetryTestCase instances. - + This executes all test cases defined in TEST_CASES list. **TO ADD MORE TESTS:** Add to TEST_CASES at top of file """ - base_url = llama_stack_server['base_url'] - collector = llama_stack_server['collector'] - + base_url = llama_stack_server["base_url"] + collector = llama_stack_server["collector"] + # Create test runner runner = TelemetryTestRunner(base_url, collector) - + # Execute all test cases results = runner.run_all_test_cases(TEST_CASES, verbose=True) - + # Print summary - print(f"\n{'='*50}") - print(f"TEST CASE SUMMARY") - print(f"{'='*50}") + print(f"\n{'=' * 50}") + print("TEST CASE SUMMARY") + print(f"{'=' * 50}") passed = sum(1 for p in results.values() if p) total = len(results) print(f"Passed: {passed}/{total}\n") - + for name, result in results.items(): status = "[PASS]" if result else "[FAIL]" print(f" {status} {name}") - print(f"{'='*50}\n") + print(f"{'=' * 50}\n") diff --git a/tests/unit/providers/telemetry/meta_reference.py b/tests/unit/providers/telemetry/meta_reference.py index 26146e133..c7c81f01f 100644 --- a/tests/unit/providers/telemetry/meta_reference.py +++ b/tests/unit/providers/telemetry/meta_reference.py @@ -4,8 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging - import pytest import llama_stack.providers.inline.telemetry.meta_reference.telemetry as telemetry_module @@ -38,7 +36,7 @@ def test_warns_when_traces_endpoints_missing(monkeypatch: pytest.MonkeyPatch, ca monkeypatch.delenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", raising=False) monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) - caplog.set_level(logging.WARNING) + caplog.set_level("WARNING") config = _make_config_with_sinks(TelemetrySink.OTEL_TRACE) telemetry_module.TelemetryAdapter(config=config, deps={}) @@ -57,7 +55,7 @@ def test_warns_when_metrics_endpoints_missing(monkeypatch: pytest.MonkeyPatch, c monkeypatch.delenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", raising=False) monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) - caplog.set_level(logging.WARNING) + caplog.set_level("WARNING") config = _make_config_with_sinks(TelemetrySink.OTEL_METRIC) telemetry_module.TelemetryAdapter(config=config, deps={}) @@ -76,7 +74,7 @@ def test_no_warning_when_traces_endpoints_present(monkeypatch: pytest.MonkeyPatc monkeypatch.setenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "https://otel.example:4318/v1/traces") monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://otel.example:4318") - caplog.set_level(logging.WARNING) + caplog.set_level("WARNING") config = _make_config_with_sinks(TelemetrySink.OTEL_TRACE) telemetry_module.TelemetryAdapter(config=config, deps={}) @@ -91,7 +89,7 @@ def test_no_warning_when_metrics_endpoints_present(monkeypatch: pytest.MonkeyPat monkeypatch.setenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", "https://otel.example:4318/v1/metrics") monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://otel.example:4318") - caplog.set_level(logging.WARNING) + caplog.set_level("WARNING") config = _make_config_with_sinks(TelemetrySink.OTEL_METRIC) telemetry_module.TelemetryAdapter(config=config, deps={}) diff --git a/tests/unit/providers/telemetry/test_otel.py b/tests/unit/providers/telemetry/test_otel.py index 5d10d74a8..2db22f37a 100644 --- a/tests/unit/providers/telemetry/test_otel.py +++ b/tests/unit/providers/telemetry/test_otel.py @@ -41,17 +41,17 @@ class TestOTelTelemetryProviderInitialization: def test_initialization_with_valid_config(self, otel_config, monkeypatch): """Test that provider initializes correctly with valid configuration.""" monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4318") - + provider = OTelTelemetryProvider(config=otel_config) - + assert provider.config == otel_config def test_initialization_sets_service_attributes(self, otel_config, monkeypatch): """Test that service attributes are properly configured.""" monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4318") - + provider = OTelTelemetryProvider(config=otel_config) - + assert provider.config.service_name == "test-service" assert provider.config.service_version == "1.0.0" assert provider.config.deployment_environment == "test" @@ -65,9 +65,9 @@ class TestOTelTelemetryProviderInitialization: deployment_environment="test", span_processor="batch", ) - + provider = OTelTelemetryProvider(config=config) - + assert provider.config.span_processor == "batch" def test_warns_when_endpoints_missing(self, otel_config, monkeypatch, caplog): @@ -76,9 +76,9 @@ class TestOTelTelemetryProviderInitialization: monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) monkeypatch.delenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", raising=False) monkeypatch.delenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", raising=False) - + OTelTelemetryProvider(config=otel_config) - + # Check that warnings were logged assert any("Traces will not be exported" in record.message for record in caplog.records) assert any("Metrics will not be exported" in record.message for record in caplog.records) @@ -90,44 +90,41 @@ class TestOTelTelemetryProviderTracerAPI: def test_get_tracer_returns_tracer(self, otel_provider): """Test that get_tracer returns a valid Tracer instance.""" tracer = otel_provider.get_tracer("test.module") - + assert tracer is not None assert isinstance(tracer, Tracer) def test_get_tracer_with_version(self, otel_provider): """Test that get_tracer works with version parameter.""" tracer = otel_provider.get_tracer( - instrumenting_module_name="test.module", - instrumenting_library_version="1.0.0" + instrumenting_module_name="test.module", instrumenting_library_version="1.0.0" ) - + assert tracer is not None assert isinstance(tracer, Tracer) def test_get_tracer_with_attributes(self, otel_provider): """Test that get_tracer works with attributes.""" tracer = otel_provider.get_tracer( - instrumenting_module_name="test.module", - attributes={"component": "test", "tier": "backend"} + instrumenting_module_name="test.module", attributes={"component": "test", "tier": "backend"} ) - + assert tracer is not None assert isinstance(tracer, Tracer) def test_get_tracer_with_schema_url(self, otel_provider): """Test that get_tracer works with schema URL.""" tracer = otel_provider.get_tracer( - instrumenting_module_name="test.module", - schema_url="https://example.com/schema" + instrumenting_module_name="test.module", schema_url="https://example.com/schema" ) - + assert tracer is not None assert isinstance(tracer, Tracer) def test_tracer_can_create_spans(self, otel_provider): """Test that tracer can create spans.""" tracer = otel_provider.get_tracer("test.module") - + with tracer.start_as_current_span("test.operation") as span: assert span is not None assert span.is_recording() @@ -135,11 +132,8 @@ class TestOTelTelemetryProviderTracerAPI: def test_tracer_can_create_spans_with_attributes(self, otel_provider): """Test that tracer can create spans with attributes.""" tracer = otel_provider.get_tracer("test.module") - - with tracer.start_as_current_span( - "test.operation", - attributes={"user.id": "123", "request.id": "abc"} - ) as span: + + with tracer.start_as_current_span("test.operation", attributes={"user.id": "123", "request.id": "abc"}) as span: assert span is not None assert span.is_recording() @@ -147,7 +141,7 @@ class TestOTelTelemetryProviderTracerAPI: """Test that multiple tracers can be created.""" tracer1 = otel_provider.get_tracer("module.one") tracer2 = otel_provider.get_tracer("module.two") - + assert tracer1 is not None assert tracer2 is not None # Tracers with different names might be the same instance or different @@ -164,50 +158,37 @@ class TestOTelTelemetryProviderMeterAPI: def test_get_meter_returns_meter(self, otel_provider): """Test that get_meter returns a valid Meter instance.""" meter = otel_provider.get_meter("test.meter") - + assert meter is not None assert isinstance(meter, Meter) def test_get_meter_with_version(self, otel_provider): """Test that get_meter works with version parameter.""" - meter = otel_provider.get_meter( - name="test.meter", - version="1.0.0" - ) - + meter = otel_provider.get_meter(name="test.meter", version="1.0.0") + assert meter is not None assert isinstance(meter, Meter) def test_get_meter_with_attributes(self, otel_provider): """Test that get_meter works with attributes.""" - meter = otel_provider.get_meter( - name="test.meter", - attributes={"service": "test", "env": "dev"} - ) - + meter = otel_provider.get_meter(name="test.meter", attributes={"service": "test", "env": "dev"}) + assert meter is not None assert isinstance(meter, Meter) def test_get_meter_with_schema_url(self, otel_provider): """Test that get_meter works with schema URL.""" - meter = otel_provider.get_meter( - name="test.meter", - schema_url="https://example.com/schema" - ) - + meter = otel_provider.get_meter(name="test.meter", schema_url="https://example.com/schema") + assert meter is not None assert isinstance(meter, Meter) def test_meter_can_create_counter(self, otel_provider): """Test that meter can create counters.""" meter = otel_provider.get_meter("test.meter") - - counter = meter.create_counter( - "test.requests.total", - unit="requests", - description="Total requests" - ) - + + counter = meter.create_counter("test.requests.total", unit="requests", description="Total requests") + assert counter is not None # Test that counter can be used counter.add(1, {"endpoint": "/test"}) @@ -215,13 +196,9 @@ class TestOTelTelemetryProviderMeterAPI: def test_meter_can_create_histogram(self, otel_provider): """Test that meter can create histograms.""" meter = otel_provider.get_meter("test.meter") - - histogram = meter.create_histogram( - "test.request.duration", - unit="ms", - description="Request duration" - ) - + + histogram = meter.create_histogram("test.request.duration", unit="ms", description="Request duration") + assert histogram is not None # Test that histogram can be used histogram.record(42.5, {"method": "GET"}) @@ -229,13 +206,11 @@ class TestOTelTelemetryProviderMeterAPI: def test_meter_can_create_up_down_counter(self, otel_provider): """Test that meter can create up/down counters.""" meter = otel_provider.get_meter("test.meter") - + up_down_counter = meter.create_up_down_counter( - "test.active.connections", - unit="connections", - description="Active connections" + "test.active.connections", unit="connections", description="Active connections" ) - + assert up_down_counter is not None # Test that up/down counter can be used up_down_counter.add(5) @@ -244,31 +219,28 @@ class TestOTelTelemetryProviderMeterAPI: def test_meter_can_create_observable_gauge(self, otel_provider): """Test that meter can create observable gauges.""" meter = otel_provider.get_meter("test.meter") - + def gauge_callback(options): return [{"attributes": {"host": "localhost"}, "value": 42.0}] - + gauge = meter.create_observable_gauge( - "test.memory.usage", - callbacks=[gauge_callback], - unit="bytes", - description="Memory usage" + "test.memory.usage", callbacks=[gauge_callback], unit="bytes", description="Memory usage" ) - + assert gauge is not None def test_multiple_instruments_from_same_meter(self, otel_provider): """Test that a meter can create multiple instruments.""" meter = otel_provider.get_meter("test.meter") - + counter = meter.create_counter("test.counter") histogram = meter.create_histogram("test.histogram") up_down_counter = meter.create_up_down_counter("test.gauge") - + assert counter is not None assert histogram is not None assert up_down_counter is not None - + # Verify they all work counter.add(1) histogram.record(10.0) @@ -281,107 +253,101 @@ class TestOTelTelemetryProviderNativeUsage: def test_complete_tracing_workflow(self, otel_provider): """Test a complete tracing workflow using native OTel API.""" tracer = otel_provider.get_tracer("llama_stack.inference") - + # Create parent span with tracer.start_as_current_span("inference.request") as parent_span: parent_span.set_attribute("model", "llama-3.2-1b") parent_span.set_attribute("user", "test-user") - + # Create child span with tracer.start_as_current_span("model.load") as child_span: child_span.set_attribute("model.size", "1B") assert child_span.is_recording() - + # Create another child span with tracer.start_as_current_span("inference.execute") as child_span: child_span.set_attribute("tokens.input", 25) child_span.set_attribute("tokens.output", 150) assert child_span.is_recording() - + assert parent_span.is_recording() def test_complete_metrics_workflow(self, otel_provider): """Test a complete metrics workflow using native OTel API.""" meter = otel_provider.get_meter("llama_stack.metrics") - + # Create various instruments - request_counter = meter.create_counter( - "llama.requests.total", - unit="requests", - description="Total requests" - ) - + request_counter = meter.create_counter("llama.requests.total", unit="requests", description="Total requests") + latency_histogram = meter.create_histogram( - "llama.inference.duration", - unit="ms", - description="Inference duration" + "llama.inference.duration", unit="ms", description="Inference duration" ) - + active_sessions = meter.create_up_down_counter( - "llama.sessions.active", - unit="sessions", - description="Active sessions" + "llama.sessions.active", unit="sessions", description="Active sessions" ) - + # Use the instruments request_counter.add(1, {"endpoint": "/chat", "status": "success"}) latency_histogram.record(123.45, {"model": "llama-3.2-1b"}) active_sessions.add(1) active_sessions.add(-1) - + # No exceptions means success def test_concurrent_tracer_usage(self, otel_provider): """Test that multiple threads can use tracers concurrently.""" + def create_spans(thread_id): tracer = otel_provider.get_tracer(f"test.module.{thread_id}") for i in range(10): with tracer.start_as_current_span(f"operation.{i}") as span: span.set_attribute("thread.id", thread_id) span.set_attribute("iteration", i) - + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [executor.submit(create_spans, i) for i in range(10)] concurrent.futures.wait(futures) - + # If we get here without exceptions, thread safety is working def test_concurrent_meter_usage(self, otel_provider): """Test that multiple threads can use meters concurrently.""" + def record_metrics(thread_id): meter = otel_provider.get_meter(f"test.meter.{thread_id}") counter = meter.create_counter(f"test.counter.{thread_id}") histogram = meter.create_histogram(f"test.histogram.{thread_id}") - + for i in range(10): counter.add(1, {"thread": str(thread_id)}) histogram.record(float(i * 10), {"thread": str(thread_id)}) - + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [executor.submit(record_metrics, i) for i in range(10)] concurrent.futures.wait(futures) - + # If we get here without exceptions, thread safety is working def test_mixed_tracing_and_metrics(self, otel_provider): """Test using both tracing and metrics together.""" tracer = otel_provider.get_tracer("test.module") meter = otel_provider.get_meter("test.meter") - + counter = meter.create_counter("operations.count") histogram = meter.create_histogram("operation.duration") - + # Trace an operation while recording metrics with tracer.start_as_current_span("test.operation") as span: counter.add(1) span.set_attribute("step", "start") - + histogram.record(50.0) span.set_attribute("step", "processing") - + counter.add(1) span.set_attribute("step", "complete") - + # No exceptions means success @@ -391,14 +357,14 @@ class TestOTelTelemetryProviderFastAPIMiddleware: def test_fastapi_middleware(self, otel_provider): """Test that fastapi_middleware can be called.""" mock_app = MagicMock() - + # Should not raise an exception otel_provider.fastapi_middleware(mock_app) def test_fastapi_middleware_is_idempotent(self, otel_provider): """Test that calling fastapi_middleware multiple times is safe.""" mock_app = MagicMock() - + # Should be able to call multiple times without error otel_provider.fastapi_middleware(mock_app) # Note: Second call might warn but shouldn't fail @@ -411,27 +377,27 @@ class TestOTelTelemetryProviderEdgeCases: def test_tracer_with_empty_module_name(self, otel_provider): """Test that get_tracer works with empty module name.""" tracer = otel_provider.get_tracer("") - + assert tracer is not None assert isinstance(tracer, Tracer) def test_meter_with_empty_name(self, otel_provider): """Test that get_meter works with empty name.""" meter = otel_provider.get_meter("") - + assert meter is not None assert isinstance(meter, Meter) def test_meter_instruments_with_special_characters(self, otel_provider): """Test that metric names with dots, underscores, and hyphens work.""" meter = otel_provider.get_meter("test.meter") - + counter = meter.create_counter("test.counter_name-special") histogram = meter.create_histogram("test.histogram_name-special") - + assert counter is not None assert histogram is not None - + # Verify they can be used counter.add(1) histogram.record(10.0) @@ -440,7 +406,7 @@ class TestOTelTelemetryProviderEdgeCases: """Test that counters work with zero value.""" meter = otel_provider.get_meter("test.meter") counter = meter.create_counter("test.counter") - + # Should not raise an exception counter.add(0.0) @@ -448,7 +414,7 @@ class TestOTelTelemetryProviderEdgeCases: """Test that histograms accept negative values.""" meter = otel_provider.get_meter("test.meter") histogram = meter.create_histogram("test.histogram") - + # Should not raise an exception histogram.record(-10.0) @@ -456,7 +422,7 @@ class TestOTelTelemetryProviderEdgeCases: """Test that up/down counters work with negative values.""" meter = otel_provider.get_meter("test.meter") up_down_counter = meter.create_up_down_counter("test.updown") - + # Should not raise an exception up_down_counter.add(-5.0) @@ -464,7 +430,7 @@ class TestOTelTelemetryProviderEdgeCases: """Test that empty attributes dict is handled correctly.""" meter = otel_provider.get_meter("test.meter") counter = meter.create_counter("test.counter") - + # Should not raise an exception counter.add(1.0, attributes={}) @@ -472,7 +438,7 @@ class TestOTelTelemetryProviderEdgeCases: """Test that None attributes are handled correctly.""" meter = otel_provider.get_meter("test.meter") counter = meter.create_counter("test.counter") - + # Should not raise an exception counter.add(1.0, attributes=None) @@ -484,28 +450,28 @@ class TestOTelTelemetryProviderRealisticScenarios: """Simulate telemetry for a complete inference request.""" tracer = otel_provider.get_tracer("llama_stack.inference") meter = otel_provider.get_meter("llama_stack.metrics") - + # Create instruments request_counter = meter.create_counter("llama.requests.total") token_counter = meter.create_counter("llama.tokens.total") latency_histogram = meter.create_histogram("llama.request.duration_ms") in_flight_gauge = meter.create_up_down_counter("llama.requests.in_flight") - + # Simulate request with tracer.start_as_current_span("inference.request") as request_span: request_span.set_attribute("model.id", "llama-3.2-1b") request_span.set_attribute("user.id", "test-user") - + request_counter.add(1, {"model": "llama-3.2-1b"}) in_flight_gauge.add(1) - + # Simulate token counting token_counter.add(25, {"type": "input", "model": "llama-3.2-1b"}) token_counter.add(150, {"type": "output", "model": "llama-3.2-1b"}) - + # Simulate latency latency_histogram.record(125.5, {"model": "llama-3.2-1b"}) - + in_flight_gauge.add(-1) request_span.set_attribute("tokens.input", 25) request_span.set_attribute("tokens.output", 150) @@ -514,36 +480,36 @@ class TestOTelTelemetryProviderRealisticScenarios: """Simulate a multi-step workflow with nested spans.""" tracer = otel_provider.get_tracer("llama_stack.workflow") meter = otel_provider.get_meter("llama_stack.workflow.metrics") - + step_counter = meter.create_counter("workflow.steps.completed") - + with tracer.start_as_current_span("workflow.execute") as root_span: root_span.set_attribute("workflow.id", "wf-123") - + # Step 1: Validate with tracer.start_as_current_span("step.validate") as span: span.set_attribute("validation.result", "pass") step_counter.add(1, {"step": "validate", "status": "success"}) - + # Step 2: Process with tracer.start_as_current_span("step.process") as span: span.set_attribute("items.processed", 100) step_counter.add(1, {"step": "process", "status": "success"}) - + # Step 3: Finalize with tracer.start_as_current_span("step.finalize") as span: span.set_attribute("output.size", 1024) step_counter.add(1, {"step": "finalize", "status": "success"}) - + root_span.set_attribute("workflow.status", "completed") def test_error_handling_with_telemetry(self, otel_provider): """Test telemetry when errors occur.""" tracer = otel_provider.get_tracer("llama_stack.errors") meter = otel_provider.get_meter("llama_stack.errors.metrics") - + error_counter = meter.create_counter("llama.errors.total") - + with tracer.start_as_current_span("operation.with.error") as span: try: span.set_attribute("step", "processing") @@ -553,24 +519,24 @@ class TestOTelTelemetryProviderRealisticScenarios: span.record_exception(e) span.set_status(trace.Status(trace.StatusCode.ERROR, str(e))) error_counter.add(1, {"error.type": "ValueError"}) - + # Should not raise - error was handled def test_batch_operations_telemetry(self, otel_provider): """Test telemetry for batch operations.""" tracer = otel_provider.get_tracer("llama_stack.batch") meter = otel_provider.get_meter("llama_stack.batch.metrics") - + batch_counter = meter.create_counter("llama.batch.items.processed") batch_duration = meter.create_histogram("llama.batch.duration_ms") - + with tracer.start_as_current_span("batch.process") as batch_span: batch_span.set_attribute("batch.size", 100) - + for i in range(100): with tracer.start_as_current_span(f"item.{i}") as item_span: item_span.set_attribute("item.index", i) batch_counter.add(1, {"status": "success"}) - + batch_duration.record(5000.0, {"batch.size": "100"}) batch_span.set_attribute("batch.status", "completed")