mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 18:50:44 +00:00
feat: use same trace ids in stack and otel
This commit is contained in:
parent
baf68c665c
commit
a7de2f3ce4
4 changed files with 73 additions and 34 deletions
|
@ -237,9 +237,17 @@ class TracingMiddleware:
|
||||||
# Use the matched template or original path
|
# Use the matched template or original path
|
||||||
trace_path = route_template or path
|
trace_path = route_template or path
|
||||||
|
|
||||||
await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
||||||
|
|
||||||
|
async def send_wrapper(message):
|
||||||
|
if message["type"] == "http.response.start":
|
||||||
|
headers = message.get("headers", [])
|
||||||
|
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||||
|
message["headers"] = headers
|
||||||
|
await send(message)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send_wrapper)
|
||||||
finally:
|
finally:
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ from datetime import datetime, timezone
|
||||||
|
|
||||||
from opentelemetry.sdk.trace import SpanProcessor
|
from opentelemetry.sdk.trace import SpanProcessor
|
||||||
from opentelemetry.trace import Span
|
from opentelemetry.trace import Span
|
||||||
|
from opentelemetry.trace.span import format_span_id, format_trace_id
|
||||||
|
|
||||||
|
|
||||||
class SQLiteSpanProcessor(SpanProcessor):
|
class SQLiteSpanProcessor(SpanProcessor):
|
||||||
|
@ -100,14 +101,14 @@ class SQLiteSpanProcessor(SpanProcessor):
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
trace_id = format(span.get_span_context().trace_id, "032x")
|
trace_id = format_trace_id(span.get_span_context().trace_id)
|
||||||
span_id = format(span.get_span_context().span_id, "016x")
|
span_id = format_span_id(span.get_span_context().span_id)
|
||||||
service_name = span.resource.attributes.get("service.name", "unknown")
|
service_name = span.resource.attributes.get("service.name", "unknown")
|
||||||
|
|
||||||
parent_span_id = None
|
parent_span_id = None
|
||||||
parent_context = span.parent
|
parent_context = span.parent
|
||||||
if parent_context:
|
if parent_context:
|
||||||
parent_span_id = format(parent_context.span_id, "016x")
|
parent_span_id = format_span_id(parent_context.span_id)
|
||||||
|
|
||||||
# Insert into traces
|
# Insert into traces
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
|
@ -123,7 +124,7 @@ class SQLiteSpanProcessor(SpanProcessor):
|
||||||
(
|
(
|
||||||
trace_id,
|
trace_id,
|
||||||
service_name,
|
service_name,
|
||||||
(span_id if not parent_span_id else None),
|
(span_id if span.attributes.get("__root_span__") == "true" else None),
|
||||||
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
|
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
|
||||||
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
|
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
|
||||||
),
|
),
|
||||||
|
|
|
@ -54,16 +54,6 @@ _global_lock = threading.Lock()
|
||||||
_TRACER_PROVIDER = None
|
_TRACER_PROVIDER = None
|
||||||
|
|
||||||
|
|
||||||
def string_to_trace_id(s: str) -> int:
|
|
||||||
# Convert the string to bytes and then to an integer
|
|
||||||
return int.from_bytes(s.encode(), byteorder="big", signed=False)
|
|
||||||
|
|
||||||
|
|
||||||
def string_to_span_id(s: str) -> int:
|
|
||||||
# Use only the first 8 bytes (64 bits) for span ID
|
|
||||||
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
|
|
||||||
|
|
||||||
|
|
||||||
def is_tracing_enabled(tracer):
|
def is_tracing_enabled(tracer):
|
||||||
with tracer.start_as_current_span("check_tracing") as span:
|
with tracer.start_as_current_span("check_tracing") as span:
|
||||||
return span.is_recording()
|
return span.is_recording()
|
||||||
|
@ -136,7 +126,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
# Use global storage instead of instance storage
|
# Use global storage instead of instance storage
|
||||||
span_id = string_to_span_id(event.span_id)
|
span_id = event.span_id
|
||||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||||
|
|
||||||
if span:
|
if span:
|
||||||
|
@ -196,8 +186,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
|
|
||||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
span_id = string_to_span_id(event.span_id)
|
span_id = int(event.span_id, 16)
|
||||||
trace_id = string_to_trace_id(event.trace_id)
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
if event.attributes is None:
|
if event.attributes is None:
|
||||||
event.attributes = {}
|
event.attributes = {}
|
||||||
|
@ -208,14 +197,23 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||||
return
|
return
|
||||||
|
|
||||||
parent_span = None
|
context = None
|
||||||
if event.payload.parent_span_id:
|
if event.payload.parent_span_id:
|
||||||
parent_span_id = string_to_span_id(event.payload.parent_span_id)
|
parent_span_id = int(event.payload.parent_span_id, 16)
|
||||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||||
|
context = trace.set_span_in_context(parent_span)
|
||||||
context = trace.Context(trace_id=trace_id)
|
else:
|
||||||
if parent_span:
|
context = trace.set_span_in_context(
|
||||||
context = trace.set_span_in_context(parent_span, context)
|
trace.NonRecordingSpan(
|
||||||
|
trace.SpanContext(
|
||||||
|
trace_id=int(event.trace_id, 16),
|
||||||
|
span_id=span_id,
|
||||||
|
is_remote=False,
|
||||||
|
trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
event.attributes["__root_span__"] = "true"
|
||||||
|
|
||||||
span = tracer.start_span(
|
span = tracer.start_span(
|
||||||
name=event.payload.name,
|
name=event.payload.name,
|
||||||
|
|
|
@ -5,12 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
|
||||||
import contextvars
|
import contextvars
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
|
import random
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
@ -31,11 +30,44 @@ from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
|
||||||
logger = get_logger(__name__, category="core")
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def generate_short_uuid(len: int = 8):
|
INVALID_SPAN_ID = 0x0000000000000000
|
||||||
full_uuid = uuid.uuid4()
|
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||||
uuid_bytes = full_uuid.bytes
|
|
||||||
encoded = base64.urlsafe_b64encode(uuid_bytes)
|
|
||||||
return encoded.rstrip(b"=").decode("ascii")[:len]
|
def format_trace_id(trace_id: int) -> str:
|
||||||
|
"""Convenience trace ID formatting method
|
||||||
|
Args:
|
||||||
|
trace_id: Trace ID int
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The trace ID as 32-byte hexadecimal string
|
||||||
|
"""
|
||||||
|
return format(trace_id, "032x")
|
||||||
|
|
||||||
|
|
||||||
|
def format_span_id(span_id: int) -> str:
|
||||||
|
"""Convenience span ID formatting method
|
||||||
|
Args:
|
||||||
|
span_id: Span ID int
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The span ID as 16-byte hexadecimal string
|
||||||
|
"""
|
||||||
|
return format(span_id, "016x")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_span_id() -> str:
|
||||||
|
span_id = random.getrandbits(64)
|
||||||
|
while span_id == INVALID_SPAN_ID:
|
||||||
|
span_id = random.getrandbits(64)
|
||||||
|
return format_span_id(span_id)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_trace_id() -> str:
|
||||||
|
trace_id = random.getrandbits(128)
|
||||||
|
while trace_id == INVALID_TRACE_ID:
|
||||||
|
trace_id = random.getrandbits(128)
|
||||||
|
return format_trace_id(trace_id)
|
||||||
|
|
||||||
|
|
||||||
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
|
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
|
||||||
|
@ -83,7 +115,7 @@ class TraceContext:
|
||||||
def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span:
|
def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span:
|
||||||
current_span = self.get_current_span()
|
current_span = self.get_current_span()
|
||||||
span = Span(
|
span = Span(
|
||||||
span_id=generate_short_uuid(),
|
span_id=generate_span_id(),
|
||||||
trace_id=self.trace_id,
|
trace_id=self.trace_id,
|
||||||
name=name,
|
name=name,
|
||||||
start_time=datetime.now(timezone.utc),
|
start_time=datetime.now(timezone.utc),
|
||||||
|
@ -143,7 +175,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceCont
|
||||||
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
|
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
|
||||||
return
|
return
|
||||||
|
|
||||||
trace_id = generate_short_uuid(16)
|
trace_id = generate_trace_id()
|
||||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||||
context.push_span(name, {"__root__": True, **(attributes or {})})
|
context.push_span(name, {"__root__": True, **(attributes or {})})
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue