# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import asyncio import contextvars import logging import queue import random import threading from collections.abc import Callable from datetime import UTC, datetime from functools import wraps from typing import Any from llama_stack.apis.telemetry import ( LogSeverity, Span, SpanEndPayload, SpanStartPayload, SpanStatus, StructuredLogEvent, Telemetry, UnstructuredLogEvent, ) from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value logger = get_logger(__name__, category="core") INVALID_SPAN_ID = 0x0000000000000000 INVALID_TRACE_ID = 0x00000000000000000000000000000000 ROOT_SPAN_MARKERS = ["__root__", "__root_span__"] # The logical root span may not be visible to this process if a parent context # is passed in. The local root span is the first local span in a trace. LOCAL_ROOT_SPAN_MARKER = "__local_root_span__" def trace_id_to_str(trace_id: int) -> str: """Convenience trace ID formatting method Args: trace_id: Trace ID int Returns: The trace ID as 32-byte hexadecimal string """ return format(trace_id, "032x") def span_id_to_str(span_id: int) -> str: """Convenience span ID formatting method Args: span_id: Span ID int Returns: The span ID as 16-byte hexadecimal string """ return format(span_id, "016x") def generate_span_id() -> str: span_id = random.getrandbits(64) while span_id == INVALID_SPAN_ID: span_id = random.getrandbits(64) return span_id_to_str(span_id) def generate_trace_id() -> str: trace_id = random.getrandbits(128) while trace_id == INVALID_TRACE_ID: trace_id = random.getrandbits(128) return trace_id_to_str(trace_id) CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None) BACKGROUND_LOGGER = None class BackgroundLogger: def __init__(self, api: Telemetry, capacity: int = 1000): self.api = api self.log_queue = queue.Queue(maxsize=capacity) self.worker_thread = threading.Thread(target=self._process_logs, daemon=True) self.worker_thread.start() def log_event(self, event): try: self.log_queue.put_nowait(event) except queue.Full: logger.error("Log queue is full, dropping event") def _process_logs(self): while True: try: event = self.log_queue.get() # figure out how to use a thread's native loop asyncio.run(self.api.log_event(event)) except Exception: import traceback traceback.print_exc() print("Error processing log event") finally: self.log_queue.task_done() def __del__(self): self.log_queue.join() class TraceContext: spans: list[Span] = [] def __init__(self, logger: BackgroundLogger, trace_id: str): self.logger = logger self.trace_id = trace_id def push_span(self, name: str, attributes: dict[str, Any] = None) -> Span: current_span = self.get_current_span() span = Span( span_id=generate_span_id(), trace_id=self.trace_id, name=name, start_time=datetime.now(UTC), parent_span_id=current_span.span_id if current_span else None, attributes=attributes, ) self.logger.log_event( StructuredLogEvent( trace_id=span.trace_id, span_id=span.span_id, timestamp=span.start_time, attributes=span.attributes, payload=SpanStartPayload( name=span.name, parent_span_id=span.parent_span_id, ), ) ) self.spans.append(span) return span def pop_span(self, status: SpanStatus = SpanStatus.OK): span = self.spans.pop() if span is not None: self.logger.log_event( StructuredLogEvent( trace_id=span.trace_id, span_id=span.span_id, timestamp=span.start_time, attributes=span.attributes, payload=SpanEndPayload( status=status, ), ) ) def get_current_span(self): return self.spans[-1] if self.spans else None def setup_logger(api: Telemetry, level: int = logging.INFO): global BACKGROUND_LOGGER if BACKGROUND_LOGGER is None: BACKGROUND_LOGGER = BackgroundLogger(api) root_logger = logging.getLogger() root_logger.setLevel(level) root_logger.addHandler(TelemetryHandler()) async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceContext: global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER if BACKGROUND_LOGGER is None: logger.debug("No Telemetry implementation set. Skipping trace initialization...") return trace_id = generate_trace_id() context = TraceContext(BACKGROUND_LOGGER, trace_id) # Mark this span as the root for the trace for now. The processing of # traceparent context if supplied comes later and will result in the # ROOT_SPAN_MARKERS being removed. Also mark this is the 'local' root, # i.e. the root of the spans originating in this process as this is # needed to ensure that we insert this 'local' root span's id into # the trace record in sqlite store. attributes = dict.fromkeys(ROOT_SPAN_MARKERS, True) | {LOCAL_ROOT_SPAN_MARKER: True} | (attributes or {}) context.push_span(name, attributes) CURRENT_TRACE_CONTEXT.set(context) return context async def end_trace(status: SpanStatus = SpanStatus.OK): global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if context is None: logger.debug("No trace context to end") return context.pop_span(status) CURRENT_TRACE_CONTEXT.set(None) def severity(levelname: str) -> LogSeverity: if levelname == "DEBUG": return LogSeverity.DEBUG elif levelname == "INFO": return LogSeverity.INFO elif levelname == "WARNING": return LogSeverity.WARN elif levelname == "ERROR": return LogSeverity.ERROR elif levelname == "CRITICAL": return LogSeverity.CRITICAL else: raise ValueError(f"Unknown log level: {levelname}") # TODO: ideally, the actual emitting should be done inside a separate daemon # process completely isolated from the server class TelemetryHandler(logging.Handler): def emit(self, record: logging.LogRecord): # horrendous hack to avoid logging from asyncio and getting into an infinite loop if record.module in ("asyncio", "selector_events"): return global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER if BACKGROUND_LOGGER is None: raise RuntimeError("Telemetry API not initialized") context = CURRENT_TRACE_CONTEXT.get() if context is None: return span = context.get_current_span() if span is None: return BACKGROUND_LOGGER.log_event( UnstructuredLogEvent( trace_id=span.trace_id, span_id=span.span_id, timestamp=datetime.now(UTC), message=self.format(record), severity=severity(record.levelname), ) ) def close(self): pass class SpanContextManager: def __init__(self, name: str, attributes: dict[str, Any] = None): self.name = name self.attributes = attributes self.span = None def __enter__(self): global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if not context: logger.debug("No trace context to push span") return self self.span = context.push_span(self.name, self.attributes) return self def __exit__(self, exc_type, exc_value, traceback): global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if not context: logger.debug("No trace context to pop span") return context.pop_span() def set_attribute(self, key: str, value: Any): if self.span: if self.span.attributes is None: self.span.attributes = {} self.span.attributes[key] = serialize_value(value) async def __aenter__(self): global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if not context: logger.debug("No trace context to push span") return self self.span = context.push_span(self.name, self.attributes) return self async def __aexit__(self, exc_type, exc_value, traceback): global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if not context: logger.debug("No trace context to pop span") return context.pop_span() def __call__(self, func: Callable): @wraps(func) def sync_wrapper(*args, **kwargs): with self: return func(*args, **kwargs) @wraps(func) async def async_wrapper(*args, **kwargs): async with self: return await func(*args, **kwargs) @wraps(func) def wrapper(*args, **kwargs): if asyncio.iscoroutinefunction(func): return async_wrapper(*args, **kwargs) else: return sync_wrapper(*args, **kwargs) return wrapper def span(name: str, attributes: dict[str, Any] = None): return SpanContextManager(name, attributes) def get_current_span() -> Span | None: global CURRENT_TRACE_CONTEXT if CURRENT_TRACE_CONTEXT is None: logger.debug("No trace context to get current span") return None context = CURRENT_TRACE_CONTEXT.get() if context: return context.get_current_span() return None