llama-stack-mirror/llama_stack/providers/utils/telemetry/tracing.py
grs 68d8f2186f
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.12, agents) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, providers) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.13, inference) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.12, inference) (push) Failing after 12s
Integration Tests / test-matrix (http, 3.12, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.12, datasets) (push) Failing after 14s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 11s
Integration Tests / test-matrix (http, 3.13, datasets) (push) Failing after 12s
Integration Tests / test-matrix (http, 3.13, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.13, vector_io) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.13, agents) (push) Failing after 23s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.12, inspect) (push) Failing after 23s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 13s
Integration Tests / test-matrix (http, 3.12, scoring) (push) Failing after 13s
Integration Tests / test-matrix (http, 3.13, scoring) (push) Failing after 22s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 22s
Integration Tests / test-matrix (library, 3.13, inspect) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.12, tool_runtime) (push) Failing after 14s
Integration Tests / test-matrix (http, 3.13, inspect) (push) Failing after 11s
Integration Tests / test-matrix (http, 3.13, providers) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.12, post_training) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 20s
Integration Tests / test-matrix (library, 3.13, agents) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.13, post_training) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.13, inference) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, datasets) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.13, post_training) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, tool_runtime) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.13, providers) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.13, vector_io) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, scoring) (push) Failing after 10s
Python Package Build Test / build (3.12) (push) Failing after 7s
Test External Providers / test-external-providers (venv) (push) Failing after 8s
Unit Tests / unit-tests (3.13) (push) Failing after 9s
Python Package Build Test / build (3.13) (push) Failing after 32s
Unit Tests / unit-tests (3.12) (push) Failing after 48s
Pre-commit / pre-commit (push) Successful in 1m32s
fix: fix test of root span to match what is being set (#2494)
# What does this PR do?

I get errors when trying to query spans. It appears to be a result of
traces being inserted where there is no root_span_id which causes a
pydantic validation error on trying to load the data for a query
response (and in any case having no span referenced undermines the
purpose of the trace). The root cause as far as I can see is an invalid
test in the code that inserts the trace, where it is testing for the
string "true" against an object set to the python value True.

<!-- If resolving an issue, uncomment and update the line below -->
Closes #2493 

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->
With this change I can query spans.

Signed-off-by: Gordon Sim <gsim@redhat.com>
2025-06-26 11:41:35 -04:00

345 lines
10 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import contextvars
import logging
import queue
import random
import threading
from collections.abc import Callable
from datetime import UTC, datetime
from functools import wraps
from typing import Any
from llama_stack.apis.telemetry import (
LogSeverity,
Span,
SpanEndPayload,
SpanStartPayload,
SpanStatus,
StructuredLogEvent,
Telemetry,
UnstructuredLogEvent,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
logger = get_logger(__name__, category="core")
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
# The logical root span may not be visible to this process if a parent context
# is passed in. The local root span is the first local span in a trace.
LOCAL_ROOT_SPAN_MARKER = "__local_root_span__"
def trace_id_to_str(trace_id: int) -> str:
"""Convenience trace ID formatting method
Args:
trace_id: Trace ID int
Returns:
The trace ID as 32-byte hexadecimal string
"""
return format(trace_id, "032x")
def span_id_to_str(span_id: int) -> str:
"""Convenience span ID formatting method
Args:
span_id: Span ID int
Returns:
The span ID as 16-byte hexadecimal string
"""
return format(span_id, "016x")
def generate_span_id() -> str:
span_id = random.getrandbits(64)
while span_id == INVALID_SPAN_ID:
span_id = random.getrandbits(64)
return span_id_to_str(span_id)
def generate_trace_id() -> str:
trace_id = random.getrandbits(128)
while trace_id == INVALID_TRACE_ID:
trace_id = random.getrandbits(128)
return trace_id_to_str(trace_id)
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
BACKGROUND_LOGGER = None
class BackgroundLogger:
def __init__(self, api: Telemetry, capacity: int = 1000):
self.api = api
self.log_queue = queue.Queue(maxsize=capacity)
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
self.worker_thread.start()
def log_event(self, event):
try:
self.log_queue.put_nowait(event)
except queue.Full:
logger.error("Log queue is full, dropping event")
def _process_logs(self):
while True:
try:
event = self.log_queue.get()
# figure out how to use a thread's native loop
asyncio.run(self.api.log_event(event))
except Exception:
import traceback
traceback.print_exc()
print("Error processing log event")
finally:
self.log_queue.task_done()
def __del__(self):
self.log_queue.join()
class TraceContext:
spans: list[Span] = []
def __init__(self, logger: BackgroundLogger, trace_id: str):
self.logger = logger
self.trace_id = trace_id
def push_span(self, name: str, attributes: dict[str, Any] = None) -> Span:
current_span = self.get_current_span()
span = Span(
span_id=generate_span_id(),
trace_id=self.trace_id,
name=name,
start_time=datetime.now(UTC),
parent_span_id=current_span.span_id if current_span else None,
attributes=attributes,
)
self.logger.log_event(
StructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=span.start_time,
attributes=span.attributes,
payload=SpanStartPayload(
name=span.name,
parent_span_id=span.parent_span_id,
),
)
)
self.spans.append(span)
return span
def pop_span(self, status: SpanStatus = SpanStatus.OK):
span = self.spans.pop()
if span is not None:
self.logger.log_event(
StructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=span.start_time,
attributes=span.attributes,
payload=SpanEndPayload(
status=status,
),
)
)
def get_current_span(self):
return self.spans[-1] if self.spans else None
def setup_logger(api: Telemetry, level: int = logging.INFO):
global BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
BACKGROUND_LOGGER = BackgroundLogger(api)
root_logger = logging.getLogger()
root_logger.setLevel(level)
root_logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceContext:
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
return
trace_id = generate_trace_id()
context = TraceContext(BACKGROUND_LOGGER, trace_id)
# Mark this span as the root for the trace for now. The processing of
# traceparent context if supplied comes later and will result in the
# ROOT_SPAN_MARKERS being removed. Also mark this is the 'local' root,
# i.e. the root of the spans originating in this process as this is
# needed to ensure that we insert this 'local' root span's id into
# the trace record in sqlite store.
attributes = dict.fromkeys(ROOT_SPAN_MARKERS, True) | {LOCAL_ROOT_SPAN_MARKER: True} | (attributes or {})
context.push_span(name, attributes)
CURRENT_TRACE_CONTEXT.set(context)
return context
async def end_trace(status: SpanStatus = SpanStatus.OK):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if context is None:
logger.debug("No trace context to end")
return
context.pop_span(status)
CURRENT_TRACE_CONTEXT.set(None)
def severity(levelname: str) -> LogSeverity:
if levelname == "DEBUG":
return LogSeverity.DEBUG
elif levelname == "INFO":
return LogSeverity.INFO
elif levelname == "WARNING":
return LogSeverity.WARN
elif levelname == "ERROR":
return LogSeverity.ERROR
elif levelname == "CRITICAL":
return LogSeverity.CRITICAL
else:
raise ValueError(f"Unknown log level: {levelname}")
# TODO: ideally, the actual emitting should be done inside a separate daemon
# process completely isolated from the server
class TelemetryHandler(logging.Handler):
def emit(self, record: logging.LogRecord):
# horrendous hack to avoid logging from asyncio and getting into an infinite loop
if record.module in ("asyncio", "selector_events"):
return
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
raise RuntimeError("Telemetry API not initialized")
context = CURRENT_TRACE_CONTEXT.get()
if context is None:
return
span = context.get_current_span()
if span is None:
return
BACKGROUND_LOGGER.log_event(
UnstructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=datetime.now(UTC),
message=self.format(record),
severity=severity(record.levelname),
)
)
def close(self):
pass
class SpanContextManager:
def __init__(self, name: str, attributes: dict[str, Any] = None):
self.name = name
self.attributes = attributes
self.span = None
def __enter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to push span")
return self
self.span = context.push_span(self.name, self.attributes)
return self
def __exit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to pop span")
return
context.pop_span()
def set_attribute(self, key: str, value: Any):
if self.span:
if self.span.attributes is None:
self.span.attributes = {}
self.span.attributes[key] = serialize_value(value)
async def __aenter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to push span")
return self
self.span = context.push_span(self.name, self.attributes)
return self
async def __aexit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to pop span")
return
context.pop_span()
def __call__(self, func: Callable):
@wraps(func)
def sync_wrapper(*args, **kwargs):
with self:
return func(*args, **kwargs)
@wraps(func)
async def async_wrapper(*args, **kwargs):
async with self:
return await func(*args, **kwargs)
@wraps(func)
def wrapper(*args, **kwargs):
if asyncio.iscoroutinefunction(func):
return async_wrapper(*args, **kwargs)
else:
return sync_wrapper(*args, **kwargs)
return wrapper
def span(name: str, attributes: dict[str, Any] = None):
return SpanContextManager(name, attributes)
def get_current_span() -> Span | None:
global CURRENT_TRACE_CONTEXT
if CURRENT_TRACE_CONTEXT is None:
logger.debug("No trace context to get current span")
return None
context = CURRENT_TRACE_CONTEXT.get()
if context:
return context.get_current_span()
return None