feat(major): move new telemetry architecture into new provider

This commit is contained in:
Emilio Garcia 2025-10-01 11:54:14 -04:00
parent ce3a804893
commit e45e77f7b0
10 changed files with 207 additions and 52 deletions

View file

@ -62,18 +62,10 @@ from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_dis
from llama_stack.core.utils.context import preserve_contexts_async_generator from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
setup_logger,
)
from .auth import AuthenticationMiddleware from .auth import AuthenticationMiddleware
from .quota import QuotaMiddleware from .quota import QuotaMiddleware
from .tracing import TracingMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -243,7 +235,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
try: try:
if is_streaming: if is_streaming:
gen = preserve_contexts_async_generator( gen = preserve_contexts_async_generator(
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR] sse_generator(func(**kwargs)), [PROVIDER_DATA_VAR]
) )
return StreamingResponse(gen, media_type="text/event-stream") return StreamingResponse(gen, media_type="text/event-stream")
else: else:
@ -288,8 +280,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
] ]
) )
route_handler.__signature__ = sig.replace(parameters=new_params) setattr(route_handler, "__signature__", sig.replace(parameters=new_params))
return route_handler return route_handler
@ -351,11 +342,12 @@ def create_app(
if config_file is None: if config_file is None:
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set") raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
config_file = resolve_config_or_distro(config_file, Mode.RUN) config_path = resolve_config_or_distro(config_file, Mode.RUN)
# Load and process configuration # Load and process configuration
logger_config = None logger_config = None
with open(config_file) as fp:
with open(config_path) as fp:
config_contents = yaml.safe_load(fp) config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg) logger_config = LoggingConfig(**cfg)
@ -387,7 +379,7 @@ def create_app(
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware) app.add_middleware(ClientVersionMiddleware)
impls = app.stack.impls impls = app.stack.get_impls()
if config.server.auth: if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}") logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
@ -429,11 +421,7 @@ def create_app(
app.add_middleware(CORSMiddleware, **cors_config.model_dump()) app.add_middleware(CORSMiddleware, **cors_config.model_dump())
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) impls[Api.telemetry].fastapi_middleware(app)
if impls[Api.telemetry].fastapi_middleware:
impls[Api.telemetry].fastapi_middleware(app)
else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
# Load external APIs if configured # Load external APIs if configured
external_apis = load_external_apis(config) external_apis = load_external_apis(config)
@ -442,7 +430,7 @@ def create_app(
if config.apis: if config.apis:
apis_to_serve = set(config.apis) apis_to_serve = set(config.apis)
else: else:
apis_to_serve = set(impls.keys()) apis_to_serve = {api.value for api in impls.keys()}
for inf in builtin_automatically_routed_apis(): for inf in builtin_automatically_routed_apis():
# if we do not serve the corresponding router API, we should not serve the routing table API # if we do not serve the corresponding router API, we should not serve the routing table API
@ -470,7 +458,8 @@ def create_app(
impl_method = getattr(impl, route.name) impl_method = getattr(impl, route.name)
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes # Filter out HEAD method since it's automatically handled by FastAPI for GET routes
available_methods = [m for m in route.methods if m != "HEAD"] route_methods = route.methods or []
available_methods = [m for m in route_methods if m != "HEAD"]
if not available_methods: if not available_methods:
raise ValueError(f"No methods found for {route.name} on {impl}") raise ValueError(f"No methods found for {route.name} on {impl}")
method = available_methods[0] method = available_methods[0]
@ -491,8 +480,6 @@ def create_app(
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
return app return app
@ -530,8 +517,8 @@ def main(args: argparse.Namespace | None = None):
logger.error(f"Error creating app: {str(e)}") logger.error(f"Error creating app: {str(e)}")
sys.exit(1) sys.exit(1)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN) config_path = resolve_config_or_distro(config_or_distro, Mode.RUN)
with open(config_file) as fp: with open(config_path) as fp:
config_contents = yaml.safe_load(fp) config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg) logger_config = LoggingConfig(**cfg)

View file

@ -359,6 +359,13 @@ class Stack:
await refresh_registry_once(impls) await refresh_registry_once(impls)
self.impls = impls self.impls = impls
# safely access impls without raising an exception
def get_impls(self) -> dict[Api, Any]:
if self.impls is None:
return {}
return self.impls
def create_registry_refresh_task(self): def create_registry_refresh_task(self):
assert self.impls is not None, "Must call initialize() before starting" assert self.impls is not None, "Must call initialize() before starting"

View file

@ -0,0 +1,4 @@
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import abstractmethod
from fastapi import FastAPI
from pydantic import BaseModel
class TelemetryProvider(BaseModel):
"""
TelemetryProvider standardizes how telemetry is provided to the application.
"""
@abstractmethod
def fastapi_middleware(self, app: FastAPI, *args, **kwargs):
"""
Injects FastAPI middleware that instruments the application for telemetry.
"""
...

View file

@ -1,20 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from aiohttp import hdrs from aiohttp import hdrs
from typing import Any
from llama_stack.apis.datatypes import Api
from llama_stack.core.external import ExternalApiSpec from llama_stack.core.external import ExternalApiSpec
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace
logger = get_logger(name=__name__, category="core::server")
logger = get_logger(name=__name__, category="telemetry::meta_reference")
class TracingMiddleware: class TracingMiddleware:
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]): def __init__(
self,
app,
impls: dict[Api, Any],
external_apis: dict[str, ExternalApiSpec],
):
self.app = app self.app = app
self.impls = impls self.impls = impls
self.external_apis = external_apis self.external_apis = external_apis
@ -34,7 +36,8 @@ class TracingMiddleware:
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"): if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls, self.external_apis) external_api_map = {Api(api_name): spec for api_name, spec in self.external_apis.items()}
self.route_impls = initialize_route_impls(self.impls, external_api_map)
try: try:
_, _, route_path, webmethod = find_matching_route( _, _, route_path, webmethod = find_matching_route(

View file

@ -7,8 +7,7 @@
import datetime import datetime
import os import os
import threading import threading
import logging from typing import Any, cast
from typing import Any
from fastapi import FastAPI from fastapi import FastAPI
@ -22,7 +21,12 @@ from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.attributes import service_attributes from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from opentelemetry.util.types import Attributes
from llama_stack.core.external import ExternalApiSpec
from llama_stack.core.server.tracing import TelemetryProvider
from llama_stack.providers.inline.telemetry.meta_reference.middleware import TracingMiddleware
from llama_stack.apis.telemetry import ( from llama_stack.apis.telemetry import (
Event, Event,
@ -73,7 +77,7 @@ def is_tracing_enabled(tracer):
return span.is_recording() return span.is_recording()
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): class TelemetryAdapter(TelemetryDatasetMixin, Telemetry, TelemetryProvider):
def __init__(self, config: TelemetryConfig, deps: dict[Api, Any]) -> None: def __init__(self, config: TelemetryConfig, deps: dict[Api, Any]) -> None:
self.config = config self.config = config
self.datasetio_api = deps.get(Api.datasetio) self.datasetio_api = deps.get(Api.datasetio)
@ -266,12 +270,13 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
# Log to OpenTelemetry meter if available # Log to OpenTelemetry meter if available
if self.meter is None: if self.meter is None:
return return
normalized_attributes = self._normalize_attributes(event.attributes)
if isinstance(event.value, int): if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit) counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes) counter.add(event.value, attributes=normalized_attributes)
elif isinstance(event.value, float): elif isinstance(event.value, float):
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit) up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
up_down_counter.add(event.value, attributes=event.attributes) up_down_counter.add(event.value, attributes=normalized_attributes)
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter: def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
assert self.meter is not None assert self.meter is not None
@ -287,18 +292,17 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
with self._lock: with self._lock:
span_id = int(event.span_id, 16) span_id = int(event.span_id, 16)
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
if event.attributes is None: event_attributes = dict(event.attributes or {})
event.attributes = {} event_attributes["__ttl__"] = ttl_seconds
event.attributes["__ttl__"] = ttl_seconds
# Extract these W3C trace context attributes so they are not written to # Extract these W3C trace context attributes so they are not written to
# underlying storage, as we just need them to propagate the trace context. # underlying storage, as we just need them to propagate the trace context.
traceparent = event.attributes.pop("traceparent", None) traceparent = event_attributes.pop("traceparent", None)
tracestate = event.attributes.pop("tracestate", None) tracestate = event_attributes.pop("tracestate", None)
if traceparent: if traceparent:
# If we have a traceparent header value, we're not the root span. # If we have a traceparent header value, we're not the root span.
for root_attribute in ROOT_SPAN_MARKERS: for root_attribute in ROOT_SPAN_MARKERS:
event.attributes.pop(root_attribute, None) event_attributes.pop(root_attribute, None)
if isinstance(event.payload, SpanStartPayload): if isinstance(event.payload, SpanStartPayload):
# Check if span already exists to prevent duplicates # Check if span already exists to prevent duplicates
@ -309,7 +313,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
if event.payload.parent_span_id: if event.payload.parent_span_id:
parent_span_id = int(event.payload.parent_span_id, 16) parent_span_id = int(event.payload.parent_span_id, 16)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
context = trace.set_span_in_context(parent_span) if parent_span:
context = trace.set_span_in_context(parent_span)
elif traceparent: elif traceparent:
carrier = { carrier = {
"traceparent": traceparent, "traceparent": traceparent,
@ -320,15 +325,15 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
span = tracer.start_span( span = tracer.start_span(
name=event.payload.name, name=event.payload.name,
context=context, context=context,
attributes=event.attributes or {}, attributes=self._normalize_attributes(event_attributes),
) )
_GLOBAL_STORAGE["active_spans"][span_id] = span _GLOBAL_STORAGE["active_spans"][span_id] = span
elif isinstance(event.payload, SpanEndPayload): elif isinstance(event.payload, SpanEndPayload):
span = _GLOBAL_STORAGE["active_spans"].get(span_id) span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span: if span:
if event.attributes: if event_attributes:
span.set_attributes(event.attributes) span.set_attributes(self._normalize_attributes(event_attributes))
status = ( status = (
trace.Status(status_code=trace.StatusCode.OK) trace.Status(status_code=trace.StatusCode.OK)
@ -377,5 +382,14 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
) )
) )
def fastapi_middleware(self, app: FastAPI) -> None: def fastapi_middleware(
FastAPIInstrumentor.instrument_app(app) self,
app: FastAPI,
impls: dict[Api, Any],
external_apis: dict[str, ExternalApiSpec],
):
TracingMiddleware(app, impls, external_apis)
@staticmethod
def _normalize_attributes(attributes: dict[str, Any] | None) -> Attributes:
return cast(Attributes, dict(attributes) if attributes else {})

View file

@ -0,0 +1,26 @@
# Open Telemetry Native Instrumentation
This instrumentation package is simple, and follows expected open telemetry standards. It injects middleware for distributed tracing into all ingress and egress points into the application, and can be tuned and configured with OTEL environment variables.
## Set Up
First, bootstrap and install all necessary libraries for open telemtry:
```
uv run opentelemetry-bootstrap -a requirements | uv pip install --requirement -
```
Then, run with automatic code injection:
```
uv run opentelemetry-instrument llama stack run --config myconfig.yaml
```
### Excluded Fast API URLs
```
export OTEL_PYTHON_FASTAPI_EXCLUDED_URLS="client/.*/info,healthcheck"
```
#### Environment Variables
Environment Variables: https://opentelemetry.io/docs/specs/otel/configuration/sdk-environment-variables/

View file

@ -0,0 +1,31 @@
from typing import Literal
from pydantic import BaseModel, Field
type BatchSpanProcessor = Literal["batch"]
type SimpleSpanProcessor = Literal["simple"]
class OTelTelemetryConfig(BaseModel):
"""
The configuration for the OpenTelemetry telemetry provider.
Most configuration is set using environment variables.
See https://opentelemetry.io/docs/specs/otel/configuration/sdk-environment-variables/ for more information.
"""
service_name: str = Field(
description="""The name of the service to be monitored.
Is overridden by the OTEL_SERVICE_NAME or OTEL_RESOURCE_ATTRIBUTES environment variables.""",
)
service_version: str | None = Field(
description="""The version of the service to be monitored.
Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable."""
)
deployment_environment: str | None = Field(
description="""The name of the environment of the service to be monitored.
Is overriden by the OTEL_RESOURCE_ATTRIBUTES environment variable."""
)
span_processor: BatchSpanProcessor | SimpleSpanProcessor | None = Field(
description="""The span processor to use.
Is overriden by the OTEL_SPAN_PROCESSOR environment variable.""",
default="batch"
)

View file

@ -0,0 +1,63 @@
import os
from opentelemetry import trace, metrics
from opentelemetry.sdk.resources import Attributes, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SimpleSpanProcessor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from llama_stack.core.telemetry.tracing import TelemetryProvider
from llama_stack.log import get_logger
from .config import OTelTelemetryConfig
from fastapi import FastAPI
logger = get_logger(name=__name__, category="telemetry::otel")
class OTelTelemetryProvider(TelemetryProvider):
"""
A simple Open Telemetry native telemetry provider.
"""
def __init__(self, config: OTelTelemetryConfig):
self.config = config
attributes: Attributes = {
key: value
for key, value in {
"service.name": self.config.service_name,
"service.version": self.config.service_version,
"deployment.environment": self.config.deployment_environment,
}.items()
if value is not None
}
resource = Resource.create(attributes)
# Configure the tracer provider
tracer_provider = TracerProvider(resource=resource)
trace.set_tracer_provider(tracer_provider)
otlp_span_exporter = OTLPSpanExporter()
# Configure the span processor
# Enable batching of spans to reduce the number of requests to the collector
if self.config.span_processor == "batch":
tracer_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
elif self.config.span_processor == "simple":
tracer_provider.add_span_processor(SimpleSpanProcessor(otlp_span_exporter))
meter_provider = MeterProvider(resource=resource)
metrics.set_meter_provider(meter_provider)
# Do not fail the application, but warn the user if the endpoints are not set properly
if not os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"):
if not os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"):
logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_TRACES_ENDPOINT is not set. Traces will not be exported.")
if not os.environ.get("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT"):
logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_METRICS_ENDPOINT is not set. Metrics will not be exported.")
def fastapi_middleware(self, app: FastAPI):
FastAPIInstrumentor.instrument_app(app)