mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
# What does this PR do? This PR fixes some of the issues with our telemetry setup to enable logs to be delivered to opentelemetry and jaeger. Main fixes 1) Updates the open telemetry provider to use the latest oltp exports instead of deprected ones. 2) Adds a tracing middleware, which injects traces into each HTTP request that the server recieves and this is going to be the root trace. Previously, we did this in the create_dynamic_route method, which is actually not the actual exectuion flow, but more of a config and this causes the traces to end prematurely. Through middleware, we plugin the trace start and end at the right location. 3) We manage our own methods to create traces and spans and this does not fit well with Opentelemetry SDK since it does not support provide a way to take in traces and spans that are already created. it expects us to use the SDK to create them. For now, I have a hacky approach of just maintaining a map from our internal telemetry objects to the open telemetry specfic ones. This is not the ideal solution. I will explore other ways to get around this issue. for now, to have something that works, i am going to keep this as is. Addresses: #509
248 lines
7.1 KiB
Python
248 lines
7.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import base64
|
|
import logging
|
|
import queue
|
|
import threading
|
|
import uuid
|
|
from datetime import datetime
|
|
from functools import wraps
|
|
from typing import Any, Callable, Dict, List
|
|
|
|
|
|
from llama_stack.apis.telemetry import * # noqa: F403
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def generate_short_uuid(len: int = 8):
|
|
full_uuid = uuid.uuid4()
|
|
uuid_bytes = full_uuid.bytes
|
|
encoded = base64.urlsafe_b64encode(uuid_bytes)
|
|
return encoded.rstrip(b"=").decode("ascii")[:len]
|
|
|
|
|
|
CURRENT_TRACE_CONTEXT = None
|
|
BACKGROUND_LOGGER = None
|
|
|
|
|
|
class BackgroundLogger:
|
|
def __init__(self, api: Telemetry, capacity: int = 1000):
|
|
self.api = api
|
|
self.log_queue = queue.Queue(maxsize=capacity)
|
|
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
|
|
self.worker_thread.start()
|
|
|
|
def log_event(self, event):
|
|
try:
|
|
self.log_queue.put_nowait(event)
|
|
except queue.Full:
|
|
log.error("Log queue is full, dropping event")
|
|
|
|
def _process_logs(self):
|
|
while True:
|
|
try:
|
|
event = self.log_queue.get()
|
|
# figure out how to use a thread's native loop
|
|
asyncio.run(self.api.log_event(event))
|
|
except Exception:
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
print("Error processing log event")
|
|
finally:
|
|
self.log_queue.task_done()
|
|
|
|
def __del__(self):
|
|
self.log_queue.join()
|
|
|
|
|
|
class TraceContext:
|
|
spans: List[Span] = []
|
|
|
|
def __init__(self, logger: BackgroundLogger, trace_id: str):
|
|
self.logger = logger
|
|
self.trace_id = trace_id
|
|
|
|
def push_span(self, name: str, attributes: Dict[str, Any] = None):
|
|
current_span = self.get_current_span()
|
|
span = Span(
|
|
span_id=generate_short_uuid(),
|
|
trace_id=self.trace_id,
|
|
name=name,
|
|
start_time=datetime.now(),
|
|
parent_span_id=current_span.span_id if current_span else None,
|
|
attributes=attributes,
|
|
)
|
|
|
|
self.logger.log_event(
|
|
StructuredLogEvent(
|
|
trace_id=span.trace_id,
|
|
span_id=span.span_id,
|
|
timestamp=span.start_time,
|
|
attributes=span.attributes,
|
|
payload=SpanStartPayload(
|
|
name=span.name,
|
|
parent_span_id=span.parent_span_id,
|
|
),
|
|
)
|
|
)
|
|
|
|
self.spans.append(span)
|
|
|
|
def pop_span(self, status: SpanStatus = SpanStatus.OK):
|
|
span = self.spans.pop()
|
|
if span is not None:
|
|
self.logger.log_event(
|
|
StructuredLogEvent(
|
|
trace_id=span.trace_id,
|
|
span_id=span.span_id,
|
|
timestamp=span.start_time,
|
|
attributes=span.attributes,
|
|
payload=SpanEndPayload(
|
|
status=status,
|
|
),
|
|
)
|
|
)
|
|
|
|
def get_current_span(self):
|
|
return self.spans[-1] if self.spans else None
|
|
|
|
|
|
def setup_logger(api: Telemetry, level: int = logging.INFO):
|
|
global BACKGROUND_LOGGER
|
|
|
|
BACKGROUND_LOGGER = BackgroundLogger(api)
|
|
logger = logging.getLogger()
|
|
logger.setLevel(level)
|
|
logger.addHandler(TelemetryHandler())
|
|
|
|
|
|
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
|
|
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
|
|
|
if BACKGROUND_LOGGER is None:
|
|
log.info("No Telemetry implementation set. Skipping trace initialization...")
|
|
return
|
|
|
|
trace_id = generate_short_uuid(16)
|
|
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
|
context.push_span(name, {"__root__": True, **(attributes or {})})
|
|
|
|
CURRENT_TRACE_CONTEXT = context
|
|
return context
|
|
|
|
|
|
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
|
global CURRENT_TRACE_CONTEXT
|
|
|
|
context = CURRENT_TRACE_CONTEXT
|
|
if context is None:
|
|
return
|
|
|
|
context.pop_span(status)
|
|
CURRENT_TRACE_CONTEXT = None
|
|
|
|
|
|
def severity(levelname: str) -> LogSeverity:
|
|
if levelname == "DEBUG":
|
|
return LogSeverity.DEBUG
|
|
elif levelname == "INFO":
|
|
return LogSeverity.INFO
|
|
elif levelname == "WARNING":
|
|
return LogSeverity.WARN
|
|
elif levelname == "ERROR":
|
|
return LogSeverity.ERROR
|
|
elif levelname == "CRITICAL":
|
|
return LogSeverity.CRITICAL
|
|
else:
|
|
raise ValueError(f"Unknown log level: {levelname}")
|
|
|
|
|
|
# TODO: ideally, the actual emitting should be done inside a separate daemon
|
|
# process completely isolated from the server
|
|
class TelemetryHandler(logging.Handler):
|
|
def emit(self, record: logging.LogRecord):
|
|
# horrendous hack to avoid logging from asyncio and getting into an infinite loop
|
|
if record.module in ("asyncio", "selector_events"):
|
|
return
|
|
|
|
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
|
|
|
if BACKGROUND_LOGGER is None:
|
|
raise RuntimeError("Telemetry API not initialized")
|
|
|
|
context = CURRENT_TRACE_CONTEXT
|
|
if context is None:
|
|
return
|
|
|
|
span = context.get_current_span()
|
|
if span is None:
|
|
return
|
|
|
|
BACKGROUND_LOGGER.log_event(
|
|
UnstructuredLogEvent(
|
|
trace_id=span.trace_id,
|
|
span_id=span.span_id,
|
|
timestamp=datetime.now(),
|
|
message=self.format(record),
|
|
severity=severity(record.levelname),
|
|
)
|
|
)
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
|
|
class SpanContextManager:
|
|
def __init__(self, name: str, attributes: Dict[str, Any] = None):
|
|
self.name = name
|
|
self.attributes = attributes
|
|
|
|
def __enter__(self):
|
|
global CURRENT_TRACE_CONTEXT
|
|
context = CURRENT_TRACE_CONTEXT
|
|
if context:
|
|
context.push_span(self.name, self.attributes)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
global CURRENT_TRACE_CONTEXT
|
|
context = CURRENT_TRACE_CONTEXT
|
|
if context:
|
|
context.pop_span()
|
|
|
|
async def __aenter__(self):
|
|
return self.__enter__()
|
|
|
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
self.__exit__(exc_type, exc_value, traceback)
|
|
|
|
def __call__(self, func: Callable):
|
|
@wraps(func)
|
|
def sync_wrapper(*args, **kwargs):
|
|
with self:
|
|
return func(*args, **kwargs)
|
|
|
|
@wraps(func)
|
|
async def async_wrapper(*args, **kwargs):
|
|
async with self:
|
|
return await func(*args, **kwargs)
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if asyncio.iscoroutinefunction(func):
|
|
return async_wrapper(*args, **kwargs)
|
|
else:
|
|
return sync_wrapper(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def span(name: str, attributes: Dict[str, Any] = None):
|
|
return SpanContextManager(name, attributes)
|