feat: use same trace ids in stack and otel (#1759)

# What does this PR do?
1) Uses otel compatible id generation for stack
2) Stack starts returning trace id info in the header of response
3) We inject the same trace id that we have into otel in order to force
it to use our trace ids.

## Test Plan
```
 curl -i --request POST \
  --url http://localhost:8321/v1/inference/chat-completion \
  --header 'content-type: application/json' \
  --data '{
  "model_id": "meta-llama/Llama-3.1-70B-Instruct",
  "messages": [
    {
      "role": "user",
      "content": {
        "type": "text",
        "text": "where do humans live"
      }
    }
  ],
  "stream": false
}'
HTTP/1.1 200 OK
date: Fri, 21 Mar 2025 21:51:19 GMT
server: uvicorn
content-length: 1712
content-type: application/json
x-trace-id: 595101ede31ece116ebe35b26d67e8cf

{"metrics":[{"metric":"prompt_tokens","value":10,"unit":null},{"metric":"completion_tokens","value":320,"unit":null},{"metric":"total_tokens","value":330,"unit":null}],"completion_message":{"role":"assistant","content":"Humans live on the planet Earth, specifically on its landmasses and in its oceans. Here's a breakdown of where humans live:\n\n1. **Continents:** Humans inhabit all seven continents:\n\t* Africa\n\t* Antarctica ( temporary residents, mostly scientists and researchers)\n\t* Asia\n\t* Australia\n\t* Europe\n\t* North America\n\t* South America\n2. **Countries:** There are 196 countries recognized by the United Nations, and humans live in almost all of them.\n3. **Cities and towns:** Many humans live in urban areas, such as cities and towns, which are often located near coastlines, rivers, or other bodies of water.\n4. **Rural areas:** Some humans live in rural areas, such as villages, farms, and countryside.\n5. **Islands:** Humans inhabit many islands around the world, including tropical islands, island nations, and islands in the Arctic and Antarctic regions.\n6. **Underwater habitats:** A few humans live in underwater habitats, such as research stations and submarines.\n7. **Space:** A small number of humans have lived in space, including astronauts on the International Space Station and those who have visited the Moon.\n\nIn terms of specific environments, humans live in a wide range of ecosystems, including:\n\n* Deserts\n* Forests\n* Grasslands\n* Mountains\n* Oceans\n* Rivers\n* Tundras\n* Wetlands\n\nOverall, humans are incredibly adaptable and can be found living in almost every corner of the globe.","stop_reason":"end_of_turn","tool_calls":[]},"logprobs":null}
```

Same trace id in Jaeger and sqlite:

![Screenshot 2025-03-21 at 2 51
53 PM](https://github.com/user-attachments/assets/38cc04b0-568c-4b9d-bccd-d3b90e581c27)
![Screenshot 2025-03-21 at 2 52
38 PM](https://github.com/user-attachments/assets/722383ad-6305-4020-8a1c-6cfdf381c25f)
This commit is contained in:
Dinesh Yeduguru 2025-03-21 15:41:26 -07:00 committed by GitHub
parent b9fbfed216
commit 5eb15684b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 73 additions and 34 deletions

View file

@ -237,9 +237,17 @@ class TracingMiddleware:
# Use the matched template or original path
trace_path = route_template or path
await start_trace(trace_path, {"__location__": "server", "raw_path": path})
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
async def send_with_trace_id(message):
if message["type"] == "http.response.start":
headers = message.get("headers", [])
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
message["headers"] = headers
await send(message)
try:
return await self.app(scope, receive, send)
return await self.app(scope, receive, send_with_trace_id)
finally:
await end_trace()

View file

@ -12,6 +12,7 @@ from datetime import datetime, timezone
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span
from opentelemetry.trace.span import format_span_id, format_trace_id
class SQLiteSpanProcessor(SpanProcessor):
@ -100,14 +101,14 @@ class SQLiteSpanProcessor(SpanProcessor):
conn = self._get_connection()
cursor = conn.cursor()
trace_id = format(span.get_span_context().trace_id, "032x")
span_id = format(span.get_span_context().span_id, "016x")
trace_id = format_trace_id(span.get_span_context().trace_id)
span_id = format_span_id(span.get_span_context().span_id)
service_name = span.resource.attributes.get("service.name", "unknown")
parent_span_id = None
parent_context = span.parent
if parent_context:
parent_span_id = format(parent_context.span_id, "016x")
parent_span_id = format_span_id(parent_context.span_id)
# Insert into traces
cursor.execute(
@ -123,7 +124,7 @@ class SQLiteSpanProcessor(SpanProcessor):
(
trace_id,
service_name,
(span_id if not parent_span_id else None),
(span_id if span.attributes.get("__root_span__") == "true" else None),
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
),

View file

@ -54,16 +54,6 @@ _global_lock = threading.Lock()
_TRACER_PROVIDER = None
def string_to_trace_id(s: str) -> int:
# Convert the string to bytes and then to an integer
return int.from_bytes(s.encode(), byteorder="big", signed=False)
def string_to_span_id(s: str) -> int:
# Use only the first 8 bytes (64 bits) for span ID
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording()
@ -137,7 +127,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
# Use global storage instead of instance storage
span_id = string_to_span_id(event.span_id)
span_id = event.span_id
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
@ -197,8 +187,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
span_id = string_to_span_id(event.span_id)
trace_id = string_to_trace_id(event.trace_id)
span_id = int(event.span_id, 16)
tracer = trace.get_tracer(__name__)
if event.attributes is None:
event.attributes = {}
@ -209,14 +198,23 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
if span_id in _GLOBAL_STORAGE["active_spans"]:
return
parent_span = None
context = None
if event.payload.parent_span_id:
parent_span_id = string_to_span_id(event.payload.parent_span_id)
parent_span_id = int(event.payload.parent_span_id, 16)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
context = trace.Context(trace_id=trace_id)
if parent_span:
context = trace.set_span_in_context(parent_span, context)
context = trace.set_span_in_context(parent_span)
else:
context = trace.set_span_in_context(
trace.NonRecordingSpan(
trace.SpanContext(
trace_id=int(event.trace_id, 16),
span_id=span_id,
is_remote=False,
trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED),
)
)
)
event.attributes["__root_span__"] = "true"
span = tracer.start_span(
name=event.payload.name,

View file

@ -5,12 +5,11 @@
# the root directory of this source tree.
import asyncio
import base64
import contextvars
import logging
import queue
import random
import threading
import uuid
from datetime import datetime, timezone
from functools import wraps
from typing import Any, Callable, Dict, List, Optional
@ -31,11 +30,44 @@ from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
logger = get_logger(__name__, category="core")
def generate_short_uuid(len: int = 8):
full_uuid = uuid.uuid4()
uuid_bytes = full_uuid.bytes
encoded = base64.urlsafe_b64encode(uuid_bytes)
return encoded.rstrip(b"=").decode("ascii")[:len]
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000
def trace_id_to_str(trace_id: int) -> str:
"""Convenience trace ID formatting method
Args:
trace_id: Trace ID int
Returns:
The trace ID as 32-byte hexadecimal string
"""
return format(trace_id, "032x")
def span_id_to_str(span_id: int) -> str:
"""Convenience span ID formatting method
Args:
span_id: Span ID int
Returns:
The span ID as 16-byte hexadecimal string
"""
return format(span_id, "016x")
def generate_span_id() -> str:
span_id = random.getrandbits(64)
while span_id == INVALID_SPAN_ID:
span_id = random.getrandbits(64)
return span_id_to_str(span_id)
def generate_trace_id() -> str:
trace_id = random.getrandbits(128)
while trace_id == INVALID_TRACE_ID:
trace_id = random.getrandbits(128)
return trace_id_to_str(trace_id)
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
@ -83,7 +115,7 @@ class TraceContext:
def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span:
current_span = self.get_current_span()
span = Span(
span_id=generate_short_uuid(),
span_id=generate_span_id(),
trace_id=self.trace_id,
name=name,
start_time=datetime.now(timezone.utc),
@ -143,7 +175,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceCont
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
return
trace_id = generate_short_uuid(16)
trace_id = generate_trace_id()
context = TraceContext(BACKGROUND_LOGGER, trace_id)
context.push_span(name, {"__root__": True, **(attributes or {})})