mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-24 00:47:00 +00:00
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Python Package Build Test / build (3.13) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 19s
API Conformance Tests / check-schema-compatibility (push) Successful in 8s
Test External API and Providers / test-external (venv) (push) Failing after 3s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 17s
Unit Tests / unit-tests (3.13) (push) Failing after 14s
Vector IO Integration Tests / test-matrix (push) Failing after 18s
Python Package Build Test / build (3.12) (push) Failing after 18s
UI Tests / ui-tests (22) (push) Successful in 53s
Pre-commit / pre-commit (push) Successful in 1m14s
# What does this PR do? Switches from `random.getrandbits` to `secrets.randbits` in the telemetry module. <!-- If resolving an issue, uncomment and update the line below --> Closes #3553 ## Test Plan Unit tests from scripts/unit-tests.sh were ran to verify the tests still pass. Signed-off-by: Doug Edgar <dedgar@redhat.com>
384 lines
12 KiB
Python
384 lines
12 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import contextvars
|
|
import logging # allow-direct-logging
|
|
import queue
|
|
import secrets
|
|
import sys
|
|
import threading
|
|
import time
|
|
from collections.abc import Callable
|
|
from datetime import UTC, datetime
|
|
from functools import wraps
|
|
from typing import Any
|
|
|
|
from llama_stack.apis.telemetry import (
|
|
Event,
|
|
LogSeverity,
|
|
Span,
|
|
SpanEndPayload,
|
|
SpanStartPayload,
|
|
SpanStatus,
|
|
StructuredLogEvent,
|
|
Telemetry,
|
|
UnstructuredLogEvent,
|
|
)
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
|
|
|
|
logger = get_logger(__name__, category="core")
|
|
|
|
# Fallback logger that does NOT propagate to TelemetryHandler to avoid recursion
|
|
_fallback_logger = logging.getLogger("llama_stack.telemetry.background")
|
|
if not _fallback_logger.handlers:
|
|
_fallback_logger.propagate = False
|
|
_fallback_logger.setLevel(logging.ERROR)
|
|
_fallback_handler = logging.StreamHandler(sys.stderr)
|
|
_fallback_handler.setLevel(logging.ERROR)
|
|
_fallback_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
|
|
_fallback_logger.addHandler(_fallback_handler)
|
|
|
|
|
|
INVALID_SPAN_ID = 0x0000000000000000
|
|
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
|
|
|
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
|
|
# The logical root span may not be visible to this process if a parent context
|
|
# is passed in. The local root span is the first local span in a trace.
|
|
LOCAL_ROOT_SPAN_MARKER = "__local_root_span__"
|
|
|
|
|
|
def trace_id_to_str(trace_id: int) -> str:
|
|
"""Convenience trace ID formatting method
|
|
Args:
|
|
trace_id: Trace ID int
|
|
|
|
Returns:
|
|
The trace ID as 32-byte hexadecimal string
|
|
"""
|
|
return format(trace_id, "032x")
|
|
|
|
|
|
def span_id_to_str(span_id: int) -> str:
|
|
"""Convenience span ID formatting method
|
|
Args:
|
|
span_id: Span ID int
|
|
|
|
Returns:
|
|
The span ID as 16-byte hexadecimal string
|
|
"""
|
|
return format(span_id, "016x")
|
|
|
|
|
|
def generate_span_id() -> str:
|
|
span_id = secrets.randbits(64)
|
|
while span_id == INVALID_SPAN_ID:
|
|
span_id = secrets.randbits(64)
|
|
return span_id_to_str(span_id)
|
|
|
|
|
|
def generate_trace_id() -> str:
|
|
trace_id = secrets.randbits(128)
|
|
while trace_id == INVALID_TRACE_ID:
|
|
trace_id = secrets.randbits(128)
|
|
return trace_id_to_str(trace_id)
|
|
|
|
|
|
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
|
|
BACKGROUND_LOGGER = None
|
|
|
|
LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS = 60.0
|
|
|
|
|
|
class BackgroundLogger:
|
|
def __init__(self, api: Telemetry, capacity: int = 100000):
|
|
self.api = api
|
|
self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity)
|
|
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
|
|
self.worker_thread.start()
|
|
self._last_queue_full_log_time: float = 0.0
|
|
self._dropped_since_last_notice: int = 0
|
|
|
|
def log_event(self, event):
|
|
try:
|
|
self.log_queue.put_nowait(event)
|
|
except queue.Full:
|
|
# Aggregate drops and emit at most once per interval via fallback logger
|
|
self._dropped_since_last_notice += 1
|
|
current_time = time.time()
|
|
if current_time - self._last_queue_full_log_time >= LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS:
|
|
_fallback_logger.error(
|
|
"Log queue is full; dropped %d events since last notice",
|
|
self._dropped_since_last_notice,
|
|
)
|
|
self._last_queue_full_log_time = current_time
|
|
self._dropped_since_last_notice = 0
|
|
|
|
def _worker(self):
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
loop.run_until_complete(self._process_logs())
|
|
|
|
async def _process_logs(self):
|
|
while True:
|
|
try:
|
|
event = self.log_queue.get()
|
|
await self.api.log_event(event)
|
|
except Exception:
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
print("Error processing log event")
|
|
finally:
|
|
self.log_queue.task_done()
|
|
|
|
def __del__(self):
|
|
self.log_queue.join()
|
|
|
|
|
|
def enqueue_event(event: Event) -> None:
|
|
"""Enqueue a telemetry event to the background logger if available.
|
|
|
|
This provides a non-blocking path for routers and other hot paths to
|
|
submit telemetry without awaiting the Telemetry API, reducing contention
|
|
with the main event loop.
|
|
"""
|
|
global BACKGROUND_LOGGER
|
|
if BACKGROUND_LOGGER is None:
|
|
raise RuntimeError("Telemetry API not initialized")
|
|
BACKGROUND_LOGGER.log_event(event)
|
|
|
|
|
|
class TraceContext:
|
|
spans: list[Span] = []
|
|
|
|
def __init__(self, logger: BackgroundLogger, trace_id: str):
|
|
self.logger = logger
|
|
self.trace_id = trace_id
|
|
|
|
def push_span(self, name: str, attributes: dict[str, Any] = None) -> Span:
|
|
current_span = self.get_current_span()
|
|
span = Span(
|
|
span_id=generate_span_id(),
|
|
trace_id=self.trace_id,
|
|
name=name,
|
|
start_time=datetime.now(UTC),
|
|
parent_span_id=current_span.span_id if current_span else None,
|
|
attributes=attributes,
|
|
)
|
|
|
|
self.logger.log_event(
|
|
StructuredLogEvent(
|
|
trace_id=span.trace_id,
|
|
span_id=span.span_id,
|
|
timestamp=span.start_time,
|
|
attributes=span.attributes,
|
|
payload=SpanStartPayload(
|
|
name=span.name,
|
|
parent_span_id=span.parent_span_id,
|
|
),
|
|
)
|
|
)
|
|
|
|
self.spans.append(span)
|
|
return span
|
|
|
|
def pop_span(self, status: SpanStatus = SpanStatus.OK):
|
|
span = self.spans.pop()
|
|
if span is not None:
|
|
self.logger.log_event(
|
|
StructuredLogEvent(
|
|
trace_id=span.trace_id,
|
|
span_id=span.span_id,
|
|
timestamp=span.start_time,
|
|
attributes=span.attributes,
|
|
payload=SpanEndPayload(
|
|
status=status,
|
|
),
|
|
)
|
|
)
|
|
|
|
def get_current_span(self):
|
|
return self.spans[-1] if self.spans else None
|
|
|
|
|
|
def setup_logger(api: Telemetry, level: int = logging.INFO):
|
|
global BACKGROUND_LOGGER
|
|
|
|
if BACKGROUND_LOGGER is None:
|
|
BACKGROUND_LOGGER = BackgroundLogger(api)
|
|
root_logger = logging.getLogger()
|
|
root_logger.setLevel(level)
|
|
root_logger.addHandler(TelemetryHandler())
|
|
|
|
|
|
async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceContext:
|
|
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
|
|
|
if BACKGROUND_LOGGER is None:
|
|
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
|
|
return
|
|
|
|
trace_id = generate_trace_id()
|
|
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
|
# Mark this span as the root for the trace for now. The processing of
|
|
# traceparent context if supplied comes later and will result in the
|
|
# ROOT_SPAN_MARKERS being removed. Also mark this is the 'local' root,
|
|
# i.e. the root of the spans originating in this process as this is
|
|
# needed to ensure that we insert this 'local' root span's id into
|
|
# the trace record in sqlite store.
|
|
attributes = dict.fromkeys(ROOT_SPAN_MARKERS, True) | {LOCAL_ROOT_SPAN_MARKER: True} | (attributes or {})
|
|
context.push_span(name, attributes)
|
|
|
|
CURRENT_TRACE_CONTEXT.set(context)
|
|
return context
|
|
|
|
|
|
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
|
global CURRENT_TRACE_CONTEXT
|
|
|
|
context = CURRENT_TRACE_CONTEXT.get()
|
|
if context is None:
|
|
logger.debug("No trace context to end")
|
|
return
|
|
|
|
context.pop_span(status)
|
|
CURRENT_TRACE_CONTEXT.set(None)
|
|
|
|
|
|
def severity(levelname: str) -> LogSeverity:
|
|
if levelname == "DEBUG":
|
|
return LogSeverity.DEBUG
|
|
elif levelname == "INFO":
|
|
return LogSeverity.INFO
|
|
elif levelname == "WARNING":
|
|
return LogSeverity.WARN
|
|
elif levelname == "ERROR":
|
|
return LogSeverity.ERROR
|
|
elif levelname == "CRITICAL":
|
|
return LogSeverity.CRITICAL
|
|
else:
|
|
raise ValueError(f"Unknown log level: {levelname}")
|
|
|
|
|
|
# TODO: ideally, the actual emitting should be done inside a separate daemon
|
|
# process completely isolated from the server
|
|
class TelemetryHandler(logging.Handler):
|
|
def emit(self, record: logging.LogRecord):
|
|
# horrendous hack to avoid logging from asyncio and getting into an infinite loop
|
|
if record.module in ("asyncio", "selector_events"):
|
|
return
|
|
|
|
global CURRENT_TRACE_CONTEXT
|
|
context = CURRENT_TRACE_CONTEXT.get()
|
|
if context is None:
|
|
return
|
|
|
|
span = context.get_current_span()
|
|
if span is None:
|
|
return
|
|
|
|
enqueue_event(
|
|
UnstructuredLogEvent(
|
|
trace_id=span.trace_id,
|
|
span_id=span.span_id,
|
|
timestamp=datetime.now(UTC),
|
|
message=self.format(record),
|
|
severity=severity(record.levelname),
|
|
)
|
|
)
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
|
|
class SpanContextManager:
|
|
def __init__(self, name: str, attributes: dict[str, Any] = None):
|
|
self.name = name
|
|
self.attributes = attributes
|
|
self.span = None
|
|
|
|
def __enter__(self):
|
|
global CURRENT_TRACE_CONTEXT
|
|
context = CURRENT_TRACE_CONTEXT.get()
|
|
if not context:
|
|
logger.debug("No trace context to push span")
|
|
return self
|
|
|
|
self.span = context.push_span(self.name, self.attributes)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
global CURRENT_TRACE_CONTEXT
|
|
context = CURRENT_TRACE_CONTEXT.get()
|
|
if not context:
|
|
logger.debug("No trace context to pop span")
|
|
return
|
|
|
|
context.pop_span()
|
|
|
|
def set_attribute(self, key: str, value: Any):
|
|
if self.span:
|
|
if self.span.attributes is None:
|
|
self.span.attributes = {}
|
|
self.span.attributes[key] = serialize_value(value)
|
|
|
|
async def __aenter__(self):
|
|
global CURRENT_TRACE_CONTEXT
|
|
context = CURRENT_TRACE_CONTEXT.get()
|
|
if not context:
|
|
logger.debug("No trace context to push span")
|
|
return self
|
|
|
|
self.span = context.push_span(self.name, self.attributes)
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
global CURRENT_TRACE_CONTEXT
|
|
context = CURRENT_TRACE_CONTEXT.get()
|
|
if not context:
|
|
logger.debug("No trace context to pop span")
|
|
return
|
|
|
|
context.pop_span()
|
|
|
|
def __call__(self, func: Callable):
|
|
@wraps(func)
|
|
def sync_wrapper(*args, **kwargs):
|
|
with self:
|
|
return func(*args, **kwargs)
|
|
|
|
@wraps(func)
|
|
async def async_wrapper(*args, **kwargs):
|
|
async with self:
|
|
return await func(*args, **kwargs)
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if asyncio.iscoroutinefunction(func):
|
|
return async_wrapper(*args, **kwargs)
|
|
else:
|
|
return sync_wrapper(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def span(name: str, attributes: dict[str, Any] = None):
|
|
return SpanContextManager(name, attributes)
|
|
|
|
|
|
def get_current_span() -> Span | None:
|
|
global CURRENT_TRACE_CONTEXT
|
|
if CURRENT_TRACE_CONTEXT is None:
|
|
logger.debug("No trace context to get current span")
|
|
return None
|
|
|
|
context = CURRENT_TRACE_CONTEXT.get()
|
|
if context:
|
|
return context.get_current_span()
|
|
return None
|