mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
feat: use same trace ids in stack and otel
This commit is contained in:
parent
baf68c665c
commit
a7de2f3ce4
4 changed files with 73 additions and 34 deletions
|
@ -237,9 +237,17 @@ class TracingMiddleware:
|
|||
# Use the matched template or original path
|
||||
trace_path = route_template or path
|
||||
|
||||
await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
||||
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
||||
|
||||
async def send_wrapper(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = message.get("headers", [])
|
||||
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send)
|
||||
return await self.app(scope, receive, send_wrapper)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ from datetime import datetime, timezone
|
|||
|
||||
from opentelemetry.sdk.trace import SpanProcessor
|
||||
from opentelemetry.trace import Span
|
||||
from opentelemetry.trace.span import format_span_id, format_trace_id
|
||||
|
||||
|
||||
class SQLiteSpanProcessor(SpanProcessor):
|
||||
|
@ -100,14 +101,14 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
trace_id = format(span.get_span_context().trace_id, "032x")
|
||||
span_id = format(span.get_span_context().span_id, "016x")
|
||||
trace_id = format_trace_id(span.get_span_context().trace_id)
|
||||
span_id = format_span_id(span.get_span_context().span_id)
|
||||
service_name = span.resource.attributes.get("service.name", "unknown")
|
||||
|
||||
parent_span_id = None
|
||||
parent_context = span.parent
|
||||
if parent_context:
|
||||
parent_span_id = format(parent_context.span_id, "016x")
|
||||
parent_span_id = format_span_id(parent_context.span_id)
|
||||
|
||||
# Insert into traces
|
||||
cursor.execute(
|
||||
|
@ -123,7 +124,7 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
(
|
||||
trace_id,
|
||||
service_name,
|
||||
(span_id if not parent_span_id else None),
|
||||
(span_id if span.attributes.get("__root_span__") == "true" else None),
|
||||
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
|
||||
),
|
||||
|
|
|
@ -54,16 +54,6 @@ _global_lock = threading.Lock()
|
|||
_TRACER_PROVIDER = None
|
||||
|
||||
|
||||
def string_to_trace_id(s: str) -> int:
|
||||
# Convert the string to bytes and then to an integer
|
||||
return int.from_bytes(s.encode(), byteorder="big", signed=False)
|
||||
|
||||
|
||||
def string_to_span_id(s: str) -> int:
|
||||
# Use only the first 8 bytes (64 bits) for span ID
|
||||
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
|
||||
|
||||
|
||||
def is_tracing_enabled(tracer):
|
||||
with tracer.start_as_current_span("check_tracing") as span:
|
||||
return span.is_recording()
|
||||
|
@ -136,7 +126,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
# Use global storage instead of instance storage
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
span_id = event.span_id
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
|
@ -196,8 +186,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
|
||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
trace_id = string_to_trace_id(event.trace_id)
|
||||
span_id = int(event.span_id, 16)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
if event.attributes is None:
|
||||
event.attributes = {}
|
||||
|
@ -208,14 +197,23 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||
return
|
||||
|
||||
parent_span = None
|
||||
context = None
|
||||
if event.payload.parent_span_id:
|
||||
parent_span_id = string_to_span_id(event.payload.parent_span_id)
|
||||
parent_span_id = int(event.payload.parent_span_id, 16)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
|
||||
context = trace.Context(trace_id=trace_id)
|
||||
if parent_span:
|
||||
context = trace.set_span_in_context(parent_span, context)
|
||||
context = trace.set_span_in_context(parent_span)
|
||||
else:
|
||||
context = trace.set_span_in_context(
|
||||
trace.NonRecordingSpan(
|
||||
trace.SpanContext(
|
||||
trace_id=int(event.trace_id, 16),
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED),
|
||||
)
|
||||
)
|
||||
)
|
||||
event.attributes["__root_span__"] = "true"
|
||||
|
||||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
|
|
|
@ -5,12 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
import random
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
@ -31,11 +30,44 @@ from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
|
|||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def generate_short_uuid(len: int = 8):
|
||||
full_uuid = uuid.uuid4()
|
||||
uuid_bytes = full_uuid.bytes
|
||||
encoded = base64.urlsafe_b64encode(uuid_bytes)
|
||||
return encoded.rstrip(b"=").decode("ascii")[:len]
|
||||
INVALID_SPAN_ID = 0x0000000000000000
|
||||
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||
|
||||
|
||||
def format_trace_id(trace_id: int) -> str:
|
||||
"""Convenience trace ID formatting method
|
||||
Args:
|
||||
trace_id: Trace ID int
|
||||
|
||||
Returns:
|
||||
The trace ID as 32-byte hexadecimal string
|
||||
"""
|
||||
return format(trace_id, "032x")
|
||||
|
||||
|
||||
def format_span_id(span_id: int) -> str:
|
||||
"""Convenience span ID formatting method
|
||||
Args:
|
||||
span_id: Span ID int
|
||||
|
||||
Returns:
|
||||
The span ID as 16-byte hexadecimal string
|
||||
"""
|
||||
return format(span_id, "016x")
|
||||
|
||||
|
||||
def generate_span_id() -> str:
|
||||
span_id = random.getrandbits(64)
|
||||
while span_id == INVALID_SPAN_ID:
|
||||
span_id = random.getrandbits(64)
|
||||
return format_span_id(span_id)
|
||||
|
||||
|
||||
def generate_trace_id() -> str:
|
||||
trace_id = random.getrandbits(128)
|
||||
while trace_id == INVALID_TRACE_ID:
|
||||
trace_id = random.getrandbits(128)
|
||||
return format_trace_id(trace_id)
|
||||
|
||||
|
||||
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
|
||||
|
@ -83,7 +115,7 @@ class TraceContext:
|
|||
def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span:
|
||||
current_span = self.get_current_span()
|
||||
span = Span(
|
||||
span_id=generate_short_uuid(),
|
||||
span_id=generate_span_id(),
|
||||
trace_id=self.trace_id,
|
||||
name=name,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
|
@ -143,7 +175,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceCont
|
|||
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
|
||||
return
|
||||
|
||||
trace_id = generate_short_uuid(16)
|
||||
trace_id = generate_trace_id()
|
||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||
context.push_span(name, {"__root__": True, **(attributes or {})})
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue