fix(pr specific): passes pre-commit

This commit is contained in:
Emilio Garcia 2025-10-03 12:35:09 -04:00
parent 4aa2dc110d
commit 2b7a765d02
20 changed files with 547 additions and 516 deletions

View file

@ -0,0 +1,33 @@
---
description: "Native OpenTelemetry provider with full access to OTel Tracer and Meter APIs for advanced instrumentation."
sidebar_label: Otel
title: inline::otel
---
# inline::otel
## Description
Native OpenTelemetry provider with full access to OTel Tracer and Meter APIs for advanced instrumentation.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `service_name` | `<class 'str'>` | No | | The name of the service to be monitored.
Is overridden by the OTEL_SERVICE_NAME or OTEL_RESOURCE_ATTRIBUTES environment variables. |
| `service_version` | `str \| None` | No | | The version of the service to be monitored.
Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable. |
| `deployment_environment` | `str \| None` | No | | The name of the environment of the service to be monitored.
Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable. |
| `span_processor` | `BatchSpanProcessor \| SimpleSpanProcessor \| None` | No | batch | The span processor to use.
Is overriden by the OTEL_SPAN_PROCESSOR environment variable. |
## Sample Configuration
```yaml
service_name: ${env.OTEL_SERVICE_NAME:=llama-stack}
service_version: ${env.OTEL_SERVICE_VERSION:=}
deployment_environment: ${env.OTEL_DEPLOYMENT_ENVIRONMENT:=}
span_processor: ${env.OTEL_SPAN_PROCESSOR:=batch}
```

View file

@ -32,7 +32,7 @@ from termcolor import cprint
from llama_stack.core.build import print_pip_install_help
from llama_stack.core.configure import parse_and_maybe_upgrade_config
from llama_stack.core.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
@ -49,7 +49,6 @@ from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.core.utils.exec import in_notebook
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core")
T = TypeVar("T")

View file

@ -63,7 +63,6 @@ from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from .auth import AuthenticationMiddleware
from .quota import QuotaMiddleware
@ -236,9 +235,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
try:
if is_streaming:
gen = preserve_contexts_async_generator(
sse_generator(func(**kwargs)), [PROVIDER_DATA_VAR]
)
gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), [PROVIDER_DATA_VAR])
return StreamingResponse(gen, media_type="text/event-stream")
else:
value = func(**kwargs)
@ -282,7 +279,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
]
)
setattr(route_handler, "__signature__", sig.replace(parameters=new_params))
route_handler.__signature__ = sig.replace(parameters=new_params)
return route_handler

View file

@ -359,7 +359,6 @@ class Stack:
await refresh_registry_once(impls)
self.impls = impls
# safely access impls without raising an exception
def get_impls(self) -> dict[Api, Any]:
if self.impls is None:

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import abstractmethod
from fastapi import FastAPI
from pydantic import BaseModel
from opentelemetry.trace import Tracer
from fastapi import FastAPI
from opentelemetry.metrics import Meter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.resources import Attributes
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import Tracer
from pydantic import BaseModel
from sqlalchemy import Engine
@ -19,6 +19,7 @@ class TelemetryProvider(BaseModel):
"""
TelemetryProvider standardizes how telemetry is provided to the application.
"""
@abstractmethod
def fastapi_middleware(self, app: FastAPI, *args, **kwargs):
"""
@ -34,12 +35,13 @@ class TelemetryProvider(BaseModel):
...
@abstractmethod
def get_tracer(self,
instrumenting_module_name: str,
instrumenting_library_version: str | None = None,
tracer_provider: TracerProvider | None = None,
schema_url: str | None = None,
attributes: Attributes | None = None
def get_tracer(
self,
instrumenting_module_name: str,
instrumenting_library_version: str | None = None,
tracer_provider: TracerProvider | None = None,
schema_url: str | None = None,
attributes: Attributes | None = None,
) -> Tracer:
"""
Gets a tracer.
@ -47,11 +49,14 @@ class TelemetryProvider(BaseModel):
...
@abstractmethod
def get_meter(self, name: str,
version: str = "",
meter_provider: MeterProvider | None = None,
schema_url: str | None = None,
attributes: Attributes | None = None) -> Meter:
def get_meter(
self,
name: str,
version: str = "",
meter_provider: MeterProvider | None = None,
schema_url: str | None = None,
attributes: Attributes | None = None,
) -> Meter:
"""
Gets a meter.
"""

View file

@ -1,15 +1,22 @@
from aiohttp import hdrs
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from aiohttp import hdrs
from llama_stack.apis.datatypes import Api
from llama_stack.core.external import ExternalApiSpec
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace
logger = get_logger(name=__name__, category="telemetry::meta_reference")
class TracingMiddleware:
def __init__(
self,

View file

@ -10,7 +10,6 @@ import threading
from typing import Any, cast
from fastapi import FastAPI
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
@ -23,11 +22,6 @@ from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.util.types import Attributes
from llama_stack.core.external import ExternalApiSpec
from llama_stack.core.server.tracing import TelemetryProvider
from llama_stack.providers.inline.telemetry.meta_reference.middleware import TracingMiddleware
from llama_stack.apis.telemetry import (
Event,
MetricEvent,
@ -47,10 +41,13 @@ from llama_stack.apis.telemetry import (
UnstructuredLogEvent,
)
from llama_stack.core.datatypes import Api
from llama_stack.core.external import ExternalApiSpec
from llama_stack.core.server.tracing import TelemetryProvider
from llama_stack.log import get_logger
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor,
)
from llama_stack.providers.inline.telemetry.meta_reference.middleware import TracingMiddleware
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
SQLiteSpanProcessor,
)

View file

@ -21,4 +21,3 @@ async def get_provider_impl(config: OTelTelemetryConfig, deps):
# The provider is synchronously initialized via Pydantic model_post_init
# No async initialization needed
return OTelTelemetryProvider(config=config)

View file

@ -1,8 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Literal
from pydantic import BaseModel, Field
type BatchSpanProcessor = Literal["batch"]
type SimpleSpanProcessor = Literal["simple"]
@ -13,6 +18,7 @@ class OTelTelemetryConfig(BaseModel):
Most configuration is set using environment variables.
See https://opentelemetry.io/docs/specs/otel/configuration/sdk-configuration-variables/ for more information.
"""
service_name: str = Field(
description="""The name of the service to be monitored.
Is overridden by the OTEL_SERVICE_NAME or OTEL_RESOURCE_ATTRIBUTES environment variables.""",
@ -20,17 +26,17 @@ class OTelTelemetryConfig(BaseModel):
service_version: str | None = Field(
default=None,
description="""The version of the service to be monitored.
Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable."""
Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable.""",
)
deployment_environment: str | None = Field(
default=None,
description="""The name of the environment of the service to be monitored.
Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable."""
Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable.""",
)
span_processor: BatchSpanProcessor | SimpleSpanProcessor | None = Field(
description="""The span processor to use.
Is overriden by the OTEL_SPAN_PROCESSOR environment variable.""",
default="batch"
default="batch",
)
@classmethod

View file

@ -1,24 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from opentelemetry import trace, metrics
from fastapi import FastAPI
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.metrics import Meter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.resources import Attributes, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SimpleSpanProcessor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.trace import Tracer
from opentelemetry.metrics import Meter
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from sqlalchemy import Engine
from llama_stack.core.telemetry.telemetry import TelemetryProvider
from llama_stack.log import get_logger
from sqlalchemy import Engine
from .config import OTelTelemetryConfig
from fastapi import FastAPI
logger = get_logger(name=__name__, category="telemetry::otel")
@ -27,6 +31,7 @@ class OTelTelemetryProvider(TelemetryProvider):
"""
A simple Open Telemetry native telemetry provider.
"""
config: OTelTelemetryConfig
def model_post_init(self, __context):
@ -63,10 +68,13 @@ class OTelTelemetryProvider(TelemetryProvider):
# Do not fail the application, but warn the user if the endpoints are not set properly.
if not os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"):
if not os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"):
logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_TRACES_ENDPOINT is not set. Traces will not be exported.")
logger.warning(
"OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_TRACES_ENDPOINT is not set. Traces will not be exported."
)
if not os.environ.get("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT"):
logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_METRICS_ENDPOINT is not set. Metrics will not be exported.")
logger.warning(
"OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_METRICS_ENDPOINT is not set. Metrics will not be exported."
)
def fastapi_middleware(self, app: FastAPI):
"""
@ -85,21 +93,15 @@ class OTelTelemetryProvider(TelemetryProvider):
# Create HTTP metrics following semantic conventions
# https://opentelemetry.io/docs/specs/semconv/http/http-metrics/
request_duration = meter.create_histogram(
"http.server.request.duration",
unit="ms",
description="Duration of HTTP server requests"
"http.server.request.duration", unit="ms", description="Duration of HTTP server requests"
)
active_requests = meter.create_up_down_counter(
"http.server.active_requests",
unit="requests",
description="Number of active HTTP server requests"
"http.server.active_requests", unit="requests", description="Number of active HTTP server requests"
)
request_count = meter.create_counter(
"http.server.request.count",
unit="requests",
description="Total number of HTTP server requests"
"http.server.request.count", unit="requests", description="Total number of HTTP server requests"
)
# Add middleware to record metrics
@ -108,10 +110,13 @@ class OTelTelemetryProvider(TelemetryProvider):
import time
# Record active request
active_requests.add(1, {
"http.method": request.method,
"http.route": request.url.path,
})
active_requests.add(
1,
{
"http.method": request.method,
"http.route": request.url.path,
},
)
start_time = time.time()
status_code = 500 # Default to error
@ -133,48 +138,46 @@ class OTelTelemetryProvider(TelemetryProvider):
request_duration.record(duration_ms, attributes)
request_count.add(1, attributes)
active_requests.add(-1, {
"http.method": request.method,
"http.route": request.url.path,
})
active_requests.add(
-1,
{
"http.method": request.method,
"http.route": request.url.path,
},
)
return response
def sqlalchemy_instrumentation(self, engine: Engine | None = None):
kwargs = {}
if engine:
kwargs["engine"] = engine
SQLAlchemyInstrumentor().instrument(**kwargs)
def get_tracer(self,
instrumenting_module_name: str,
instrumenting_library_version: str | None = None,
tracer_provider: TracerProvider | None = None,
schema_url: str | None = None,
attributes: Attributes | None = None
def get_tracer(
self,
instrumenting_module_name: str,
instrumenting_library_version: str | None = None,
tracer_provider: TracerProvider | None = None,
schema_url: str | None = None,
attributes: Attributes | None = None,
) -> Tracer:
return trace.get_tracer(
instrumenting_module_name=instrumenting_module_name,
instrumenting_library_version=instrumenting_library_version,
tracer_provider=tracer_provider,
schema_url=schema_url,
attributes=attributes
attributes=attributes,
)
def get_meter(self,
name: str,
version: str = "",
meter_provider: MeterProvider | None = None,
schema_url: str | None = None,
attributes: Attributes | None = None
def get_meter(
self,
name: str,
version: str = "",
meter_provider: MeterProvider | None = None,
schema_url: str | None = None,
attributes: Attributes | None = None,
) -> Meter:
return metrics.get_meter(
name=name,
version=version,
meter_provider=meter_provider,
schema_url=schema_url,
attributes=attributes
name=name, version=version, meter_provider=meter_provider, schema_url=schema_url, attributes=attributes
)

View file

@ -3,4 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -67,7 +67,7 @@ class MockRedisServer(MockServerBase):
# Wait for port to be listening
for _ in range(10):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if sock.connect_ex(('localhost', self.port)) == 0:
if sock.connect_ex(("localhost", self.port)) == 0:
sock.close()
return # Ready!
await asyncio.sleep(0.1)
@ -116,6 +116,7 @@ def mock_servers():
yield servers
stop_mock_servers(servers)
# Access specific servers
@pytest.fixture(scope="module")
def mock_redis(mock_servers):

View file

@ -14,9 +14,9 @@ This module provides:
- Mock server harness for parallel async startup
"""
from .harness import MockServerConfig, start_mock_servers_async, stop_mock_servers
from .mock_base import MockServerBase
from .servers import MockOTLPCollector, MockVLLMServer
from .harness import MockServerConfig, start_mock_servers_async, stop_mock_servers
__all__ = [
"MockServerBase",
@ -26,4 +26,3 @@ __all__ = [
"start_mock_servers_async",
"stop_mock_servers",
]

View file

@ -14,7 +14,7 @@ HOW TO ADD A NEW MOCK SERVER:
"""
import asyncio
from typing import Any, Dict, List
from typing import Any
from pydantic import BaseModel, Field
@ -40,10 +40,10 @@ class MockServerConfig(BaseModel):
name: str = Field(description="Display name for logging")
server_class: type = Field(description="Mock server class (must inherit from MockServerBase)")
init_kwargs: Dict[str, Any] = Field(default_factory=dict, description="Kwargs to pass to server constructor")
init_kwargs: dict[str, Any] = Field(default_factory=dict, description="Kwargs to pass to server constructor")
async def start_mock_servers_async(mock_servers_config: List[MockServerConfig]) -> Dict[str, MockServerBase]:
async def start_mock_servers_async(mock_servers_config: list[MockServerConfig]) -> dict[str, MockServerBase]:
"""
Start all mock servers in parallel and wait for them to be ready.
@ -83,14 +83,14 @@ async def start_mock_servers_async(mock_servers_config: List[MockServerConfig])
for server in servers.values():
try:
server.stop()
except:
except Exception:
pass
raise RuntimeError(f"Failed to start mock servers: {e}")
raise RuntimeError(f"Failed to start mock servers: {e}") from None
return servers
def stop_mock_servers(servers: Dict[str, Any]):
def stop_mock_servers(servers: dict[str, Any]):
"""
Stop all mock servers.
@ -99,9 +99,8 @@ def stop_mock_servers(servers: Dict[str, Any]):
"""
for name, server in servers.items():
try:
if hasattr(server, 'get_request_count'):
if hasattr(server, "get_request_count"):
print(f"\n[INFO] {name} received {server.get_request_count()} requests")
server.stop()
except Exception as e:
print(f"[WARN] Error stopping {name}: {e}")

View file

@ -10,9 +10,9 @@ Base class for mock servers with async startup support.
All mock servers should inherit from MockServerBase and implement await_start().
"""
import asyncio
from abc import abstractmethod
from pydantic import BaseModel, Field
from pydantic import BaseModel
class MockServerBase(BaseModel):
@ -66,4 +66,3 @@ class MockServerBase(BaseModel):
This method should gracefully shut down the server.
"""
...

View file

@ -20,7 +20,7 @@ import json
import socket
import threading
import time
from typing import Any, Dict, List
from typing import Any
from pydantic import Field
@ -45,8 +45,8 @@ class MockOTLPCollector(MockServerBase):
port: int = Field(default=4318, description="Port to run collector on")
# Non-Pydantic fields (set after initialization)
traces: List[Dict] = Field(default_factory=list, exclude=True)
metrics: List[Dict] = Field(default_factory=list, exclude=True)
traces: list[dict] = Field(default_factory=list, exclude=True)
metrics: list[dict] = Field(default_factory=list, exclude=True)
server: Any = Field(default=None, exclude=True)
server_thread: Any = Field(default=None, exclude=True)
@ -68,34 +68,38 @@ class MockOTLPCollector(MockServerBase):
"""Suppress HTTP server logs."""
pass
def do_GET(self):
def do_GET(self): # noqa: N802
"""Handle GET requests."""
# No readiness endpoint needed - using await_start() instead
self.send_response(404)
self.end_headers()
def do_POST(self):
def do_POST(self): # noqa: N802
"""Handle OTLP POST requests."""
content_length = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(content_length) if content_length > 0 else b''
content_length = int(self.headers.get("Content-Length", 0))
body = self.rfile.read(content_length) if content_length > 0 else b""
# Store the export request
if '/v1/traces' in self.path:
collector_self.traces.append({
'body': body,
'timestamp': time.time(),
})
elif '/v1/metrics' in self.path:
collector_self.metrics.append({
'body': body,
'timestamp': time.time(),
})
if "/v1/traces" in self.path:
collector_self.traces.append(
{
"body": body,
"timestamp": time.time(),
}
)
elif "/v1/metrics" in self.path:
collector_self.metrics.append(
{
"body": body,
"timestamp": time.time(),
}
)
# Always return success (200 OK)
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(b'{}')
self.wfile.write(b"{}")
return OTLPHandler
@ -107,7 +111,7 @@ class MockOTLPCollector(MockServerBase):
"""
# Create handler and start the HTTP server
handler_class = self._create_handler_class()
self.server = http.server.HTTPServer(('localhost', self.port), handler_class)
self.server = http.server.HTTPServer(("localhost", self.port), handler_class)
self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.server_thread.start()
@ -115,12 +119,12 @@ class MockOTLPCollector(MockServerBase):
for _ in range(10):
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(('localhost', self.port))
result = sock.connect_ex(("localhost", self.port))
sock.close()
if result == 0:
# Port is listening
return
except:
except Exception:
pass
await asyncio.sleep(0.1)
@ -145,11 +149,11 @@ class MockOTLPCollector(MockServerBase):
"""Get number of metric export requests received."""
return len(self.metrics)
def get_all_traces(self) -> List[Dict]:
def get_all_traces(self) -> list[dict]:
"""Get all captured trace exports."""
return self.traces
def get_all_metrics(self) -> List[Dict]:
def get_all_metrics(self) -> list[dict]:
"""Get all captured metric exports."""
return self.metrics
@ -174,13 +178,12 @@ class MockVLLMServer(MockServerBase):
"""
port: int = Field(default=8000, description="Port to run server on")
models: List[str] = Field(
default_factory=lambda: ["meta-llama/Llama-3.2-1B-Instruct"],
description="List of model IDs to serve"
models: list[str] = Field(
default_factory=lambda: ["meta-llama/Llama-3.2-1B-Instruct"], description="List of model IDs to serve"
)
# Non-Pydantic fields
requests_received: List[Dict] = Field(default_factory=list, exclude=True)
requests_received: list[dict] = Field(default_factory=list, exclude=True)
server: Any = Field(default=None, exclude=True)
server_thread: Any = Field(default=None, exclude=True)
@ -201,53 +204,57 @@ class MockVLLMServer(MockServerBase):
"""Suppress HTTP server logs."""
pass
def log_request(self, code='-', size='-'):
def log_request(self, code="-", size="-"):
"""Log incoming requests for debugging."""
print(f"[DEBUG] Mock vLLM received: {self.command} {self.path} -> {code}")
def do_GET(self):
def do_GET(self): # noqa: N802
"""Handle GET requests (models list, health check)."""
# Log GET requests too
server_self.requests_received.append({
'path': self.path,
'method': 'GET',
'timestamp': time.time(),
})
server_self.requests_received.append(
{
"path": self.path,
"method": "GET",
"timestamp": time.time(),
}
)
if self.path == '/v1/models':
if self.path == "/v1/models":
response = self._create_models_list_response()
self._send_json_response(200, response)
elif self.path == '/health' or self.path == '/v1/health':
elif self.path == "/health" or self.path == "/v1/health":
self._send_json_response(200, {"status": "healthy"})
else:
self.send_response(404)
self.end_headers()
def do_POST(self):
def do_POST(self): # noqa: N802
"""Handle POST requests (chat/text completions)."""
content_length = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(content_length) if content_length > 0 else b'{}'
content_length = int(self.headers.get("Content-Length", 0))
body = self.rfile.read(content_length) if content_length > 0 else b"{}"
try:
request_data = json.loads(body)
except:
except Exception:
request_data = {}
# Log the request
server_self.requests_received.append({
'path': self.path,
'body': request_data,
'timestamp': time.time(),
})
server_self.requests_received.append(
{
"path": self.path,
"body": request_data,
"timestamp": time.time(),
}
)
# Route to appropriate handler
if '/chat/completions' in self.path:
if "/chat/completions" in self.path:
response = self._create_chat_completion_response(request_data)
self._send_json_response(200, response)
elif '/completions' in self.path:
elif "/completions" in self.path:
response = self._create_text_completion_response(request_data)
self._send_json_response(200, response)
@ -259,7 +266,7 @@ class MockVLLMServer(MockServerBase):
# **TO MODIFY RESPONSES:** Edit these methods
# ----------------------------------------------------------------
def _create_models_list_response(self) -> Dict:
def _create_models_list_response(self) -> dict:
"""Create OpenAI models list response with configured models."""
return {
"object": "list",
@ -271,10 +278,10 @@ class MockVLLMServer(MockServerBase):
"owned_by": "meta",
}
for model_id in server_self.models
]
],
}
def _create_chat_completion_response(self, request_data: Dict) -> Dict:
def _create_chat_completion_response(self, request_data: dict) -> dict:
"""
Create OpenAI ChatCompletion response.
@ -285,16 +292,18 @@ class MockVLLMServer(MockServerBase):
"object": "chat.completion",
"created": int(time.time()),
"model": request_data.get("model", "meta-llama/Llama-3.2-1B-Instruct"),
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a test response from mock vLLM server.",
"tool_calls": None,
},
"logprobs": None,
"finish_reason": "stop",
}],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a test response from mock vLLM server.",
"tool_calls": None,
},
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 25,
"completion_tokens": 15,
@ -305,7 +314,7 @@ class MockVLLMServer(MockServerBase):
"service_tier": None,
}
def _create_text_completion_response(self, request_data: Dict) -> Dict:
def _create_text_completion_response(self, request_data: dict) -> dict:
"""
Create OpenAI Completion response.
@ -316,12 +325,14 @@ class MockVLLMServer(MockServerBase):
"object": "text_completion",
"created": int(time.time()),
"model": request_data.get("model", "meta-llama/Llama-3.2-1B-Instruct"),
"choices": [{
"text": "This is a test completion.",
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}],
"choices": [
{
"text": "This is a test completion.",
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 8,
@ -331,10 +342,10 @@ class MockVLLMServer(MockServerBase):
"system_fingerprint": None,
}
def _send_json_response(self, status_code: int, data: Dict):
def _send_json_response(self, status_code: int, data: dict):
"""Helper to send JSON response."""
self.send_response(status_code)
self.send_header('Content-Type', 'application/json')
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(data).encode())
@ -348,7 +359,7 @@ class MockVLLMServer(MockServerBase):
"""
# Create handler and start the HTTP server
handler_class = self._create_handler_class()
self.server = http.server.HTTPServer(('localhost', self.port), handler_class)
self.server = http.server.HTTPServer(("localhost", self.port), handler_class)
self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.server_thread.start()
@ -356,12 +367,12 @@ class MockVLLMServer(MockServerBase):
for _ in range(10):
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(('localhost', self.port))
result = sock.connect_ex(("localhost", self.port))
sock.close()
if result == 0:
# Port is listening
return
except:
except Exception:
pass
await asyncio.sleep(0.1)
@ -381,7 +392,6 @@ class MockVLLMServer(MockServerBase):
"""Get number of requests received."""
return len(self.requests_received)
def get_all_requests(self) -> List[Dict]:
def get_all_requests(self) -> list[dict]:
"""Get all received requests with their bodies."""
return self.requests_received

View file

@ -34,7 +34,7 @@ import os
import socket
import subprocess
import time
from typing import Any, Dict, List
from typing import Any
import pytest
import requests
@ -44,17 +44,17 @@ from pydantic import BaseModel, Field
# Mock servers are in the mocking/ subdirectory
from .mocking import (
MockOTLPCollector,
MockVLLMServer,
MockServerConfig,
MockVLLMServer,
start_mock_servers_async,
stop_mock_servers,
)
# ============================================================================
# DATA MODELS
# ============================================================================
class TelemetryTestCase(BaseModel):
"""
Pydantic model defining expected telemetry for an API call.
@ -65,7 +65,7 @@ class TelemetryTestCase(BaseModel):
name: str = Field(description="Unique test case identifier")
http_method: str = Field(description="HTTP method (GET, POST, etc.)")
api_path: str = Field(description="API path (e.g., '/v1/models')")
request_body: Dict[str, Any] | None = Field(default=None)
request_body: dict[str, Any] | None = Field(default=None)
expected_http_status: int = Field(default=200)
expected_trace_exports: int = Field(default=1, description="Minimum number of trace exports expected")
expected_metric_exports: int = Field(default=0, description="Minimum number of metric exports expected")
@ -103,6 +103,7 @@ TEST_CASES = [
# TEST INFRASTRUCTURE
# ============================================================================
class TelemetryTestRunner:
"""
Executes TelemetryTestCase instances against real Llama Stack.
@ -160,14 +161,16 @@ class TelemetryTestRunner:
metrics_exported = new_metrics >= test_case.expected_metric_exports
if verbose:
print(f" Expected: >={test_case.expected_trace_exports} trace exports, >={test_case.expected_metric_exports} metric exports")
print(
f" Expected: >={test_case.expected_trace_exports} trace exports, >={test_case.expected_metric_exports} metric exports"
)
print(f" Actual: {new_traces} trace exports, {new_metrics} metric exports")
result = status_match and traces_exported and metrics_exported
print(f" Result: {'PASS' if result else 'FAIL'}")
return status_match and traces_exported and metrics_exported
def run_all_test_cases(self, test_cases: List[TelemetryTestCase], verbose: bool = True) -> Dict[str, bool]:
def run_all_test_cases(self, test_cases: list[TelemetryTestCase], verbose: bool = True) -> dict[str, bool]:
"""Run all test cases and return results."""
results = {}
for test_case in test_cases:
@ -179,11 +182,12 @@ class TelemetryTestRunner:
# HELPER FUNCTIONS
# ============================================================================
def is_port_available(port: int) -> bool:
"""Check if a TCP port is available for binding."""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(('localhost', port))
sock.bind(("localhost", port))
return True
except OSError:
return False
@ -193,6 +197,7 @@ def is_port_available(port: int) -> bool:
# PYTEST FIXTURES
# ============================================================================
@pytest.fixture(scope="module")
def mock_servers():
"""
@ -214,7 +219,7 @@ def mock_servers():
# init_kwargs={"port": 9000, "param": "value"},
# ),
# ========================================================================
MOCK_SERVERS = [
mock_servers_config = [
MockServerConfig(
name="Mock OTLP Collector",
server_class=MockOTLPCollector,
@ -232,7 +237,7 @@ def mock_servers():
]
# Start all servers in parallel
servers = asyncio.run(start_mock_servers_async(MOCK_SERVERS))
servers = asyncio.run(start_mock_servers_async(mock_servers_config))
# Verify vLLM models
models_response = requests.get("http://localhost:8000/v1/models", timeout=1)
@ -270,7 +275,7 @@ def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server):
config_dir = tmp_path_factory.mktemp("otel-stack-config")
# Ensure mock vLLM is ready and accessible before starting Llama Stack
print(f"\n[INFO] Verifying mock vLLM is accessible at http://localhost:8000...")
print("\n[INFO] Verifying mock vLLM is accessible at http://localhost:8000...")
try:
vllm_models = requests.get("http://localhost:8000/v1/models", timeout=2)
print(f"[INFO] Mock vLLM models endpoint response: {vllm_models.status_code}")
@ -336,9 +341,12 @@ def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server):
# Start server with automatic instrumentation
cmd = [
"opentelemetry-instrument", # ← Automatic instrumentation wrapper
"llama", "stack", "run",
"llama",
"stack",
"run",
str(config_file),
"--port", str(port),
"--port",
str(port),
]
print(f"\n[INFO] Starting Llama Stack with OTel instrumentation on port {port}")
@ -370,14 +378,14 @@ def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server):
time.sleep(1)
yield {
'base_url': base_url,
'port': port,
'collector': mock_otlp_collector,
'vllm_server': mock_vllm_server,
"base_url": base_url,
"port": port,
"collector": mock_otlp_collector,
"vllm_server": mock_vllm_server,
}
# Cleanup
print(f"\n[INFO] Stopping Llama Stack server")
print("\n[INFO] Stopping Llama Stack server")
process.terminate()
try:
process.wait(timeout=5)
@ -391,6 +399,7 @@ def llama_stack_server(tmp_path_factory, mock_otlp_collector, mock_vllm_server):
# **TO ADD NEW E2E TESTS:** Add methods to this class
# ============================================================================
@pytest.mark.slow
class TestOTelE2E:
"""
@ -405,7 +414,7 @@ class TestOTelE2E:
def test_server_starts_with_auto_instrumentation(self, llama_stack_server):
"""Verify server starts successfully with opentelemetry-instrument."""
base_url = llama_stack_server['base_url']
base_url = llama_stack_server["base_url"]
# Try different health check endpoints
health_endpoints = ["/health", "/v1/health", "/"]
@ -432,8 +441,8 @@ class TestOTelE2E:
This executes all test cases defined in TEST_CASES list.
**TO ADD MORE TESTS:** Add to TEST_CASES at top of file
"""
base_url = llama_stack_server['base_url']
collector = llama_stack_server['collector']
base_url = llama_stack_server["base_url"]
collector = llama_stack_server["collector"]
# Create test runner
runner = TelemetryTestRunner(base_url, collector)
@ -442,9 +451,9 @@ class TestOTelE2E:
results = runner.run_all_test_cases(TEST_CASES, verbose=True)
# Print summary
print(f"\n{'='*50}")
print(f"TEST CASE SUMMARY")
print(f"{'='*50}")
print(f"\n{'=' * 50}")
print("TEST CASE SUMMARY")
print(f"{'=' * 50}")
passed = sum(1 for p in results.values() if p)
total = len(results)
print(f"Passed: {passed}/{total}\n")
@ -452,4 +461,4 @@ class TestOTelE2E:
for name, result in results.items():
status = "[PASS]" if result else "[FAIL]"
print(f" {status} {name}")
print(f"{'='*50}\n")
print(f"{'=' * 50}\n")

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import pytest
import llama_stack.providers.inline.telemetry.meta_reference.telemetry as telemetry_module
@ -38,7 +36,7 @@ def test_warns_when_traces_endpoints_missing(monkeypatch: pytest.MonkeyPatch, ca
monkeypatch.delenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", raising=False)
monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False)
caplog.set_level(logging.WARNING)
caplog.set_level("WARNING")
config = _make_config_with_sinks(TelemetrySink.OTEL_TRACE)
telemetry_module.TelemetryAdapter(config=config, deps={})
@ -57,7 +55,7 @@ def test_warns_when_metrics_endpoints_missing(monkeypatch: pytest.MonkeyPatch, c
monkeypatch.delenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", raising=False)
monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False)
caplog.set_level(logging.WARNING)
caplog.set_level("WARNING")
config = _make_config_with_sinks(TelemetrySink.OTEL_METRIC)
telemetry_module.TelemetryAdapter(config=config, deps={})
@ -76,7 +74,7 @@ def test_no_warning_when_traces_endpoints_present(monkeypatch: pytest.MonkeyPatc
monkeypatch.setenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "https://otel.example:4318/v1/traces")
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://otel.example:4318")
caplog.set_level(logging.WARNING)
caplog.set_level("WARNING")
config = _make_config_with_sinks(TelemetrySink.OTEL_TRACE)
telemetry_module.TelemetryAdapter(config=config, deps={})
@ -91,7 +89,7 @@ def test_no_warning_when_metrics_endpoints_present(monkeypatch: pytest.MonkeyPat
monkeypatch.setenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", "https://otel.example:4318/v1/metrics")
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://otel.example:4318")
caplog.set_level(logging.WARNING)
caplog.set_level("WARNING")
config = _make_config_with_sinks(TelemetrySink.OTEL_METRIC)
telemetry_module.TelemetryAdapter(config=config, deps={})

View file

@ -97,8 +97,7 @@ class TestOTelTelemetryProviderTracerAPI:
def test_get_tracer_with_version(self, otel_provider):
"""Test that get_tracer works with version parameter."""
tracer = otel_provider.get_tracer(
instrumenting_module_name="test.module",
instrumenting_library_version="1.0.0"
instrumenting_module_name="test.module", instrumenting_library_version="1.0.0"
)
assert tracer is not None
@ -107,8 +106,7 @@ class TestOTelTelemetryProviderTracerAPI:
def test_get_tracer_with_attributes(self, otel_provider):
"""Test that get_tracer works with attributes."""
tracer = otel_provider.get_tracer(
instrumenting_module_name="test.module",
attributes={"component": "test", "tier": "backend"}
instrumenting_module_name="test.module", attributes={"component": "test", "tier": "backend"}
)
assert tracer is not None
@ -117,8 +115,7 @@ class TestOTelTelemetryProviderTracerAPI:
def test_get_tracer_with_schema_url(self, otel_provider):
"""Test that get_tracer works with schema URL."""
tracer = otel_provider.get_tracer(
instrumenting_module_name="test.module",
schema_url="https://example.com/schema"
instrumenting_module_name="test.module", schema_url="https://example.com/schema"
)
assert tracer is not None
@ -136,10 +133,7 @@ class TestOTelTelemetryProviderTracerAPI:
"""Test that tracer can create spans with attributes."""
tracer = otel_provider.get_tracer("test.module")
with tracer.start_as_current_span(
"test.operation",
attributes={"user.id": "123", "request.id": "abc"}
) as span:
with tracer.start_as_current_span("test.operation", attributes={"user.id": "123", "request.id": "abc"}) as span:
assert span is not None
assert span.is_recording()
@ -170,30 +164,21 @@ class TestOTelTelemetryProviderMeterAPI:
def test_get_meter_with_version(self, otel_provider):
"""Test that get_meter works with version parameter."""
meter = otel_provider.get_meter(
name="test.meter",
version="1.0.0"
)
meter = otel_provider.get_meter(name="test.meter", version="1.0.0")
assert meter is not None
assert isinstance(meter, Meter)
def test_get_meter_with_attributes(self, otel_provider):
"""Test that get_meter works with attributes."""
meter = otel_provider.get_meter(
name="test.meter",
attributes={"service": "test", "env": "dev"}
)
meter = otel_provider.get_meter(name="test.meter", attributes={"service": "test", "env": "dev"})
assert meter is not None
assert isinstance(meter, Meter)
def test_get_meter_with_schema_url(self, otel_provider):
"""Test that get_meter works with schema URL."""
meter = otel_provider.get_meter(
name="test.meter",
schema_url="https://example.com/schema"
)
meter = otel_provider.get_meter(name="test.meter", schema_url="https://example.com/schema")
assert meter is not None
assert isinstance(meter, Meter)
@ -202,11 +187,7 @@ class TestOTelTelemetryProviderMeterAPI:
"""Test that meter can create counters."""
meter = otel_provider.get_meter("test.meter")
counter = meter.create_counter(
"test.requests.total",
unit="requests",
description="Total requests"
)
counter = meter.create_counter("test.requests.total", unit="requests", description="Total requests")
assert counter is not None
# Test that counter can be used
@ -216,11 +197,7 @@ class TestOTelTelemetryProviderMeterAPI:
"""Test that meter can create histograms."""
meter = otel_provider.get_meter("test.meter")
histogram = meter.create_histogram(
"test.request.duration",
unit="ms",
description="Request duration"
)
histogram = meter.create_histogram("test.request.duration", unit="ms", description="Request duration")
assert histogram is not None
# Test that histogram can be used
@ -231,9 +208,7 @@ class TestOTelTelemetryProviderMeterAPI:
meter = otel_provider.get_meter("test.meter")
up_down_counter = meter.create_up_down_counter(
"test.active.connections",
unit="connections",
description="Active connections"
"test.active.connections", unit="connections", description="Active connections"
)
assert up_down_counter is not None
@ -249,10 +224,7 @@ class TestOTelTelemetryProviderMeterAPI:
return [{"attributes": {"host": "localhost"}, "value": 42.0}]
gauge = meter.create_observable_gauge(
"test.memory.usage",
callbacks=[gauge_callback],
unit="bytes",
description="Memory usage"
"test.memory.usage", callbacks=[gauge_callback], unit="bytes", description="Memory usage"
)
assert gauge is not None
@ -305,22 +277,14 @@ class TestOTelTelemetryProviderNativeUsage:
meter = otel_provider.get_meter("llama_stack.metrics")
# Create various instruments
request_counter = meter.create_counter(
"llama.requests.total",
unit="requests",
description="Total requests"
)
request_counter = meter.create_counter("llama.requests.total", unit="requests", description="Total requests")
latency_histogram = meter.create_histogram(
"llama.inference.duration",
unit="ms",
description="Inference duration"
"llama.inference.duration", unit="ms", description="Inference duration"
)
active_sessions = meter.create_up_down_counter(
"llama.sessions.active",
unit="sessions",
description="Active sessions"
"llama.sessions.active", unit="sessions", description="Active sessions"
)
# Use the instruments
@ -333,6 +297,7 @@ class TestOTelTelemetryProviderNativeUsage:
def test_concurrent_tracer_usage(self, otel_provider):
"""Test that multiple threads can use tracers concurrently."""
def create_spans(thread_id):
tracer = otel_provider.get_tracer(f"test.module.{thread_id}")
for i in range(10):
@ -348,6 +313,7 @@ class TestOTelTelemetryProviderNativeUsage:
def test_concurrent_meter_usage(self, otel_provider):
"""Test that multiple threads can use meters concurrently."""
def record_metrics(thread_id):
meter = otel_provider.get_meter(f"test.meter.{thread_id}")
counter = meter.create_counter(f"test.counter.{thread_id}")