# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import asyncio import contextvars import logging # allow-direct-logging import queue import secrets import sys import threading import time from collections.abc import Callable from datetime import UTC, datetime from functools import wraps from typing import Any from llama_stack.apis.telemetry import ( Event, LogSeverity, Span, SpanEndPayload, SpanStartPayload, SpanStatus, StructuredLogEvent, Telemetry, UnstructuredLogEvent, ) from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value logger = get_logger(__name__, category="core") # Fallback logger that does NOT propagate to TelemetryHandler to avoid recursion _fallback_logger = logging.getLogger("llama_stack.telemetry.background") if not _fallback_logger.handlers: _fallback_logger.propagate = False _fallback_logger.setLevel(logging.ERROR) _fallback_handler = logging.StreamHandler(sys.stderr) _fallback_handler.setLevel(logging.ERROR) _fallback_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")) _fallback_logger.addHandler(_fallback_handler) INVALID_SPAN_ID = 0x0000000000000000 INVALID_TRACE_ID = 0x00000000000000000000000000000000 ROOT_SPAN_MARKERS = ["__root__", "__root_span__"] # The logical root span may not be visible to this process if a parent context # is passed in. The local root span is the first local span in a trace. LOCAL_ROOT_SPAN_MARKER = "__local_root_span__" def trace_id_to_str(trace_id: int) -> str: """Convenience trace ID formatting method Args: trace_id: Trace ID int Returns: The trace ID as 32-byte hexadecimal string """ return format(trace_id, "032x") def span_id_to_str(span_id: int) -> str: """Convenience span ID formatting method Args: span_id: Span ID int Returns: The span ID as 16-byte hexadecimal string """ return format(span_id, "016x") def generate_span_id() -> str: span_id = secrets.randbits(64) while span_id == INVALID_SPAN_ID: span_id = secrets.randbits(64) return span_id_to_str(span_id) def generate_trace_id() -> str: trace_id = secrets.randbits(128) while trace_id == INVALID_TRACE_ID: trace_id = secrets.randbits(128) return trace_id_to_str(trace_id) CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None) BACKGROUND_LOGGER = None LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS = 60.0 class BackgroundLogger: def __init__(self, api: Telemetry, capacity: int = 100000): self.api = api self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity) self.worker_thread = threading.Thread(target=self._worker, daemon=True) self.worker_thread.start() self._last_queue_full_log_time: float = 0.0 self._dropped_since_last_notice: int = 0 def log_event(self, event): try: self.log_queue.put_nowait(event) except queue.Full: # Aggregate drops and emit at most once per interval via fallback logger self._dropped_since_last_notice += 1 current_time = time.time() if current_time - self._last_queue_full_log_time >= LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS: _fallback_logger.error( "Log queue is full; dropped %d events since last notice", self._dropped_since_last_notice, ) self._last_queue_full_log_time = current_time self._dropped_since_last_notice = 0 def _worker(self): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(self._process_logs()) async def _process_logs(self): while True: try: event = self.log_queue.get() await self.api.log_event(event) except Exception: import traceback traceback.print_exc() print("Error processing log event") finally: self.log_queue.task_done() def __del__(self): self.log_queue.join() def enqueue_event(event: Event) -> None: """Enqueue a telemetry event to the background logger if available. This provides a non-blocking path for routers and other hot paths to submit telemetry without awaiting the Telemetry API, reducing contention with the main event loop. """ global BACKGROUND_LOGGER if BACKGROUND_LOGGER is None: raise RuntimeError("Telemetry API not initialized") BACKGROUND_LOGGER.log_event(event) class TraceContext: spans: list[Span] = [] def __init__(self, logger: BackgroundLogger, trace_id: str): self.logger = logger self.trace_id = trace_id def push_span(self, name: str, attributes: dict[str, Any] = None) -> Span: current_span = self.get_current_span() span = Span( span_id=generate_span_id(), trace_id=self.trace_id, name=name, start_time=datetime.now(UTC), parent_span_id=current_span.span_id if current_span else None, attributes=attributes, ) self.logger.log_event( StructuredLogEvent( trace_id=span.trace_id, span_id=span.span_id, timestamp=span.start_time, attributes=span.attributes, payload=SpanStartPayload( name=span.name, parent_span_id=span.parent_span_id, ), ) ) self.spans.append(span) return span def pop_span(self, status: SpanStatus = SpanStatus.OK): span = self.spans.pop() if span is not None: self.logger.log_event( StructuredLogEvent( trace_id=span.trace_id, span_id=span.span_id, timestamp=span.start_time, attributes=span.attributes, payload=SpanEndPayload( status=status, ), ) ) def get_current_span(self): return self.spans[-1] if self.spans else None def setup_logger(api: Telemetry, level: int = logging.INFO): global BACKGROUND_LOGGER if BACKGROUND_LOGGER is None: BACKGROUND_LOGGER = BackgroundLogger(api) root_logger = logging.getLogger() root_logger.setLevel(level) root_logger.addHandler(TelemetryHandler()) async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceContext: global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER if BACKGROUND_LOGGER is None: logger.debug("No Telemetry implementation set. Skipping trace initialization...") return trace_id = generate_trace_id() context = TraceContext(BACKGROUND_LOGGER, trace_id) # Mark this span as the root for the trace for now. The processing of # traceparent context if supplied comes later and will result in the # ROOT_SPAN_MARKERS being removed. Also mark this is the 'local' root, # i.e. the root of the spans originating in this process as this is # needed to ensure that we insert this 'local' root span's id into # the trace record in sqlite store. attributes = dict.fromkeys(ROOT_SPAN_MARKERS, True) | {LOCAL_ROOT_SPAN_MARKER: True} | (attributes or {}) context.push_span(name, attributes) CURRENT_TRACE_CONTEXT.set(context) return context async def end_trace(status: SpanStatus = SpanStatus.OK): global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if context is None: logger.debug("No trace context to end") return context.pop_span(status) CURRENT_TRACE_CONTEXT.set(None) def severity(levelname: str) -> LogSeverity: if levelname == "DEBUG": return LogSeverity.DEBUG elif levelname == "INFO": return LogSeverity.INFO elif levelname == "WARNING": return LogSeverity.WARN elif levelname == "ERROR": return LogSeverity.ERROR elif levelname == "CRITICAL": return LogSeverity.CRITICAL else: raise ValueError(f"Unknown log level: {levelname}") # TODO: ideally, the actual emitting should be done inside a separate daemon # process completely isolated from the server class TelemetryHandler(logging.Handler): def emit(self, record: logging.LogRecord): # horrendous hack to avoid logging from asyncio and getting into an infinite loop if record.module in ("asyncio", "selector_events"): return global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if context is None: return span = context.get_current_span() if span is None: return enqueue_event( UnstructuredLogEvent( trace_id=span.trace_id, span_id=span.span_id, timestamp=datetime.now(UTC), message=self.format(record), severity=severity(record.levelname), ) ) def close(self): pass class SpanContextManager: def __init__(self, name: str, attributes: dict[str, Any] = None): self.name = name self.attributes = attributes self.span = None def __enter__(self): global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if not context: logger.debug("No trace context to push span") return self self.span = context.push_span(self.name, self.attributes) return self def __exit__(self, exc_type, exc_value, traceback): global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if not context: logger.debug("No trace context to pop span") return context.pop_span() def set_attribute(self, key: str, value: Any): if self.span: if self.span.attributes is None: self.span.attributes = {} self.span.attributes[key] = serialize_value(value) async def __aenter__(self): global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if not context: logger.debug("No trace context to push span") return self self.span = context.push_span(self.name, self.attributes) return self async def __aexit__(self, exc_type, exc_value, traceback): global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if not context: logger.debug("No trace context to pop span") return context.pop_span() def __call__(self, func: Callable): @wraps(func) def sync_wrapper(*args, **kwargs): with self: return func(*args, **kwargs) @wraps(func) async def async_wrapper(*args, **kwargs): async with self: return await func(*args, **kwargs) @wraps(func) def wrapper(*args, **kwargs): if asyncio.iscoroutinefunction(func): return async_wrapper(*args, **kwargs) else: return sync_wrapper(*args, **kwargs) return wrapper def span(name: str, attributes: dict[str, Any] = None): return SpanContextManager(name, attributes) def get_current_span() -> Span | None: global CURRENT_TRACE_CONTEXT if CURRENT_TRACE_CONTEXT is None: logger.debug("No trace context to get current span") return None context = CURRENT_TRACE_CONTEXT.get() if context: return context.get_current_span() return None