fix(pr specific): passes pre-commit

This commit is contained in:
Emilio Garcia 2025-10-03 12:35:09 -04:00
parent 4aa2dc110d
commit 2b7a765d02
20 changed files with 547 additions and 516 deletions

View file

@ -0,0 +1,33 @@
---
description: "Native OpenTelemetry provider with full access to OTel Tracer and Meter APIs for advanced instrumentation."
sidebar_label: Otel
title: inline::otel
---
# inline::otel
## Description
Native OpenTelemetry provider with full access to OTel Tracer and Meter APIs for advanced instrumentation.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `service_name` | `<class 'str'>` | No | | The name of the service to be monitored.
Is overridden by the OTEL_SERVICE_NAME or OTEL_RESOURCE_ATTRIBUTES environment variables. |
| `service_version` | `str \| None` | No | | The version of the service to be monitored.
Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable. |
| `deployment_environment` | `str \| None` | No | | The name of the environment of the service to be monitored.
Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable. |
| `span_processor` | `BatchSpanProcessor \| SimpleSpanProcessor \| None` | No | batch | The span processor to use.
Is overriden by the OTEL_SPAN_PROCESSOR environment variable. |
## Sample Configuration
```yaml
service_name: ${env.OTEL_SERVICE_NAME:=llama-stack}
service_version: ${env.OTEL_SERVICE_VERSION:=}
deployment_environment: ${env.OTEL_DEPLOYMENT_ENVIRONMENT:=}
span_processor: ${env.OTEL_SPAN_PROCESSOR:=batch}
```

View file

@ -32,7 +32,7 @@ from termcolor import cprint
from llama_stack.core.build import print_pip_install_help from llama_stack.core.build import print_pip_install_help
from llama_stack.core.configure import parse_and_maybe_upgrade_config from llama_stack.core.configure import parse_and_maybe_upgrade_config
from llama_stack.core.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.request_headers import ( from llama_stack.core.request_headers import (
PROVIDER_DATA_VAR, PROVIDER_DATA_VAR,
request_provider_data_context, request_provider_data_context,
@ -49,7 +49,6 @@ from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.core.utils.exec import in_notebook from llama_stack.core.utils.exec import in_notebook
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
T = TypeVar("T") T = TypeVar("T")

View file

@ -63,7 +63,6 @@ from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from .auth import AuthenticationMiddleware from .auth import AuthenticationMiddleware
from .quota import QuotaMiddleware from .quota import QuotaMiddleware
@ -236,9 +235,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
try: try:
if is_streaming: if is_streaming:
gen = preserve_contexts_async_generator( gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), [PROVIDER_DATA_VAR])
sse_generator(func(**kwargs)), [PROVIDER_DATA_VAR]
)
return StreamingResponse(gen, media_type="text/event-stream") return StreamingResponse(gen, media_type="text/event-stream")
else: else:
value = func(**kwargs) value = func(**kwargs)
@ -282,7 +279,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
] ]
) )
setattr(route_handler, "__signature__", sig.replace(parameters=new_params)) route_handler.__signature__ = sig.replace(parameters=new_params)
return route_handler return route_handler

View file

@ -359,7 +359,6 @@ class Stack:
await refresh_registry_once(impls) await refresh_registry_once(impls)
self.impls = impls self.impls = impls
# safely access impls without raising an exception # safely access impls without raising an exception
def get_impls(self) -> dict[Api, Any]: def get_impls(self) -> dict[Api, Any]:
if self.impls is None: if self.impls is None:

View file

@ -1,4 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from abc import abstractmethod from abc import abstractmethod
from fastapi import FastAPI
from pydantic import BaseModel
from opentelemetry.trace import Tracer from fastapi import FastAPI
from opentelemetry.metrics import Meter from opentelemetry.metrics import Meter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.resources import Attributes from opentelemetry.sdk.resources import Attributes
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import Tracer
from pydantic import BaseModel
from sqlalchemy import Engine from sqlalchemy import Engine
@ -19,39 +19,44 @@ class TelemetryProvider(BaseModel):
""" """
TelemetryProvider standardizes how telemetry is provided to the application. TelemetryProvider standardizes how telemetry is provided to the application.
""" """
@abstractmethod @abstractmethod
def fastapi_middleware(self, app: FastAPI, *args, **kwargs): def fastapi_middleware(self, app: FastAPI, *args, **kwargs):
""" """
Injects FastAPI middleware that instruments the application for telemetry. Injects FastAPI middleware that instruments the application for telemetry.
""" """
... ...
@abstractmethod @abstractmethod
def sqlalchemy_instrumentation(self, engine: Engine | None = None): def sqlalchemy_instrumentation(self, engine: Engine | None = None):
""" """
Injects SQLAlchemy instrumentation that instruments the application for telemetry. Injects SQLAlchemy instrumentation that instruments the application for telemetry.
""" """
... ...
@abstractmethod @abstractmethod
def get_tracer(self, def get_tracer(
instrumenting_module_name: str, self,
instrumenting_library_version: str | None = None, instrumenting_module_name: str,
tracer_provider: TracerProvider | None = None, instrumenting_library_version: str | None = None,
schema_url: str | None = None, tracer_provider: TracerProvider | None = None,
attributes: Attributes | None = None schema_url: str | None = None,
attributes: Attributes | None = None,
) -> Tracer: ) -> Tracer:
""" """
Gets a tracer. Gets a tracer.
""" """
... ...
@abstractmethod @abstractmethod
def get_meter(self, name: str, def get_meter(
version: str = "", self,
meter_provider: MeterProvider | None = None, name: str,
schema_url: str | None = None, version: str = "",
attributes: Attributes | None = None) -> Meter: meter_provider: MeterProvider | None = None,
schema_url: str | None = None,
attributes: Attributes | None = None,
) -> Meter:
""" """
Gets a meter. Gets a meter.
""" """

View file

@ -1,15 +1,22 @@
from aiohttp import hdrs # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any from typing import Any
from aiohttp import hdrs
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
from llama_stack.core.external import ExternalApiSpec from llama_stack.core.external import ExternalApiSpec
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace
logger = get_logger(name=__name__, category="telemetry::meta_reference") logger = get_logger(name=__name__, category="telemetry::meta_reference")
class TracingMiddleware: class TracingMiddleware:
def __init__( def __init__(
self, self,

View file

@ -10,7 +10,6 @@ import threading
from typing import Any, cast from typing import Any, cast
from fastapi import FastAPI from fastapi import FastAPI
from opentelemetry import metrics, trace from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
@ -23,11 +22,6 @@ from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.util.types import Attributes from opentelemetry.util.types import Attributes
from llama_stack.core.external import ExternalApiSpec
from llama_stack.core.server.tracing import TelemetryProvider
from llama_stack.providers.inline.telemetry.meta_reference.middleware import TracingMiddleware
from llama_stack.apis.telemetry import ( from llama_stack.apis.telemetry import (
Event, Event,
MetricEvent, MetricEvent,
@ -47,10 +41,13 @@ from llama_stack.apis.telemetry import (
UnstructuredLogEvent, UnstructuredLogEvent,
) )
from llama_stack.core.datatypes import Api from llama_stack.core.datatypes import Api
from llama_stack.core.external import ExternalApiSpec
from llama_stack.core.server.tracing import TelemetryProvider
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor, ConsoleSpanProcessor,
) )
from llama_stack.providers.inline.telemetry.meta_reference.middleware import TracingMiddleware
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import ( from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
SQLiteSpanProcessor, SQLiteSpanProcessor,
) )
@ -381,7 +378,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry, TelemetryProvider):
max_depth=max_depth, max_depth=max_depth,
) )
) )
def fastapi_middleware( def fastapi_middleware(
self, self,
app: FastAPI, app: FastAPI,

View file

@ -12,13 +12,12 @@ __all__ = ["OTelTelemetryConfig"]
async def get_provider_impl(config: OTelTelemetryConfig, deps): async def get_provider_impl(config: OTelTelemetryConfig, deps):
""" """
Get the OTel telemetry provider implementation. Get the OTel telemetry provider implementation.
This function is called by the Llama Stack registry to instantiate This function is called by the Llama Stack registry to instantiate
the provider. the provider.
""" """
from .otel import OTelTelemetryProvider from .otel import OTelTelemetryProvider
# The provider is synchronously initialized via Pydantic model_post_init # The provider is synchronously initialized via Pydantic model_post_init
# No async initialization needed # No async initialization needed
return OTelTelemetryProvider(config=config) return OTelTelemetryProvider(config=config)

View file

@ -1,8 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Literal from typing import Any, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
type BatchSpanProcessor = Literal["batch"] type BatchSpanProcessor = Literal["batch"]
type SimpleSpanProcessor = Literal["simple"] type SimpleSpanProcessor = Literal["simple"]
@ -13,26 +18,27 @@ class OTelTelemetryConfig(BaseModel):
Most configuration is set using environment variables. Most configuration is set using environment variables.
See https://opentelemetry.io/docs/specs/otel/configuration/sdk-configuration-variables/ for more information. See https://opentelemetry.io/docs/specs/otel/configuration/sdk-configuration-variables/ for more information.
""" """
service_name: str = Field( service_name: str = Field(
description="""The name of the service to be monitored. description="""The name of the service to be monitored.
Is overridden by the OTEL_SERVICE_NAME or OTEL_RESOURCE_ATTRIBUTES environment variables.""", Is overridden by the OTEL_SERVICE_NAME or OTEL_RESOURCE_ATTRIBUTES environment variables.""",
) )
service_version: str | None = Field( service_version: str | None = Field(
default=None, default=None,
description="""The version of the service to be monitored. description="""The version of the service to be monitored.
Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable.""" Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable.""",
) )
deployment_environment: str | None = Field( deployment_environment: str | None = Field(
default=None, default=None,
description="""The name of the environment of the service to be monitored. description="""The name of the environment of the service to be monitored.
Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable.""" Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable.""",
) )
span_processor: BatchSpanProcessor | SimpleSpanProcessor | None = Field( span_processor: BatchSpanProcessor | SimpleSpanProcessor | None = Field(
description="""The span processor to use. description="""The span processor to use.
Is overriden by the OTEL_SPAN_PROCESSOR environment variable.""", Is overriden by the OTEL_SPAN_PROCESSOR environment variable.""",
default="batch" default="batch",
) )
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str = "") -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str = "") -> dict[str, Any]:
"""Sample configuration for use in distributions.""" """Sample configuration for use in distributions."""

View file

@ -1,24 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os import os
from opentelemetry import trace, metrics from fastapi import FastAPI
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.metrics import Meter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.resources import Attributes, Resource from opentelemetry.sdk.resources import Attributes, Resource
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SimpleSpanProcessor from opentelemetry.sdk.trace.export import BatchSpanProcessor, SimpleSpanProcessor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.trace import Tracer from opentelemetry.trace import Tracer
from opentelemetry.metrics import Meter from sqlalchemy import Engine
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from llama_stack.core.telemetry.telemetry import TelemetryProvider from llama_stack.core.telemetry.telemetry import TelemetryProvider
from llama_stack.log import get_logger from llama_stack.log import get_logger
from sqlalchemy import Engine
from .config import OTelTelemetryConfig from .config import OTelTelemetryConfig
from fastapi import FastAPI
logger = get_logger(name=__name__, category="telemetry::otel") logger = get_logger(name=__name__, category="telemetry::otel")
@ -27,6 +31,7 @@ class OTelTelemetryProvider(TelemetryProvider):
""" """
A simple Open Telemetry native telemetry provider. A simple Open Telemetry native telemetry provider.
""" """
config: OTelTelemetryConfig config: OTelTelemetryConfig
def model_post_init(self, __context): def model_post_init(self, __context):
@ -56,66 +61,66 @@ class OTelTelemetryProvider(TelemetryProvider):
tracer_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter)) tracer_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
elif self.config.span_processor == "simple": elif self.config.span_processor == "simple":
tracer_provider.add_span_processor(SimpleSpanProcessor(otlp_span_exporter)) tracer_provider.add_span_processor(SimpleSpanProcessor(otlp_span_exporter))
meter_provider = MeterProvider(resource=resource) meter_provider = MeterProvider(resource=resource)
metrics.set_meter_provider(meter_provider) metrics.set_meter_provider(meter_provider)
# Do not fail the application, but warn the user if the endpoints are not set properly. # Do not fail the application, but warn the user if the endpoints are not set properly.
if not os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"): if not os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"):
if not os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"): if not os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"):
logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_TRACES_ENDPOINT is not set. Traces will not be exported.") logger.warning(
"OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_TRACES_ENDPOINT is not set. Traces will not be exported."
)
if not os.environ.get("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT"): if not os.environ.get("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT"):
logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_METRICS_ENDPOINT is not set. Metrics will not be exported.") logger.warning(
"OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_METRICS_ENDPOINT is not set. Metrics will not be exported."
)
def fastapi_middleware(self, app: FastAPI): def fastapi_middleware(self, app: FastAPI):
""" """
Instrument FastAPI with OTel for automatic tracing and metrics. Instrument FastAPI with OTel for automatic tracing and metrics.
Captures: Captures:
- Distributed traces for all HTTP requests (via FastAPIInstrumentor) - Distributed traces for all HTTP requests (via FastAPIInstrumentor)
- HTTP metrics following semantic conventions (custom middleware) - HTTP metrics following semantic conventions (custom middleware)
""" """
# Enable automatic tracing # Enable automatic tracing
FastAPIInstrumentor.instrument_app(app) FastAPIInstrumentor.instrument_app(app)
# Add custom middleware for HTTP metrics # Add custom middleware for HTTP metrics
meter = self.get_meter("llama_stack.http.server") meter = self.get_meter("llama_stack.http.server")
# Create HTTP metrics following semantic conventions # Create HTTP metrics following semantic conventions
# https://opentelemetry.io/docs/specs/semconv/http/http-metrics/ # https://opentelemetry.io/docs/specs/semconv/http/http-metrics/
request_duration = meter.create_histogram( request_duration = meter.create_histogram(
"http.server.request.duration", "http.server.request.duration", unit="ms", description="Duration of HTTP server requests"
unit="ms",
description="Duration of HTTP server requests"
) )
active_requests = meter.create_up_down_counter( active_requests = meter.create_up_down_counter(
"http.server.active_requests", "http.server.active_requests", unit="requests", description="Number of active HTTP server requests"
unit="requests",
description="Number of active HTTP server requests"
) )
request_count = meter.create_counter( request_count = meter.create_counter(
"http.server.request.count", "http.server.request.count", unit="requests", description="Total number of HTTP server requests"
unit="requests",
description="Total number of HTTP server requests"
) )
# Add middleware to record metrics # Add middleware to record metrics
@app.middleware("http") # type: ignore[misc] @app.middleware("http") # type: ignore[misc]
async def http_metrics_middleware(request, call_next): async def http_metrics_middleware(request, call_next):
import time import time
# Record active request # Record active request
active_requests.add(1, { active_requests.add(
"http.method": request.method, 1,
"http.route": request.url.path, {
}) "http.method": request.method,
"http.route": request.url.path,
},
)
start_time = time.time() start_time = time.time()
status_code = 500 # Default to error status_code = 500 # Default to error
try: try:
response = await call_next(request) response = await call_next(request)
status_code = response.status_code status_code = response.status_code
@ -124,22 +129,24 @@ class OTelTelemetryProvider(TelemetryProvider):
finally: finally:
# Record metrics # Record metrics
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
attributes = { attributes = {
"http.method": request.method, "http.method": request.method,
"http.route": request.url.path, "http.route": request.url.path,
"http.status_code": status_code, "http.status_code": status_code,
} }
request_duration.record(duration_ms, attributes) request_duration.record(duration_ms, attributes)
request_count.add(1, attributes) request_count.add(1, attributes)
active_requests.add(-1, { active_requests.add(
"http.method": request.method, -1,
"http.route": request.url.path, {
}) "http.method": request.method,
"http.route": request.url.path,
return response },
)
return response
def sqlalchemy_instrumentation(self, engine: Engine | None = None): def sqlalchemy_instrumentation(self, engine: Engine | None = None):
kwargs = {} kwargs = {}
@ -147,34 +154,30 @@ class OTelTelemetryProvider(TelemetryProvider):
kwargs["engine"] = engine kwargs["engine"] = engine
SQLAlchemyInstrumentor().instrument(**kwargs) SQLAlchemyInstrumentor().instrument(**kwargs)
def get_tracer(
def get_tracer(self, self,
instrumenting_module_name: str, instrumenting_module_name: str,
instrumenting_library_version: str | None = None, instrumenting_library_version: str | None = None,
tracer_provider: TracerProvider | None = None, tracer_provider: TracerProvider | None = None,
schema_url: str | None = None, schema_url: str | None = None,
attributes: Attributes | None = None attributes: Attributes | None = None,
) -> Tracer: ) -> Tracer:
return trace.get_tracer( return trace.get_tracer(
instrumenting_module_name=instrumenting_module_name, instrumenting_module_name=instrumenting_module_name,
instrumenting_library_version=instrumenting_library_version, instrumenting_library_version=instrumenting_library_version,
tracer_provider=tracer_provider, tracer_provider=tracer_provider,
schema_url=schema_url, schema_url=schema_url,
attributes=attributes attributes=attributes,
) )
def get_meter(
def get_meter(self, self,
name: str, name: str,
version: str = "", version: str = "",
meter_provider: MeterProvider | None = None, meter_provider: MeterProvider | None = None,
schema_url: str | None = None, schema_url: str | None = None,
attributes: Attributes | None = None attributes: Attributes | None = None,
) -> Meter: ) -> Meter:
return metrics.get_meter( return metrics.get_meter(
name=name, name=name, version=version, meter_provider=meter_provider, schema_url=schema_url, attributes=attributes
version=version, )
meter_provider=meter_provider,
schema_url=schema_url,
attributes=attributes
)

View file

@ -3,4 +3,3 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.

View file

@ -24,7 +24,7 @@ class MockServerBase(BaseModel):
async def await_start(self): async def await_start(self):
# Start server and wait until ready # Start server and wait until ready
... ...
def stop(self): def stop(self):
# Stop server and cleanup # Stop server and cleanup
... ...
@ -49,29 +49,29 @@ Add to `servers.py`:
```python ```python
class MockRedisServer(MockServerBase): class MockRedisServer(MockServerBase):
"""Mock Redis server.""" """Mock Redis server."""
port: int = Field(default=6379) port: int = Field(default=6379)
# Non-Pydantic fields # Non-Pydantic fields
server: Any = Field(default=None, exclude=True) server: Any = Field(default=None, exclude=True)
def model_post_init(self, __context): def model_post_init(self, __context):
self.server = None self.server = None
async def await_start(self): async def await_start(self):
"""Start Redis mock and wait until ready.""" """Start Redis mock and wait until ready."""
# Start your server # Start your server
self.server = create_redis_server(self.port) self.server = create_redis_server(self.port)
self.server.start() self.server.start()
# Wait for port to be listening # Wait for port to be listening
for _ in range(10): for _ in range(10):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if sock.connect_ex(('localhost', self.port)) == 0: if sock.connect_ex(("localhost", self.port)) == 0:
sock.close() sock.close()
return # Ready! return # Ready!
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
def stop(self): def stop(self):
if self.server: if self.server:
self.server.stop() self.server.stop()
@ -101,11 +101,11 @@ The harness automatically:
## Benefits ## Benefits
**Parallel Startup** - All servers start simultaneously **Parallel Startup** - All servers start simultaneously
**Type-Safe** - Pydantic validation **Type-Safe** - Pydantic validation
**Simple** - Just implement 2 methods **Simple** - Just implement 2 methods
**Fast** - No HTTP polling, direct port checking **Fast** - No HTTP polling, direct port checking
**Clean** - Async/await pattern **Clean** - Async/await pattern
## Usage in Tests ## Usage in Tests
@ -116,6 +116,7 @@ def mock_servers():
yield servers yield servers
stop_mock_servers(servers) stop_mock_servers(servers)
# Access specific servers # Access specific servers
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def mock_redis(mock_servers): def mock_redis(mock_servers):

View file

@ -14,9 +14,9 @@ This module provides:
- Mock server harness for parallel async startup - Mock server harness for parallel async startup
""" """
from .harness import MockServerConfig, start_mock_servers_async, stop_mock_servers
from .mock_base import MockServerBase from .mock_base import MockServerBase
from .servers import MockOTLPCollector, MockVLLMServer from .servers import MockOTLPCollector, MockVLLMServer
from .harness import MockServerConfig, start_mock_servers_async, stop_mock_servers
__all__ = [ __all__ = [
"MockServerBase", "MockServerBase",
@ -26,4 +26,3 @@ __all__ = [
"start_mock_servers_async", "start_mock_servers_async",
"stop_mock_servers", "stop_mock_servers",
] ]

View file

@ -14,7 +14,7 @@ HOW TO ADD A NEW MOCK SERVER:
""" """
import asyncio import asyncio
from typing import Any, Dict, List from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -24,10 +24,10 @@ from .mock_base import MockServerBase
class MockServerConfig(BaseModel): class MockServerConfig(BaseModel):
""" """
Configuration for a mock server to start. Configuration for a mock server to start.
**TO ADD A NEW MOCK SERVER:** **TO ADD A NEW MOCK SERVER:**
Just create a MockServerConfig instance with your server class. Just create a MockServerConfig instance with your server class.
Example: Example:
MockServerConfig( MockServerConfig(
name="Mock MyService", name="Mock MyService",
@ -35,73 +35,72 @@ class MockServerConfig(BaseModel):
init_kwargs={"port": 9000, "config_param": "value"}, init_kwargs={"port": 9000, "config_param": "value"},
) )
""" """
model_config = {"arbitrary_types_allowed": True} model_config = {"arbitrary_types_allowed": True}
name: str = Field(description="Display name for logging") name: str = Field(description="Display name for logging")
server_class: type = Field(description="Mock server class (must inherit from MockServerBase)") server_class: type = Field(description="Mock server class (must inherit from MockServerBase)")
init_kwargs: Dict[str, Any] = Field(default_factory=dict, description="Kwargs to pass to server constructor") init_kwargs: dict[str, Any] = Field(default_factory=dict, description="Kwargs to pass to server constructor")
async def start_mock_servers_async(mock_servers_config: List[MockServerConfig]) -> Dict[str, MockServerBase]: async def start_mock_servers_async(mock_servers_config: list[MockServerConfig]) -> dict[str, MockServerBase]:
""" """
Start all mock servers in parallel and wait for them to be ready. Start all mock servers in parallel and wait for them to be ready.
**HOW IT WORKS:** **HOW IT WORKS:**
1. Creates all server instances 1. Creates all server instances
2. Calls await_start() on all servers in parallel 2. Calls await_start() on all servers in parallel
3. Returns when all are ready 3. Returns when all are ready
**SIMPLE TO USE:** **SIMPLE TO USE:**
servers = await start_mock_servers_async([config1, config2, ...]) servers = await start_mock_servers_async([config1, config2, ...])
Args: Args:
mock_servers_config: List of mock server configurations mock_servers_config: List of mock server configurations
Returns: Returns:
Dict mapping server name to server instance Dict mapping server name to server instance
""" """
servers = {} servers = {}
start_tasks = [] start_tasks = []
# Create all servers and prepare start tasks # Create all servers and prepare start tasks
for config in mock_servers_config: for config in mock_servers_config:
server = config.server_class(**config.init_kwargs) server = config.server_class(**config.init_kwargs)
servers[config.name] = server servers[config.name] = server
start_tasks.append(server.await_start()) start_tasks.append(server.await_start())
# Start all servers in parallel # Start all servers in parallel
try: try:
await asyncio.gather(*start_tasks) await asyncio.gather(*start_tasks)
# Print readiness confirmation # Print readiness confirmation
for name in servers.keys(): for name in servers.keys():
print(f"[INFO] {name} ready") print(f"[INFO] {name} ready")
except Exception as e: except Exception as e:
# If any server fails, stop all servers # If any server fails, stop all servers
for server in servers.values(): for server in servers.values():
try: try:
server.stop() server.stop()
except: except Exception:
pass pass
raise RuntimeError(f"Failed to start mock servers: {e}") raise RuntimeError(f"Failed to start mock servers: {e}") from None
return servers return servers
def stop_mock_servers(servers: Dict[str, Any]): def stop_mock_servers(servers: dict[str, Any]):
""" """
Stop all mock servers. Stop all mock servers.
Args: Args:
servers: Dict of server instances from start_mock_servers_async() servers: Dict of server instances from start_mock_servers_async()
""" """
for name, server in servers.items(): for name, server in servers.items():
try: try:
if hasattr(server, 'get_request_count'): if hasattr(server, "get_request_count"):
print(f"\n[INFO] {name} received {server.get_request_count()} requests") print(f"\n[INFO] {name} received {server.get_request_count()} requests")
server.stop() server.stop()
except Exception as e: except Exception as e:
print(f"[WARN] Error stopping {name}: {e}") print(f"[WARN] Error stopping {name}: {e}")

View file

@ -10,25 +10,25 @@ Base class for mock servers with async startup support.
All mock servers should inherit from MockServerBase and implement await_start(). All mock servers should inherit from MockServerBase and implement await_start().
""" """
import asyncio
from abc import abstractmethod from abc import abstractmethod
from pydantic import BaseModel, Field
from pydantic import BaseModel
class MockServerBase(BaseModel): class MockServerBase(BaseModel):
""" """
Pydantic base model for mock servers. Pydantic base model for mock servers.
**TO CREATE A NEW MOCK SERVER:** **TO CREATE A NEW MOCK SERVER:**
1. Inherit from this class 1. Inherit from this class
2. Implement async def await_start(self) 2. Implement async def await_start(self)
3. Implement def stop(self) 3. Implement def stop(self)
4. Done! 4. Done!
Example: Example:
class MyMockServer(MockServerBase): class MyMockServer(MockServerBase):
port: int = 8080 port: int = 8080
async def await_start(self): async def await_start(self):
# Start your server # Start your server
self.server = create_server() self.server = create_server()
@ -36,34 +36,33 @@ class MockServerBase(BaseModel):
# Wait until ready (can check internal state, no HTTP needed) # Wait until ready (can check internal state, no HTTP needed)
while not self.server.is_listening(): while not self.server.is_listening():
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
def stop(self): def stop(self):
if self.server: if self.server:
self.server.stop() self.server.stop()
""" """
model_config = {"arbitrary_types_allowed": True} model_config = {"arbitrary_types_allowed": True}
@abstractmethod @abstractmethod
async def await_start(self): async def await_start(self):
""" """
Start the server and wait until it's ready. Start the server and wait until it's ready.
This method should: This method should:
1. Start the server (synchronous or async) 1. Start the server (synchronous or async)
2. Wait until the server is fully ready to accept requests 2. Wait until the server is fully ready to accept requests
3. Return when ready 3. Return when ready
Subclasses can check internal state directly - no HTTP polling needed! Subclasses can check internal state directly - no HTTP polling needed!
""" """
... ...
@abstractmethod @abstractmethod
def stop(self): def stop(self):
""" """
Stop the server and clean up resources. Stop the server and clean up resources.
This method should gracefully shut down the server. This method should gracefully shut down the server.
""" """
... ...

View file

@ -20,7 +20,7 @@ import json
import socket import socket
import threading import threading
import time import time
from typing import Any, Dict, List from typing import Any
from pydantic import Field from pydantic import Field
@ -30,10 +30,10 @@ from .mock_base import MockServerBase
class MockOTLPCollector(MockServerBase): class MockOTLPCollector(MockServerBase):
""" """
Mock OTLP collector HTTP server. Mock OTLP collector HTTP server.
Receives real OTLP exports from Llama Stack and stores them for verification. Receives real OTLP exports from Llama Stack and stores them for verification.
Runs on localhost:4318 (standard OTLP HTTP port). Runs on localhost:4318 (standard OTLP HTTP port).
Usage: Usage:
collector = MockOTLPCollector() collector = MockOTLPCollector()
await collector.await_start() await collector.await_start()
@ -41,115 +41,119 @@ class MockOTLPCollector(MockServerBase):
print(f"Received {collector.get_trace_count()} traces") print(f"Received {collector.get_trace_count()} traces")
collector.stop() collector.stop()
""" """
port: int = Field(default=4318, description="Port to run collector on") port: int = Field(default=4318, description="Port to run collector on")
# Non-Pydantic fields (set after initialization) # Non-Pydantic fields (set after initialization)
traces: List[Dict] = Field(default_factory=list, exclude=True) traces: list[dict] = Field(default_factory=list, exclude=True)
metrics: List[Dict] = Field(default_factory=list, exclude=True) metrics: list[dict] = Field(default_factory=list, exclude=True)
server: Any = Field(default=None, exclude=True) server: Any = Field(default=None, exclude=True)
server_thread: Any = Field(default=None, exclude=True) server_thread: Any = Field(default=None, exclude=True)
def model_post_init(self, __context): def model_post_init(self, __context):
"""Initialize after Pydantic validation.""" """Initialize after Pydantic validation."""
self.traces = [] self.traces = []
self.metrics = [] self.metrics = []
self.server = None self.server = None
self.server_thread = None self.server_thread = None
def _create_handler_class(self): def _create_handler_class(self):
"""Create the HTTP handler class for this collector instance.""" """Create the HTTP handler class for this collector instance."""
collector_self = self collector_self = self
class OTLPHandler(http.server.BaseHTTPRequestHandler): class OTLPHandler(http.server.BaseHTTPRequestHandler):
"""HTTP request handler for OTLP requests.""" """HTTP request handler for OTLP requests."""
def log_message(self, format, *args): def log_message(self, format, *args):
"""Suppress HTTP server logs.""" """Suppress HTTP server logs."""
pass pass
def do_GET(self): def do_GET(self): # noqa: N802
"""Handle GET requests.""" """Handle GET requests."""
# No readiness endpoint needed - using await_start() instead # No readiness endpoint needed - using await_start() instead
self.send_response(404) self.send_response(404)
self.end_headers() self.end_headers()
def do_POST(self): def do_POST(self): # noqa: N802
"""Handle OTLP POST requests.""" """Handle OTLP POST requests."""
content_length = int(self.headers.get('Content-Length', 0)) content_length = int(self.headers.get("Content-Length", 0))
body = self.rfile.read(content_length) if content_length > 0 else b'' body = self.rfile.read(content_length) if content_length > 0 else b""
# Store the export request # Store the export request
if '/v1/traces' in self.path: if "/v1/traces" in self.path:
collector_self.traces.append({ collector_self.traces.append(
'body': body, {
'timestamp': time.time(), "body": body,
}) "timestamp": time.time(),
elif '/v1/metrics' in self.path: }
collector_self.metrics.append({ )
'body': body, elif "/v1/metrics" in self.path:
'timestamp': time.time(), collector_self.metrics.append(
}) {
"body": body,
"timestamp": time.time(),
}
)
# Always return success (200 OK) # Always return success (200 OK)
self.send_response(200) self.send_response(200)
self.send_header('Content-Type', 'application/json') self.send_header("Content-Type", "application/json")
self.end_headers() self.end_headers()
self.wfile.write(b'{}') self.wfile.write(b"{}")
return OTLPHandler return OTLPHandler
async def await_start(self): async def await_start(self):
""" """
Start the OTLP collector and wait until ready. Start the OTLP collector and wait until ready.
This method is async and can be awaited to ensure the server is ready. This method is async and can be awaited to ensure the server is ready.
""" """
# Create handler and start the HTTP server # Create handler and start the HTTP server
handler_class = self._create_handler_class() handler_class = self._create_handler_class()
self.server = http.server.HTTPServer(('localhost', self.port), handler_class) self.server = http.server.HTTPServer(("localhost", self.port), handler_class)
self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True) self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.server_thread.start() self.server_thread.start()
# Wait for server to be listening on the port # Wait for server to be listening on the port
for _ in range(10): for _ in range(10):
try: try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(('localhost', self.port)) result = sock.connect_ex(("localhost", self.port))
sock.close() sock.close()
if result == 0: if result == 0:
# Port is listening # Port is listening
return return
except: except Exception:
pass pass
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
raise RuntimeError(f"OTLP collector failed to start on port {self.port}") raise RuntimeError(f"OTLP collector failed to start on port {self.port}")
def stop(self): def stop(self):
"""Stop the OTLP collector server.""" """Stop the OTLP collector server."""
if self.server: if self.server:
self.server.shutdown() self.server.shutdown()
self.server.server_close() self.server.server_close()
def clear(self): def clear(self):
"""Clear all captured telemetry data.""" """Clear all captured telemetry data."""
self.traces = [] self.traces = []
self.metrics = [] self.metrics = []
def get_trace_count(self) -> int: def get_trace_count(self) -> int:
"""Get number of trace export requests received.""" """Get number of trace export requests received."""
return len(self.traces) return len(self.traces)
def get_metric_count(self) -> int: def get_metric_count(self) -> int:
"""Get number of metric export requests received.""" """Get number of metric export requests received."""
return len(self.metrics) return len(self.metrics)
def get_all_traces(self) -> List[Dict]: def get_all_traces(self) -> list[dict]:
"""Get all captured trace exports.""" """Get all captured trace exports."""
return self.traces return self.traces
def get_all_metrics(self) -> List[Dict]: def get_all_metrics(self) -> list[dict]:
"""Get all captured metric exports.""" """Get all captured metric exports."""
return self.metrics return self.metrics
@ -157,14 +161,14 @@ class MockOTLPCollector(MockServerBase):
class MockVLLMServer(MockServerBase): class MockVLLMServer(MockServerBase):
""" """
Mock vLLM inference server with OpenAI-compatible API. Mock vLLM inference server with OpenAI-compatible API.
Returns valid OpenAI Python client response objects for: Returns valid OpenAI Python client response objects for:
- Chat completions (/v1/chat/completions) - Chat completions (/v1/chat/completions)
- Text completions (/v1/completions) - Text completions (/v1/completions)
- Model listing (/v1/models) - Model listing (/v1/models)
Runs on localhost:8000 (standard vLLM port). Runs on localhost:8000 (standard vLLM port).
Usage: Usage:
server = MockVLLMServer(models=["my-model"]) server = MockVLLMServer(models=["my-model"])
await server.await_start() await server.await_start()
@ -172,94 +176,97 @@ class MockVLLMServer(MockServerBase):
print(f"Handled {server.get_request_count()} requests") print(f"Handled {server.get_request_count()} requests")
server.stop() server.stop()
""" """
port: int = Field(default=8000, description="Port to run server on") port: int = Field(default=8000, description="Port to run server on")
models: List[str] = Field( models: list[str] = Field(
default_factory=lambda: ["meta-llama/Llama-3.2-1B-Instruct"], default_factory=lambda: ["meta-llama/Llama-3.2-1B-Instruct"], description="List of model IDs to serve"
description="List of model IDs to serve"
) )
# Non-Pydantic fields # Non-Pydantic fields
requests_received: List[Dict] = Field(default_factory=list, exclude=True) requests_received: list[dict] = Field(default_factory=list, exclude=True)
server: Any = Field(default=None, exclude=True) server: Any = Field(default=None, exclude=True)
server_thread: Any = Field(default=None, exclude=True) server_thread: Any = Field(default=None, exclude=True)
def model_post_init(self, __context): def model_post_init(self, __context):
"""Initialize after Pydantic validation.""" """Initialize after Pydantic validation."""
self.requests_received = [] self.requests_received = []
self.server = None self.server = None
self.server_thread = None self.server_thread = None
def _create_handler_class(self): def _create_handler_class(self):
"""Create the HTTP handler class for this vLLM instance.""" """Create the HTTP handler class for this vLLM instance."""
server_self = self server_self = self
class VLLMHandler(http.server.BaseHTTPRequestHandler): class VLLMHandler(http.server.BaseHTTPRequestHandler):
"""HTTP request handler for vLLM API.""" """HTTP request handler for vLLM API."""
def log_message(self, format, *args): def log_message(self, format, *args):
"""Suppress HTTP server logs.""" """Suppress HTTP server logs."""
pass pass
def log_request(self, code='-', size='-'): def log_request(self, code="-", size="-"):
"""Log incoming requests for debugging.""" """Log incoming requests for debugging."""
print(f"[DEBUG] Mock vLLM received: {self.command} {self.path} -> {code}") print(f"[DEBUG] Mock vLLM received: {self.command} {self.path} -> {code}")
def do_GET(self): def do_GET(self): # noqa: N802
"""Handle GET requests (models list, health check).""" """Handle GET requests (models list, health check)."""
# Log GET requests too # Log GET requests too
server_self.requests_received.append({ server_self.requests_received.append(
'path': self.path, {
'method': 'GET', "path": self.path,
'timestamp': time.time(), "method": "GET",
}) "timestamp": time.time(),
}
if self.path == '/v1/models': )
if self.path == "/v1/models":
response = self._create_models_list_response() response = self._create_models_list_response()
self._send_json_response(200, response) self._send_json_response(200, response)
elif self.path == '/health' or self.path == '/v1/health': elif self.path == "/health" or self.path == "/v1/health":
self._send_json_response(200, {"status": "healthy"}) self._send_json_response(200, {"status": "healthy"})
else: else:
self.send_response(404) self.send_response(404)
self.end_headers() self.end_headers()
def do_POST(self): def do_POST(self): # noqa: N802
"""Handle POST requests (chat/text completions).""" """Handle POST requests (chat/text completions)."""
content_length = int(self.headers.get('Content-Length', 0)) content_length = int(self.headers.get("Content-Length", 0))
body = self.rfile.read(content_length) if content_length > 0 else b'{}' body = self.rfile.read(content_length) if content_length > 0 else b"{}"
try: try:
request_data = json.loads(body) request_data = json.loads(body)
except: except Exception:
request_data = {} request_data = {}
# Log the request # Log the request
server_self.requests_received.append({ server_self.requests_received.append(
'path': self.path, {
'body': request_data, "path": self.path,
'timestamp': time.time(), "body": request_data,
}) "timestamp": time.time(),
}
)
# Route to appropriate handler # Route to appropriate handler
if '/chat/completions' in self.path: if "/chat/completions" in self.path:
response = self._create_chat_completion_response(request_data) response = self._create_chat_completion_response(request_data)
self._send_json_response(200, response) self._send_json_response(200, response)
elif '/completions' in self.path: elif "/completions" in self.path:
response = self._create_text_completion_response(request_data) response = self._create_text_completion_response(request_data)
self._send_json_response(200, response) self._send_json_response(200, response)
else: else:
self._send_json_response(200, {"status": "ok"}) self._send_json_response(200, {"status": "ok"})
# ---------------------------------------------------------------- # ----------------------------------------------------------------
# Response Generators # Response Generators
# **TO MODIFY RESPONSES:** Edit these methods # **TO MODIFY RESPONSES:** Edit these methods
# ---------------------------------------------------------------- # ----------------------------------------------------------------
def _create_models_list_response(self) -> Dict: def _create_models_list_response(self) -> dict:
"""Create OpenAI models list response with configured models.""" """Create OpenAI models list response with configured models."""
return { return {
"object": "list", "object": "list",
@ -271,13 +278,13 @@ class MockVLLMServer(MockServerBase):
"owned_by": "meta", "owned_by": "meta",
} }
for model_id in server_self.models for model_id in server_self.models
] ],
} }
def _create_chat_completion_response(self, request_data: Dict) -> Dict: def _create_chat_completion_response(self, request_data: dict) -> dict:
""" """
Create OpenAI ChatCompletion response. Create OpenAI ChatCompletion response.
Returns a valid response matching openai.types.ChatCompletion Returns a valid response matching openai.types.ChatCompletion
""" """
return { return {
@ -285,16 +292,18 @@ class MockVLLMServer(MockServerBase):
"object": "chat.completion", "object": "chat.completion",
"created": int(time.time()), "created": int(time.time()),
"model": request_data.get("model", "meta-llama/Llama-3.2-1B-Instruct"), "model": request_data.get("model", "meta-llama/Llama-3.2-1B-Instruct"),
"choices": [{ "choices": [
"index": 0, {
"message": { "index": 0,
"role": "assistant", "message": {
"content": "This is a test response from mock vLLM server.", "role": "assistant",
"tool_calls": None, "content": "This is a test response from mock vLLM server.",
}, "tool_calls": None,
"logprobs": None, },
"finish_reason": "stop", "logprobs": None,
}], "finish_reason": "stop",
}
],
"usage": { "usage": {
"prompt_tokens": 25, "prompt_tokens": 25,
"completion_tokens": 15, "completion_tokens": 15,
@ -304,11 +313,11 @@ class MockVLLMServer(MockServerBase):
"system_fingerprint": None, "system_fingerprint": None,
"service_tier": None, "service_tier": None,
} }
def _create_text_completion_response(self, request_data: Dict) -> Dict: def _create_text_completion_response(self, request_data: dict) -> dict:
""" """
Create OpenAI Completion response. Create OpenAI Completion response.
Returns a valid response matching openai.types.Completion Returns a valid response matching openai.types.Completion
""" """
return { return {
@ -316,12 +325,14 @@ class MockVLLMServer(MockServerBase):
"object": "text_completion", "object": "text_completion",
"created": int(time.time()), "created": int(time.time()),
"model": request_data.get("model", "meta-llama/Llama-3.2-1B-Instruct"), "model": request_data.get("model", "meta-llama/Llama-3.2-1B-Instruct"),
"choices": [{ "choices": [
"text": "This is a test completion.", {
"index": 0, "text": "This is a test completion.",
"logprobs": None, "index": 0,
"finish_reason": "stop", "logprobs": None,
}], "finish_reason": "stop",
}
],
"usage": { "usage": {
"prompt_tokens": 10, "prompt_tokens": 10,
"completion_tokens": 8, "completion_tokens": 8,
@ -330,58 +341,57 @@ class MockVLLMServer(MockServerBase):
}, },
"system_fingerprint": None, "system_fingerprint": None,
} }
def _send_json_response(self, status_code: int, data: Dict): def _send_json_response(self, status_code: int, data: dict):
"""Helper to send JSON response.""" """Helper to send JSON response."""
self.send_response(status_code) self.send_response(status_code)
self.send_header('Content-Type', 'application/json') self.send_header("Content-Type", "application/json")
self.end_headers() self.end_headers()
self.wfile.write(json.dumps(data).encode()) self.wfile.write(json.dumps(data).encode())
return VLLMHandler return VLLMHandler
async def await_start(self): async def await_start(self):
""" """
Start the vLLM server and wait until ready. Start the vLLM server and wait until ready.
This method is async and can be awaited to ensure the server is ready. This method is async and can be awaited to ensure the server is ready.
""" """
# Create handler and start the HTTP server # Create handler and start the HTTP server
handler_class = self._create_handler_class() handler_class = self._create_handler_class()
self.server = http.server.HTTPServer(('localhost', self.port), handler_class) self.server = http.server.HTTPServer(("localhost", self.port), handler_class)
self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True) self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.server_thread.start() self.server_thread.start()
# Wait for server to be listening on the port # Wait for server to be listening on the port
for _ in range(10): for _ in range(10):
try: try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(('localhost', self.port)) result = sock.connect_ex(("localhost", self.port))
sock.close() sock.close()
if result == 0: if result == 0:
# Port is listening # Port is listening
return return
except: except Exception:
pass pass
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
raise RuntimeError(f"vLLM server failed to start on port {self.port}") raise RuntimeError(f"vLLM server failed to start on port {self.port}")
def stop(self): def stop(self):
"""Stop the vLLM server.""" """Stop the vLLM server."""
if self.server: if self.server:
self.server.shutdown() self.server.shutdown()
self.server.server_close() self.server.server_close()
def clear(self): def clear(self):
"""Clear request history.""" """Clear request history."""
self.requests_received = [] self.requests_received = []
def get_request_count(self) -> int: def get_request_count(self) -> int:
"""Get number of requests received.""" """Get number of requests received."""
return len(self.requests_received) return len(self.requests_received)
def get_all_requests(self) -> List[Dict]: def get_all_requests(self) -> list[dict]:
"""Get all received requests with their bodies.""" """Get all received requests with their bodies."""
return self.requests_received return self.requests_received

View file

@ -34,7 +34,7 @@ import os
import socket import socket
import subprocess import subprocess
import time import time
from typing import Any, Dict, List from typing import Any
import pytest import pytest
import requests import requests
@ -44,28 +44,28 @@ from pydantic import BaseModel, Field
# Mock servers are in the mocking/ subdirectory # Mock servers are in the mocking/ subdirectory
from .mocking import ( from .mocking import (
MockOTLPCollector, MockOTLPCollector,
MockVLLMServer,
MockServerConfig, MockServerConfig,
MockVLLMServer,
start_mock_servers_async, start_mock_servers_async,
stop_mock_servers, stop_mock_servers,
) )
# ============================================================================ # ============================================================================
# DATA MODELS # DATA MODELS
# ============================================================================ # ============================================================================
class TelemetryTestCase(BaseModel): class TelemetryTestCase(BaseModel):
""" """
Pydantic model defining expected telemetry for an API call. Pydantic model defining expected telemetry for an API call.
**TO ADD A NEW TEST CASE:** Add to TEST_CASES list below. **TO ADD A NEW TEST CASE:** Add to TEST_CASES list below.
""" """
name: str = Field(description="Unique test case identifier") name: str = Field(description="Unique test case identifier")
http_method: str = Field(description="HTTP method (GET, POST, etc.)") http_method: str = Field(description="HTTP method (GET, POST, etc.)")
api_path: str = Field(description="API path (e.g., '/v1/models')") api_path: str = Field(description="API path (e.g., '/v1/models')")
request_body: Dict[str, Any] | None = Field(default=None) request_body: dict[str, Any] | None = Field(default=None)
expected_http_status: int = Field(default=200) expected_http_status: int = Field(default=200)
expected_trace_exports: int = Field(default=1, description="Minimum number of trace exports expected") expected_trace_exports: int = Field(default=1, description="Minimum number of trace exports expected")
expected_metric_exports: int = Field(default=0, description="Minimum number of metric exports expected") expected_metric_exports: int = Field(default=0, description="Minimum number of metric exports expected")
@ -103,71 +103,74 @@ TEST_CASES = [
# TEST INFRASTRUCTURE # TEST INFRASTRUCTURE
# ============================================================================ # ============================================================================
class TelemetryTestRunner: class TelemetryTestRunner:
""" """
Executes TelemetryTestCase instances against real Llama Stack. Executes TelemetryTestCase instances against real Llama Stack.
**HOW IT WORKS:** **HOW IT WORKS:**
1. Makes real HTTP request to the stack 1. Makes real HTTP request to the stack
2. Waits for telemetry export 2. Waits for telemetry export
3. Verifies exports were sent to mock collector 3. Verifies exports were sent to mock collector
""" """
def __init__(self, base_url: str, collector: MockOTLPCollector): def __init__(self, base_url: str, collector: MockOTLPCollector):
self.base_url = base_url self.base_url = base_url
self.collector = collector self.collector = collector
def run_test_case(self, test_case: TelemetryTestCase, verbose: bool = False) -> bool: def run_test_case(self, test_case: TelemetryTestCase, verbose: bool = False) -> bool:
"""Execute a single test case and verify telemetry.""" """Execute a single test case and verify telemetry."""
initial_traces = self.collector.get_trace_count() initial_traces = self.collector.get_trace_count()
initial_metrics = self.collector.get_metric_count() initial_metrics = self.collector.get_metric_count()
if verbose: if verbose:
print(f"\n--- {test_case.name} ---") print(f"\n--- {test_case.name} ---")
print(f" {test_case.http_method} {test_case.api_path}") print(f" {test_case.http_method} {test_case.api_path}")
# Make real HTTP request to Llama Stack # Make real HTTP request to Llama Stack
try: try:
url = f"{self.base_url}{test_case.api_path}" url = f"{self.base_url}{test_case.api_path}"
if test_case.http_method == "GET": if test_case.http_method == "GET":
response = requests.get(url, timeout=5) response = requests.get(url, timeout=5)
elif test_case.http_method == "POST": elif test_case.http_method == "POST":
response = requests.post(url, json=test_case.request_body or {}, timeout=5) response = requests.post(url, json=test_case.request_body or {}, timeout=5)
else: else:
response = requests.request(test_case.http_method, url, timeout=5) response = requests.request(test_case.http_method, url, timeout=5)
if verbose: if verbose:
print(f" HTTP Response: {response.status_code}") print(f" HTTP Response: {response.status_code}")
status_match = response.status_code == test_case.expected_http_status status_match = response.status_code == test_case.expected_http_status
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
if verbose: if verbose:
print(f" Request failed: {e}") print(f" Request failed: {e}")
status_match = False status_match = False
# Wait for automatic instrumentation to export telemetry # Wait for automatic instrumentation to export telemetry
# Traces export immediately, metrics export every 1 second (configured via env var) # Traces export immediately, metrics export every 1 second (configured via env var)
time.sleep(2.0) # Wait for both traces and metrics to export time.sleep(2.0) # Wait for both traces and metrics to export
# Verify traces were exported to mock collector # Verify traces were exported to mock collector
new_traces = self.collector.get_trace_count() - initial_traces new_traces = self.collector.get_trace_count() - initial_traces
traces_exported = new_traces >= test_case.expected_trace_exports traces_exported = new_traces >= test_case.expected_trace_exports
# Verify metrics were exported (if expected) # Verify metrics were exported (if expected)
new_metrics = self.collector.get_metric_count() - initial_metrics new_metrics = self.collector.get_metric_count() - initial_metrics
metrics_exported = new_metrics >= test_case.expected_metric_exports metrics_exported = new_metrics >= test_case.expected_metric_exports
if verbose: if verbose:
print(f" Expected: >={test_case.expected_trace_exports} trace exports, >={test_case.expected_metric_exports} metric exports") print(
f" Expected: >={test_case.expected_trace_exports} trace exports, >={test_case.expected_metric_exports} metric exports"
)
print(f" Actual: {new_traces} trace exports, {new_metrics} metric exports") print(f" Actual: {new_traces} trace exports, {new_metrics} metric exports")
result = status_match and traces_exported and metrics_exported result = status_match and traces_exported and metrics_exported
print(f" Result: {'PASS' if result else 'FAIL'}") print(f" Result: {'PASS' if result else 'FAIL'}")
return status_match and traces_exported and metrics_exported return status_match and traces_exported and metrics_exported
def run_all_test_cases(self, test_cases: List[TelemetryTestCase], verbose: bool = True) -> Dict[str, bool]: def run_all_test_cases(self, test_cases: list[TelemetryTestCase], verbose: bool = True) -> dict[str, bool]:
"""Run all test cases and return results.""" """Run all test cases and return results."""
results = {} results = {}
for test_case in test_cases: for test_case in test_cases:
@ -179,11 +182,12 @@ class TelemetryTestRunner:
# HELPER FUNCTIONS # HELPER FUNCTIONS
# ============================================================================ # ============================================================================
def is_port_available(port: int) -> bool: def is_port_available(port: int) -> bool:
"""Check if a TCP port is available for binding.""" """Check if a TCP port is available for binding."""
try: try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(('localhost', port)) sock.bind(("localhost", port))
return True return True
except OSError: except OSError:
return False return False
@ -193,20 +197,21 @@ def is_port_available(port: int) -> bool:
# PYTEST FIXTURES # PYTEST FIXTURES
# ============================================================================ # ============================================================================
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def mock_servers(): def mock_servers():
""" """
Fixture: Start all mock servers in parallel using async harness. Fixture: Start all mock servers in parallel using async harness.
**TO ADD A NEW MOCK SERVER:** **TO ADD A NEW MOCK SERVER:**
Just add a MockServerConfig to the MOCK_SERVERS list below. Just add a MockServerConfig to the MOCK_SERVERS list below.
""" """
import asyncio import asyncio
# ======================================================================== # ========================================================================
# MOCK SERVER CONFIGURATION # MOCK SERVER CONFIGURATION
# **TO ADD A NEW MOCK:** Just add a MockServerConfig instance below # **TO ADD A NEW MOCK:** Just add a MockServerConfig instance below
# #
# Example: # Example:
# MockServerConfig( # MockServerConfig(
# name="Mock MyService", # name="Mock MyService",
@ -214,7 +219,7 @@ def mock_servers():
# init_kwargs={"port": 9000, "param": "value"}, # init_kwargs={"port": 9000, "param": "value"},
# ), # ),
# ======================================================================== # ========================================================================
MOCK_SERVERS = [ mock_servers_config = [
MockServerConfig( MockServerConfig(
name="Mock OTLP Collector", name="Mock OTLP Collector",
server_class=MockOTLPCollector, server_class=MockOTLPCollector,
@ -230,17 +235,17 @@ def mock_servers():
), ),
# Add more mock servers here - they will start in parallel automatically! # Add more mock servers here - they will start in parallel automatically!
] ]
# Start all servers in parallel # Start all servers in parallel
servers = asyncio.run(start_mock_servers_async(MOCK_SERVERS)) servers = asyncio.run(start_mock_servers_async(mock_servers_config))
# Verify vLLM models # Verify vLLM models
models_response = requests.get("http://localhost:8000/v1/models", timeout=1) models_response = requests.get("http://localhost:8000/v1/models", timeout=1)
models_data = models_response.json() models_data = models_response.json()
print(f"[INFO] Mock vLLM serving {len(models_data['data'])} models: {[m['id'] for m in models_data['data']]}") print(f"[INFO] Mock vLLM serving {len(models_data['data'])} models: {[m['id'] for m in models_data['data']]}")
yield servers yield servers
# Stop all servers # Stop all servers
stop_mock_servers(servers) stop_mock_servers(servers)
@ -261,22 +266,22 @@ def mock_vllm_server(mock_servers):
def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server): def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server):
""" """
Fixture: Start real Llama Stack server with automatic OTel instrumentation. Fixture: Start real Llama Stack server with automatic OTel instrumentation.
**THIS IS THE MAIN FIXTURE** - it runs: **THIS IS THE MAIN FIXTURE** - it runs:
opentelemetry-instrument llama stack run --config run.yaml opentelemetry-instrument llama stack run --config run.yaml
**TO MODIFY STACK CONFIG:** Edit run_config dict below **TO MODIFY STACK CONFIG:** Edit run_config dict below
""" """
config_dir = tmp_path_factory.mktemp("otel-stack-config") config_dir = tmp_path_factory.mktemp("otel-stack-config")
# Ensure mock vLLM is ready and accessible before starting Llama Stack # Ensure mock vLLM is ready and accessible before starting Llama Stack
print(f"\n[INFO] Verifying mock vLLM is accessible at http://localhost:8000...") print("\n[INFO] Verifying mock vLLM is accessible at http://localhost:8000...")
try: try:
vllm_models = requests.get("http://localhost:8000/v1/models", timeout=2) vllm_models = requests.get("http://localhost:8000/v1/models", timeout=2)
print(f"[INFO] Mock vLLM models endpoint response: {vllm_models.status_code}") print(f"[INFO] Mock vLLM models endpoint response: {vllm_models.status_code}")
except Exception as e: except Exception as e:
pytest.fail(f"Mock vLLM not accessible before starting Llama Stack: {e}") pytest.fail(f"Mock vLLM not accessible before starting Llama Stack: {e}")
# Create run.yaml with inference provider # Create run.yaml with inference provider
# **TO ADD MORE PROVIDERS:** Add to providers dict # **TO ADD MORE PROVIDERS:** Add to providers dict
run_config = { run_config = {
@ -300,19 +305,19 @@ def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server):
} }
], ],
} }
config_file = config_dir / "run.yaml" config_file = config_dir / "run.yaml"
with open(config_file, "w") as f: with open(config_file, "w") as f:
yaml.dump(run_config, f) yaml.dump(run_config, f)
# Find available port for Llama Stack # Find available port for Llama Stack
port = 5555 port = 5555
while not is_port_available(port) and port < 5600: while not is_port_available(port) and port < 5600:
port += 1 port += 1
if port >= 5600: if port >= 5600:
pytest.skip("No available ports for test server") pytest.skip("No available ports for test server")
# Set environment variables for OTel instrumentation # Set environment variables for OTel instrumentation
# NOTE: These only affect the subprocess, not other tests # NOTE: These only affect the subprocess, not other tests
env = os.environ.copy() env = os.environ.copy()
@ -321,29 +326,32 @@ def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server):
env["OTEL_SERVICE_NAME"] = "llama-stack-e2e-test" env["OTEL_SERVICE_NAME"] = "llama-stack-e2e-test"
env["LLAMA_STACK_PORT"] = str(port) env["LLAMA_STACK_PORT"] = str(port)
env["OTEL_PYTHON_LOGGING_AUTO_INSTRUMENTATION_ENABLED"] = "true" env["OTEL_PYTHON_LOGGING_AUTO_INSTRUMENTATION_ENABLED"] = "true"
# Configure fast metric export for testing (default is 60 seconds) # Configure fast metric export for testing (default is 60 seconds)
# This makes metrics export every 500ms instead of every 60 seconds # This makes metrics export every 500ms instead of every 60 seconds
env["OTEL_METRIC_EXPORT_INTERVAL"] = "500" # milliseconds env["OTEL_METRIC_EXPORT_INTERVAL"] = "500" # milliseconds
env["OTEL_METRIC_EXPORT_TIMEOUT"] = "1000" # milliseconds env["OTEL_METRIC_EXPORT_TIMEOUT"] = "1000" # milliseconds
# Disable inference recording to ensure real requests to our mock vLLM # Disable inference recording to ensure real requests to our mock vLLM
# This is critical - without this, Llama Stack replays cached responses # This is critical - without this, Llama Stack replays cached responses
# Safe to remove here as it only affects the subprocess environment # Safe to remove here as it only affects the subprocess environment
if "LLAMA_STACK_TEST_INFERENCE_MODE" in env: if "LLAMA_STACK_TEST_INFERENCE_MODE" in env:
del env["LLAMA_STACK_TEST_INFERENCE_MODE"] del env["LLAMA_STACK_TEST_INFERENCE_MODE"]
# Start server with automatic instrumentation # Start server with automatic instrumentation
cmd = [ cmd = [
"opentelemetry-instrument", # ← Automatic instrumentation wrapper "opentelemetry-instrument", # ← Automatic instrumentation wrapper
"llama", "stack", "run", "llama",
"stack",
"run",
str(config_file), str(config_file),
"--port", str(port), "--port",
str(port),
] ]
print(f"\n[INFO] Starting Llama Stack with OTel instrumentation on port {port}") print(f"\n[INFO] Starting Llama Stack with OTel instrumentation on port {port}")
print(f"[INFO] Command: {' '.join(cmd)}") print(f"[INFO] Command: {' '.join(cmd)}")
process = subprocess.Popen( process = subprocess.Popen(
cmd, cmd,
env=env, env=env,
@ -351,11 +359,11 @@ def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server):
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
text=True, text=True,
) )
# Wait for server to start # Wait for server to start
max_wait = 30 max_wait = 30
base_url = f"http://localhost:{port}" base_url = f"http://localhost:{port}"
for i in range(max_wait): for i in range(max_wait):
try: try:
response = requests.get(f"{base_url}/v1/health", timeout=1) response = requests.get(f"{base_url}/v1/health", timeout=1)
@ -368,16 +376,16 @@ def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server):
stdout, stderr = process.communicate(timeout=5) stdout, stderr = process.communicate(timeout=5)
pytest.fail(f"Server failed to start.\nStdout: {stdout}\nStderr: {stderr}") pytest.fail(f"Server failed to start.\nStdout: {stdout}\nStderr: {stderr}")
time.sleep(1) time.sleep(1)
yield { yield {
'base_url': base_url, "base_url": base_url,
'port': port, "port": port,
'collector': mock_otlp_collector, "collector": mock_otlp_collector,
'vllm_server': mock_vllm_server, "vllm_server": mock_vllm_server,
} }
# Cleanup # Cleanup
print(f"\n[INFO] Stopping Llama Stack server") print("\n[INFO] Stopping Llama Stack server")
process.terminate() process.terminate()
try: try:
process.wait(timeout=5) process.wait(timeout=5)
@ -391,26 +399,27 @@ def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server):
# **TO ADD NEW E2E TESTS:** Add methods to this class # **TO ADD NEW E2E TESTS:** Add methods to this class
# ============================================================================ # ============================================================================
@pytest.mark.slow @pytest.mark.slow
class TestOTelE2E: class TestOTelE2E:
""" """
End-to-end tests with real Llama Stack server. End-to-end tests with real Llama Stack server.
These tests verify the complete flow: These tests verify the complete flow:
- Real Llama Stack with opentelemetry-instrument - Real Llama Stack with opentelemetry-instrument
- Real API calls - Real API calls
- Real automatic instrumentation - Real automatic instrumentation
- Mock OTLP collector captures exports - Mock OTLP collector captures exports
""" """
def test_server_starts_with_auto_instrumentation(self, llama_stack_server): def test_server_starts_with_auto_instrumentation(self, llama_stack_server):
"""Verify server starts successfully with opentelemetry-instrument.""" """Verify server starts successfully with opentelemetry-instrument."""
base_url = llama_stack_server['base_url'] base_url = llama_stack_server["base_url"]
# Try different health check endpoints # Try different health check endpoints
health_endpoints = ["/health", "/v1/health", "/"] health_endpoints = ["/health", "/v1/health", "/"]
server_responding = False server_responding = False
for endpoint in health_endpoints: for endpoint in health_endpoints:
try: try:
response = requests.get(f"{base_url}{endpoint}", timeout=5) response = requests.get(f"{base_url}{endpoint}", timeout=5)
@ -420,36 +429,36 @@ class TestOTelE2E:
break break
except Exception as e: except Exception as e:
print(f"[DEBUG] {endpoint} failed: {e}") print(f"[DEBUG] {endpoint} failed: {e}")
assert server_responding, f"Server not responding on any endpoint at {base_url}" assert server_responding, f"Server not responding on any endpoint at {base_url}"
print(f"\n[PASS] Llama Stack running with OTel at {base_url}") print(f"\n[PASS] Llama Stack running with OTel at {base_url}")
def test_all_test_cases_via_runner(self, llama_stack_server): def test_all_test_cases_via_runner(self, llama_stack_server):
""" """
**MAIN TEST:** Run all TelemetryTestCase instances. **MAIN TEST:** Run all TelemetryTestCase instances.
This executes all test cases defined in TEST_CASES list. This executes all test cases defined in TEST_CASES list.
**TO ADD MORE TESTS:** Add to TEST_CASES at top of file **TO ADD MORE TESTS:** Add to TEST_CASES at top of file
""" """
base_url = llama_stack_server['base_url'] base_url = llama_stack_server["base_url"]
collector = llama_stack_server['collector'] collector = llama_stack_server["collector"]
# Create test runner # Create test runner
runner = TelemetryTestRunner(base_url, collector) runner = TelemetryTestRunner(base_url, collector)
# Execute all test cases # Execute all test cases
results = runner.run_all_test_cases(TEST_CASES, verbose=True) results = runner.run_all_test_cases(TEST_CASES, verbose=True)
# Print summary # Print summary
print(f"\n{'='*50}") print(f"\n{'=' * 50}")
print(f"TEST CASE SUMMARY") print("TEST CASE SUMMARY")
print(f"{'='*50}") print(f"{'=' * 50}")
passed = sum(1 for p in results.values() if p) passed = sum(1 for p in results.values() if p)
total = len(results) total = len(results)
print(f"Passed: {passed}/{total}\n") print(f"Passed: {passed}/{total}\n")
for name, result in results.items(): for name, result in results.items():
status = "[PASS]" if result else "[FAIL]" status = "[PASS]" if result else "[FAIL]"
print(f" {status} {name}") print(f" {status} {name}")
print(f"{'='*50}\n") print(f"{'=' * 50}\n")

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import pytest import pytest
import llama_stack.providers.inline.telemetry.meta_reference.telemetry as telemetry_module import llama_stack.providers.inline.telemetry.meta_reference.telemetry as telemetry_module
@ -38,7 +36,7 @@ def test_warns_when_traces_endpoints_missing(monkeypatch: pytest.MonkeyPatch, ca
monkeypatch.delenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", raising=False) monkeypatch.delenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", raising=False)
monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False)
caplog.set_level(logging.WARNING) caplog.set_level("WARNING")
config = _make_config_with_sinks(TelemetrySink.OTEL_TRACE) config = _make_config_with_sinks(TelemetrySink.OTEL_TRACE)
telemetry_module.TelemetryAdapter(config=config, deps={}) telemetry_module.TelemetryAdapter(config=config, deps={})
@ -57,7 +55,7 @@ def test_warns_when_metrics_endpoints_missing(monkeypatch: pytest.MonkeyPatch, c
monkeypatch.delenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", raising=False) monkeypatch.delenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", raising=False)
monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False)
caplog.set_level(logging.WARNING) caplog.set_level("WARNING")
config = _make_config_with_sinks(TelemetrySink.OTEL_METRIC) config = _make_config_with_sinks(TelemetrySink.OTEL_METRIC)
telemetry_module.TelemetryAdapter(config=config, deps={}) telemetry_module.TelemetryAdapter(config=config, deps={})
@ -76,7 +74,7 @@ def test_no_warning_when_traces_endpoints_present(monkeypatch: pytest.MonkeyPatc
monkeypatch.setenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "https://otel.example:4318/v1/traces") monkeypatch.setenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "https://otel.example:4318/v1/traces")
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://otel.example:4318") monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://otel.example:4318")
caplog.set_level(logging.WARNING) caplog.set_level("WARNING")
config = _make_config_with_sinks(TelemetrySink.OTEL_TRACE) config = _make_config_with_sinks(TelemetrySink.OTEL_TRACE)
telemetry_module.TelemetryAdapter(config=config, deps={}) telemetry_module.TelemetryAdapter(config=config, deps={})
@ -91,7 +89,7 @@ def test_no_warning_when_metrics_endpoints_present(monkeypatch: pytest.MonkeyPat
monkeypatch.setenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", "https://otel.example:4318/v1/metrics") monkeypatch.setenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", "https://otel.example:4318/v1/metrics")
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://otel.example:4318") monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://otel.example:4318")
caplog.set_level(logging.WARNING) caplog.set_level("WARNING")
config = _make_config_with_sinks(TelemetrySink.OTEL_METRIC) config = _make_config_with_sinks(TelemetrySink.OTEL_METRIC)
telemetry_module.TelemetryAdapter(config=config, deps={}) telemetry_module.TelemetryAdapter(config=config, deps={})

View file

@ -41,17 +41,17 @@ class TestOTelTelemetryProviderInitialization:
def test_initialization_with_valid_config(self, otel_config, monkeypatch): def test_initialization_with_valid_config(self, otel_config, monkeypatch):
"""Test that provider initializes correctly with valid configuration.""" """Test that provider initializes correctly with valid configuration."""
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4318") monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4318")
provider = OTelTelemetryProvider(config=otel_config) provider = OTelTelemetryProvider(config=otel_config)
assert provider.config == otel_config assert provider.config == otel_config
def test_initialization_sets_service_attributes(self, otel_config, monkeypatch): def test_initialization_sets_service_attributes(self, otel_config, monkeypatch):
"""Test that service attributes are properly configured.""" """Test that service attributes are properly configured."""
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4318") monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4318")
provider = OTelTelemetryProvider(config=otel_config) provider = OTelTelemetryProvider(config=otel_config)
assert provider.config.service_name == "test-service" assert provider.config.service_name == "test-service"
assert provider.config.service_version == "1.0.0" assert provider.config.service_version == "1.0.0"
assert provider.config.deployment_environment == "test" assert provider.config.deployment_environment == "test"
@ -65,9 +65,9 @@ class TestOTelTelemetryProviderInitialization:
deployment_environment="test", deployment_environment="test",
span_processor="batch", span_processor="batch",
) )
provider = OTelTelemetryProvider(config=config) provider = OTelTelemetryProvider(config=config)
assert provider.config.span_processor == "batch" assert provider.config.span_processor == "batch"
def test_warns_when_endpoints_missing(self, otel_config, monkeypatch, caplog): def test_warns_when_endpoints_missing(self, otel_config, monkeypatch, caplog):
@ -76,9 +76,9 @@ class TestOTelTelemetryProviderInitialization:
monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False)
monkeypatch.delenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", raising=False) monkeypatch.delenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", raising=False)
monkeypatch.delenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", raising=False) monkeypatch.delenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", raising=False)
OTelTelemetryProvider(config=otel_config) OTelTelemetryProvider(config=otel_config)
# Check that warnings were logged # Check that warnings were logged
assert any("Traces will not be exported" in record.message for record in caplog.records) assert any("Traces will not be exported" in record.message for record in caplog.records)
assert any("Metrics will not be exported" in record.message for record in caplog.records) assert any("Metrics will not be exported" in record.message for record in caplog.records)
@ -90,44 +90,41 @@ class TestOTelTelemetryProviderTracerAPI:
def test_get_tracer_returns_tracer(self, otel_provider): def test_get_tracer_returns_tracer(self, otel_provider):
"""Test that get_tracer returns a valid Tracer instance.""" """Test that get_tracer returns a valid Tracer instance."""
tracer = otel_provider.get_tracer("test.module") tracer = otel_provider.get_tracer("test.module")
assert tracer is not None assert tracer is not None
assert isinstance(tracer, Tracer) assert isinstance(tracer, Tracer)
def test_get_tracer_with_version(self, otel_provider): def test_get_tracer_with_version(self, otel_provider):
"""Test that get_tracer works with version parameter.""" """Test that get_tracer works with version parameter."""
tracer = otel_provider.get_tracer( tracer = otel_provider.get_tracer(
instrumenting_module_name="test.module", instrumenting_module_name="test.module", instrumenting_library_version="1.0.0"
instrumenting_library_version="1.0.0"
) )
assert tracer is not None assert tracer is not None
assert isinstance(tracer, Tracer) assert isinstance(tracer, Tracer)
def test_get_tracer_with_attributes(self, otel_provider): def test_get_tracer_with_attributes(self, otel_provider):
"""Test that get_tracer works with attributes.""" """Test that get_tracer works with attributes."""
tracer = otel_provider.get_tracer( tracer = otel_provider.get_tracer(
instrumenting_module_name="test.module", instrumenting_module_name="test.module", attributes={"component": "test", "tier": "backend"}
attributes={"component": "test", "tier": "backend"}
) )
assert tracer is not None assert tracer is not None
assert isinstance(tracer, Tracer) assert isinstance(tracer, Tracer)
def test_get_tracer_with_schema_url(self, otel_provider): def test_get_tracer_with_schema_url(self, otel_provider):
"""Test that get_tracer works with schema URL.""" """Test that get_tracer works with schema URL."""
tracer = otel_provider.get_tracer( tracer = otel_provider.get_tracer(
instrumenting_module_name="test.module", instrumenting_module_name="test.module", schema_url="https://example.com/schema"
schema_url="https://example.com/schema"
) )
assert tracer is not None assert tracer is not None
assert isinstance(tracer, Tracer) assert isinstance(tracer, Tracer)
def test_tracer_can_create_spans(self, otel_provider): def test_tracer_can_create_spans(self, otel_provider):
"""Test that tracer can create spans.""" """Test that tracer can create spans."""
tracer = otel_provider.get_tracer("test.module") tracer = otel_provider.get_tracer("test.module")
with tracer.start_as_current_span("test.operation") as span: with tracer.start_as_current_span("test.operation") as span:
assert span is not None assert span is not None
assert span.is_recording() assert span.is_recording()
@ -135,11 +132,8 @@ class TestOTelTelemetryProviderTracerAPI:
def test_tracer_can_create_spans_with_attributes(self, otel_provider): def test_tracer_can_create_spans_with_attributes(self, otel_provider):
"""Test that tracer can create spans with attributes.""" """Test that tracer can create spans with attributes."""
tracer = otel_provider.get_tracer("test.module") tracer = otel_provider.get_tracer("test.module")
with tracer.start_as_current_span( with tracer.start_as_current_span("test.operation", attributes={"user.id": "123", "request.id": "abc"}) as span:
"test.operation",
attributes={"user.id": "123", "request.id": "abc"}
) as span:
assert span is not None assert span is not None
assert span.is_recording() assert span.is_recording()
@ -147,7 +141,7 @@ class TestOTelTelemetryProviderTracerAPI:
"""Test that multiple tracers can be created.""" """Test that multiple tracers can be created."""
tracer1 = otel_provider.get_tracer("module.one") tracer1 = otel_provider.get_tracer("module.one")
tracer2 = otel_provider.get_tracer("module.two") tracer2 = otel_provider.get_tracer("module.two")
assert tracer1 is not None assert tracer1 is not None
assert tracer2 is not None assert tracer2 is not None
# Tracers with different names might be the same instance or different # Tracers with different names might be the same instance or different
@ -164,50 +158,37 @@ class TestOTelTelemetryProviderMeterAPI:
def test_get_meter_returns_meter(self, otel_provider): def test_get_meter_returns_meter(self, otel_provider):
"""Test that get_meter returns a valid Meter instance.""" """Test that get_meter returns a valid Meter instance."""
meter = otel_provider.get_meter("test.meter") meter = otel_provider.get_meter("test.meter")
assert meter is not None assert meter is not None
assert isinstance(meter, Meter) assert isinstance(meter, Meter)
def test_get_meter_with_version(self, otel_provider): def test_get_meter_with_version(self, otel_provider):
"""Test that get_meter works with version parameter.""" """Test that get_meter works with version parameter."""
meter = otel_provider.get_meter( meter = otel_provider.get_meter(name="test.meter", version="1.0.0")
name="test.meter",
version="1.0.0"
)
assert meter is not None assert meter is not None
assert isinstance(meter, Meter) assert isinstance(meter, Meter)
def test_get_meter_with_attributes(self, otel_provider): def test_get_meter_with_attributes(self, otel_provider):
"""Test that get_meter works with attributes.""" """Test that get_meter works with attributes."""
meter = otel_provider.get_meter( meter = otel_provider.get_meter(name="test.meter", attributes={"service": "test", "env": "dev"})
name="test.meter",
attributes={"service": "test", "env": "dev"}
)
assert meter is not None assert meter is not None
assert isinstance(meter, Meter) assert isinstance(meter, Meter)
def test_get_meter_with_schema_url(self, otel_provider): def test_get_meter_with_schema_url(self, otel_provider):
"""Test that get_meter works with schema URL.""" """Test that get_meter works with schema URL."""
meter = otel_provider.get_meter( meter = otel_provider.get_meter(name="test.meter", schema_url="https://example.com/schema")
name="test.meter",
schema_url="https://example.com/schema"
)
assert meter is not None assert meter is not None
assert isinstance(meter, Meter) assert isinstance(meter, Meter)
def test_meter_can_create_counter(self, otel_provider): def test_meter_can_create_counter(self, otel_provider):
"""Test that meter can create counters.""" """Test that meter can create counters."""
meter = otel_provider.get_meter("test.meter") meter = otel_provider.get_meter("test.meter")
counter = meter.create_counter( counter = meter.create_counter("test.requests.total", unit="requests", description="Total requests")
"test.requests.total",
unit="requests",
description="Total requests"
)
assert counter is not None assert counter is not None
# Test that counter can be used # Test that counter can be used
counter.add(1, {"endpoint": "/test"}) counter.add(1, {"endpoint": "/test"})
@ -215,13 +196,9 @@ class TestOTelTelemetryProviderMeterAPI:
def test_meter_can_create_histogram(self, otel_provider): def test_meter_can_create_histogram(self, otel_provider):
"""Test that meter can create histograms.""" """Test that meter can create histograms."""
meter = otel_provider.get_meter("test.meter") meter = otel_provider.get_meter("test.meter")
histogram = meter.create_histogram( histogram = meter.create_histogram("test.request.duration", unit="ms", description="Request duration")
"test.request.duration",
unit="ms",
description="Request duration"
)
assert histogram is not None assert histogram is not None
# Test that histogram can be used # Test that histogram can be used
histogram.record(42.5, {"method": "GET"}) histogram.record(42.5, {"method": "GET"})
@ -229,13 +206,11 @@ class TestOTelTelemetryProviderMeterAPI:
def test_meter_can_create_up_down_counter(self, otel_provider): def test_meter_can_create_up_down_counter(self, otel_provider):
"""Test that meter can create up/down counters.""" """Test that meter can create up/down counters."""
meter = otel_provider.get_meter("test.meter") meter = otel_provider.get_meter("test.meter")
up_down_counter = meter.create_up_down_counter( up_down_counter = meter.create_up_down_counter(
"test.active.connections", "test.active.connections", unit="connections", description="Active connections"
unit="connections",
description="Active connections"
) )
assert up_down_counter is not None assert up_down_counter is not None
# Test that up/down counter can be used # Test that up/down counter can be used
up_down_counter.add(5) up_down_counter.add(5)
@ -244,31 +219,28 @@ class TestOTelTelemetryProviderMeterAPI:
def test_meter_can_create_observable_gauge(self, otel_provider): def test_meter_can_create_observable_gauge(self, otel_provider):
"""Test that meter can create observable gauges.""" """Test that meter can create observable gauges."""
meter = otel_provider.get_meter("test.meter") meter = otel_provider.get_meter("test.meter")
def gauge_callback(options): def gauge_callback(options):
return [{"attributes": {"host": "localhost"}, "value": 42.0}] return [{"attributes": {"host": "localhost"}, "value": 42.0}]
gauge = meter.create_observable_gauge( gauge = meter.create_observable_gauge(
"test.memory.usage", "test.memory.usage", callbacks=[gauge_callback], unit="bytes", description="Memory usage"
callbacks=[gauge_callback],
unit="bytes",
description="Memory usage"
) )
assert gauge is not None assert gauge is not None
def test_multiple_instruments_from_same_meter(self, otel_provider): def test_multiple_instruments_from_same_meter(self, otel_provider):
"""Test that a meter can create multiple instruments.""" """Test that a meter can create multiple instruments."""
meter = otel_provider.get_meter("test.meter") meter = otel_provider.get_meter("test.meter")
counter = meter.create_counter("test.counter") counter = meter.create_counter("test.counter")
histogram = meter.create_histogram("test.histogram") histogram = meter.create_histogram("test.histogram")
up_down_counter = meter.create_up_down_counter("test.gauge") up_down_counter = meter.create_up_down_counter("test.gauge")
assert counter is not None assert counter is not None
assert histogram is not None assert histogram is not None
assert up_down_counter is not None assert up_down_counter is not None
# Verify they all work # Verify they all work
counter.add(1) counter.add(1)
histogram.record(10.0) histogram.record(10.0)
@ -281,107 +253,101 @@ class TestOTelTelemetryProviderNativeUsage:
def test_complete_tracing_workflow(self, otel_provider): def test_complete_tracing_workflow(self, otel_provider):
"""Test a complete tracing workflow using native OTel API.""" """Test a complete tracing workflow using native OTel API."""
tracer = otel_provider.get_tracer("llama_stack.inference") tracer = otel_provider.get_tracer("llama_stack.inference")
# Create parent span # Create parent span
with tracer.start_as_current_span("inference.request") as parent_span: with tracer.start_as_current_span("inference.request") as parent_span:
parent_span.set_attribute("model", "llama-3.2-1b") parent_span.set_attribute("model", "llama-3.2-1b")
parent_span.set_attribute("user", "test-user") parent_span.set_attribute("user", "test-user")
# Create child span # Create child span
with tracer.start_as_current_span("model.load") as child_span: with tracer.start_as_current_span("model.load") as child_span:
child_span.set_attribute("model.size", "1B") child_span.set_attribute("model.size", "1B")
assert child_span.is_recording() assert child_span.is_recording()
# Create another child span # Create another child span
with tracer.start_as_current_span("inference.execute") as child_span: with tracer.start_as_current_span("inference.execute") as child_span:
child_span.set_attribute("tokens.input", 25) child_span.set_attribute("tokens.input", 25)
child_span.set_attribute("tokens.output", 150) child_span.set_attribute("tokens.output", 150)
assert child_span.is_recording() assert child_span.is_recording()
assert parent_span.is_recording() assert parent_span.is_recording()
def test_complete_metrics_workflow(self, otel_provider): def test_complete_metrics_workflow(self, otel_provider):
"""Test a complete metrics workflow using native OTel API.""" """Test a complete metrics workflow using native OTel API."""
meter = otel_provider.get_meter("llama_stack.metrics") meter = otel_provider.get_meter("llama_stack.metrics")
# Create various instruments # Create various instruments
request_counter = meter.create_counter( request_counter = meter.create_counter("llama.requests.total", unit="requests", description="Total requests")
"llama.requests.total",
unit="requests",
description="Total requests"
)
latency_histogram = meter.create_histogram( latency_histogram = meter.create_histogram(
"llama.inference.duration", "llama.inference.duration", unit="ms", description="Inference duration"
unit="ms",
description="Inference duration"
) )
active_sessions = meter.create_up_down_counter( active_sessions = meter.create_up_down_counter(
"llama.sessions.active", "llama.sessions.active", unit="sessions", description="Active sessions"
unit="sessions",
description="Active sessions"
) )
# Use the instruments # Use the instruments
request_counter.add(1, {"endpoint": "/chat", "status": "success"}) request_counter.add(1, {"endpoint": "/chat", "status": "success"})
latency_histogram.record(123.45, {"model": "llama-3.2-1b"}) latency_histogram.record(123.45, {"model": "llama-3.2-1b"})
active_sessions.add(1) active_sessions.add(1)
active_sessions.add(-1) active_sessions.add(-1)
# No exceptions means success # No exceptions means success
def test_concurrent_tracer_usage(self, otel_provider): def test_concurrent_tracer_usage(self, otel_provider):
"""Test that multiple threads can use tracers concurrently.""" """Test that multiple threads can use tracers concurrently."""
def create_spans(thread_id): def create_spans(thread_id):
tracer = otel_provider.get_tracer(f"test.module.{thread_id}") tracer = otel_provider.get_tracer(f"test.module.{thread_id}")
for i in range(10): for i in range(10):
with tracer.start_as_current_span(f"operation.{i}") as span: with tracer.start_as_current_span(f"operation.{i}") as span:
span.set_attribute("thread.id", thread_id) span.set_attribute("thread.id", thread_id)
span.set_attribute("iteration", i) span.set_attribute("iteration", i)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(create_spans, i) for i in range(10)] futures = [executor.submit(create_spans, i) for i in range(10)]
concurrent.futures.wait(futures) concurrent.futures.wait(futures)
# If we get here without exceptions, thread safety is working # If we get here without exceptions, thread safety is working
def test_concurrent_meter_usage(self, otel_provider): def test_concurrent_meter_usage(self, otel_provider):
"""Test that multiple threads can use meters concurrently.""" """Test that multiple threads can use meters concurrently."""
def record_metrics(thread_id): def record_metrics(thread_id):
meter = otel_provider.get_meter(f"test.meter.{thread_id}") meter = otel_provider.get_meter(f"test.meter.{thread_id}")
counter = meter.create_counter(f"test.counter.{thread_id}") counter = meter.create_counter(f"test.counter.{thread_id}")
histogram = meter.create_histogram(f"test.histogram.{thread_id}") histogram = meter.create_histogram(f"test.histogram.{thread_id}")
for i in range(10): for i in range(10):
counter.add(1, {"thread": str(thread_id)}) counter.add(1, {"thread": str(thread_id)})
histogram.record(float(i * 10), {"thread": str(thread_id)}) histogram.record(float(i * 10), {"thread": str(thread_id)})
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(record_metrics, i) for i in range(10)] futures = [executor.submit(record_metrics, i) for i in range(10)]
concurrent.futures.wait(futures) concurrent.futures.wait(futures)
# If we get here without exceptions, thread safety is working # If we get here without exceptions, thread safety is working
def test_mixed_tracing_and_metrics(self, otel_provider): def test_mixed_tracing_and_metrics(self, otel_provider):
"""Test using both tracing and metrics together.""" """Test using both tracing and metrics together."""
tracer = otel_provider.get_tracer("test.module") tracer = otel_provider.get_tracer("test.module")
meter = otel_provider.get_meter("test.meter") meter = otel_provider.get_meter("test.meter")
counter = meter.create_counter("operations.count") counter = meter.create_counter("operations.count")
histogram = meter.create_histogram("operation.duration") histogram = meter.create_histogram("operation.duration")
# Trace an operation while recording metrics # Trace an operation while recording metrics
with tracer.start_as_current_span("test.operation") as span: with tracer.start_as_current_span("test.operation") as span:
counter.add(1) counter.add(1)
span.set_attribute("step", "start") span.set_attribute("step", "start")
histogram.record(50.0) histogram.record(50.0)
span.set_attribute("step", "processing") span.set_attribute("step", "processing")
counter.add(1) counter.add(1)
span.set_attribute("step", "complete") span.set_attribute("step", "complete")
# No exceptions means success # No exceptions means success
@ -391,14 +357,14 @@ class TestOTelTelemetryProviderFastAPIMiddleware:
def test_fastapi_middleware(self, otel_provider): def test_fastapi_middleware(self, otel_provider):
"""Test that fastapi_middleware can be called.""" """Test that fastapi_middleware can be called."""
mock_app = MagicMock() mock_app = MagicMock()
# Should not raise an exception # Should not raise an exception
otel_provider.fastapi_middleware(mock_app) otel_provider.fastapi_middleware(mock_app)
def test_fastapi_middleware_is_idempotent(self, otel_provider): def test_fastapi_middleware_is_idempotent(self, otel_provider):
"""Test that calling fastapi_middleware multiple times is safe.""" """Test that calling fastapi_middleware multiple times is safe."""
mock_app = MagicMock() mock_app = MagicMock()
# Should be able to call multiple times without error # Should be able to call multiple times without error
otel_provider.fastapi_middleware(mock_app) otel_provider.fastapi_middleware(mock_app)
# Note: Second call might warn but shouldn't fail # Note: Second call might warn but shouldn't fail
@ -411,27 +377,27 @@ class TestOTelTelemetryProviderEdgeCases:
def test_tracer_with_empty_module_name(self, otel_provider): def test_tracer_with_empty_module_name(self, otel_provider):
"""Test that get_tracer works with empty module name.""" """Test that get_tracer works with empty module name."""
tracer = otel_provider.get_tracer("") tracer = otel_provider.get_tracer("")
assert tracer is not None assert tracer is not None
assert isinstance(tracer, Tracer) assert isinstance(tracer, Tracer)
def test_meter_with_empty_name(self, otel_provider): def test_meter_with_empty_name(self, otel_provider):
"""Test that get_meter works with empty name.""" """Test that get_meter works with empty name."""
meter = otel_provider.get_meter("") meter = otel_provider.get_meter("")
assert meter is not None assert meter is not None
assert isinstance(meter, Meter) assert isinstance(meter, Meter)
def test_meter_instruments_with_special_characters(self, otel_provider): def test_meter_instruments_with_special_characters(self, otel_provider):
"""Test that metric names with dots, underscores, and hyphens work.""" """Test that metric names with dots, underscores, and hyphens work."""
meter = otel_provider.get_meter("test.meter") meter = otel_provider.get_meter("test.meter")
counter = meter.create_counter("test.counter_name-special") counter = meter.create_counter("test.counter_name-special")
histogram = meter.create_histogram("test.histogram_name-special") histogram = meter.create_histogram("test.histogram_name-special")
assert counter is not None assert counter is not None
assert histogram is not None assert histogram is not None
# Verify they can be used # Verify they can be used
counter.add(1) counter.add(1)
histogram.record(10.0) histogram.record(10.0)
@ -440,7 +406,7 @@ class TestOTelTelemetryProviderEdgeCases:
"""Test that counters work with zero value.""" """Test that counters work with zero value."""
meter = otel_provider.get_meter("test.meter") meter = otel_provider.get_meter("test.meter")
counter = meter.create_counter("test.counter") counter = meter.create_counter("test.counter")
# Should not raise an exception # Should not raise an exception
counter.add(0.0) counter.add(0.0)
@ -448,7 +414,7 @@ class TestOTelTelemetryProviderEdgeCases:
"""Test that histograms accept negative values.""" """Test that histograms accept negative values."""
meter = otel_provider.get_meter("test.meter") meter = otel_provider.get_meter("test.meter")
histogram = meter.create_histogram("test.histogram") histogram = meter.create_histogram("test.histogram")
# Should not raise an exception # Should not raise an exception
histogram.record(-10.0) histogram.record(-10.0)
@ -456,7 +422,7 @@ class TestOTelTelemetryProviderEdgeCases:
"""Test that up/down counters work with negative values.""" """Test that up/down counters work with negative values."""
meter = otel_provider.get_meter("test.meter") meter = otel_provider.get_meter("test.meter")
up_down_counter = meter.create_up_down_counter("test.updown") up_down_counter = meter.create_up_down_counter("test.updown")
# Should not raise an exception # Should not raise an exception
up_down_counter.add(-5.0) up_down_counter.add(-5.0)
@ -464,7 +430,7 @@ class TestOTelTelemetryProviderEdgeCases:
"""Test that empty attributes dict is handled correctly.""" """Test that empty attributes dict is handled correctly."""
meter = otel_provider.get_meter("test.meter") meter = otel_provider.get_meter("test.meter")
counter = meter.create_counter("test.counter") counter = meter.create_counter("test.counter")
# Should not raise an exception # Should not raise an exception
counter.add(1.0, attributes={}) counter.add(1.0, attributes={})
@ -472,7 +438,7 @@ class TestOTelTelemetryProviderEdgeCases:
"""Test that None attributes are handled correctly.""" """Test that None attributes are handled correctly."""
meter = otel_provider.get_meter("test.meter") meter = otel_provider.get_meter("test.meter")
counter = meter.create_counter("test.counter") counter = meter.create_counter("test.counter")
# Should not raise an exception # Should not raise an exception
counter.add(1.0, attributes=None) counter.add(1.0, attributes=None)
@ -484,28 +450,28 @@ class TestOTelTelemetryProviderRealisticScenarios:
"""Simulate telemetry for a complete inference request.""" """Simulate telemetry for a complete inference request."""
tracer = otel_provider.get_tracer("llama_stack.inference") tracer = otel_provider.get_tracer("llama_stack.inference")
meter = otel_provider.get_meter("llama_stack.metrics") meter = otel_provider.get_meter("llama_stack.metrics")
# Create instruments # Create instruments
request_counter = meter.create_counter("llama.requests.total") request_counter = meter.create_counter("llama.requests.total")
token_counter = meter.create_counter("llama.tokens.total") token_counter = meter.create_counter("llama.tokens.total")
latency_histogram = meter.create_histogram("llama.request.duration_ms") latency_histogram = meter.create_histogram("llama.request.duration_ms")
in_flight_gauge = meter.create_up_down_counter("llama.requests.in_flight") in_flight_gauge = meter.create_up_down_counter("llama.requests.in_flight")
# Simulate request # Simulate request
with tracer.start_as_current_span("inference.request") as request_span: with tracer.start_as_current_span("inference.request") as request_span:
request_span.set_attribute("model.id", "llama-3.2-1b") request_span.set_attribute("model.id", "llama-3.2-1b")
request_span.set_attribute("user.id", "test-user") request_span.set_attribute("user.id", "test-user")
request_counter.add(1, {"model": "llama-3.2-1b"}) request_counter.add(1, {"model": "llama-3.2-1b"})
in_flight_gauge.add(1) in_flight_gauge.add(1)
# Simulate token counting # Simulate token counting
token_counter.add(25, {"type": "input", "model": "llama-3.2-1b"}) token_counter.add(25, {"type": "input", "model": "llama-3.2-1b"})
token_counter.add(150, {"type": "output", "model": "llama-3.2-1b"}) token_counter.add(150, {"type": "output", "model": "llama-3.2-1b"})
# Simulate latency # Simulate latency
latency_histogram.record(125.5, {"model": "llama-3.2-1b"}) latency_histogram.record(125.5, {"model": "llama-3.2-1b"})
in_flight_gauge.add(-1) in_flight_gauge.add(-1)
request_span.set_attribute("tokens.input", 25) request_span.set_attribute("tokens.input", 25)
request_span.set_attribute("tokens.output", 150) request_span.set_attribute("tokens.output", 150)
@ -514,36 +480,36 @@ class TestOTelTelemetryProviderRealisticScenarios:
"""Simulate a multi-step workflow with nested spans.""" """Simulate a multi-step workflow with nested spans."""
tracer = otel_provider.get_tracer("llama_stack.workflow") tracer = otel_provider.get_tracer("llama_stack.workflow")
meter = otel_provider.get_meter("llama_stack.workflow.metrics") meter = otel_provider.get_meter("llama_stack.workflow.metrics")
step_counter = meter.create_counter("workflow.steps.completed") step_counter = meter.create_counter("workflow.steps.completed")
with tracer.start_as_current_span("workflow.execute") as root_span: with tracer.start_as_current_span("workflow.execute") as root_span:
root_span.set_attribute("workflow.id", "wf-123") root_span.set_attribute("workflow.id", "wf-123")
# Step 1: Validate # Step 1: Validate
with tracer.start_as_current_span("step.validate") as span: with tracer.start_as_current_span("step.validate") as span:
span.set_attribute("validation.result", "pass") span.set_attribute("validation.result", "pass")
step_counter.add(1, {"step": "validate", "status": "success"}) step_counter.add(1, {"step": "validate", "status": "success"})
# Step 2: Process # Step 2: Process
with tracer.start_as_current_span("step.process") as span: with tracer.start_as_current_span("step.process") as span:
span.set_attribute("items.processed", 100) span.set_attribute("items.processed", 100)
step_counter.add(1, {"step": "process", "status": "success"}) step_counter.add(1, {"step": "process", "status": "success"})
# Step 3: Finalize # Step 3: Finalize
with tracer.start_as_current_span("step.finalize") as span: with tracer.start_as_current_span("step.finalize") as span:
span.set_attribute("output.size", 1024) span.set_attribute("output.size", 1024)
step_counter.add(1, {"step": "finalize", "status": "success"}) step_counter.add(1, {"step": "finalize", "status": "success"})
root_span.set_attribute("workflow.status", "completed") root_span.set_attribute("workflow.status", "completed")
def test_error_handling_with_telemetry(self, otel_provider): def test_error_handling_with_telemetry(self, otel_provider):
"""Test telemetry when errors occur.""" """Test telemetry when errors occur."""
tracer = otel_provider.get_tracer("llama_stack.errors") tracer = otel_provider.get_tracer("llama_stack.errors")
meter = otel_provider.get_meter("llama_stack.errors.metrics") meter = otel_provider.get_meter("llama_stack.errors.metrics")
error_counter = meter.create_counter("llama.errors.total") error_counter = meter.create_counter("llama.errors.total")
with tracer.start_as_current_span("operation.with.error") as span: with tracer.start_as_current_span("operation.with.error") as span:
try: try:
span.set_attribute("step", "processing") span.set_attribute("step", "processing")
@ -553,24 +519,24 @@ class TestOTelTelemetryProviderRealisticScenarios:
span.record_exception(e) span.record_exception(e)
span.set_status(trace.Status(trace.StatusCode.ERROR, str(e))) span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
error_counter.add(1, {"error.type": "ValueError"}) error_counter.add(1, {"error.type": "ValueError"})
# Should not raise - error was handled # Should not raise - error was handled
def test_batch_operations_telemetry(self, otel_provider): def test_batch_operations_telemetry(self, otel_provider):
"""Test telemetry for batch operations.""" """Test telemetry for batch operations."""
tracer = otel_provider.get_tracer("llama_stack.batch") tracer = otel_provider.get_tracer("llama_stack.batch")
meter = otel_provider.get_meter("llama_stack.batch.metrics") meter = otel_provider.get_meter("llama_stack.batch.metrics")
batch_counter = meter.create_counter("llama.batch.items.processed") batch_counter = meter.create_counter("llama.batch.items.processed")
batch_duration = meter.create_histogram("llama.batch.duration_ms") batch_duration = meter.create_histogram("llama.batch.duration_ms")
with tracer.start_as_current_span("batch.process") as batch_span: with tracer.start_as_current_span("batch.process") as batch_span:
batch_span.set_attribute("batch.size", 100) batch_span.set_attribute("batch.size", 100)
for i in range(100): for i in range(100):
with tracer.start_as_current_span(f"item.{i}") as item_span: with tracer.start_as_current_span(f"item.{i}") as item_span:
item_span.set_attribute("item.index", i) item_span.set_attribute("item.index", i)
batch_counter.add(1, {"status": "success"}) batch_counter.add(1, {"status": "success"})
batch_duration.record(5000.0, {"batch.size": "100"}) batch_duration.record(5000.0, {"batch.size": "100"})
batch_span.set_attribute("batch.status", "completed") batch_span.set_attribute("batch.status", "completed")