From 5eb15684b4e36bbc480ae8616c9b2a41dc4a95e7 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 21 Mar 2025 15:41:26 -0700 Subject: [PATCH] feat: use same trace ids in stack and otel (#1759) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? 1) Uses otel compatible id generation for stack 2) Stack starts returning trace id info in the header of response 3) We inject the same trace id that we have into otel in order to force it to use our trace ids. ## Test Plan ``` curl -i --request POST \ --url http://localhost:8321/v1/inference/chat-completion \ --header 'content-type: application/json' \ --data '{ "model_id": "meta-llama/Llama-3.1-70B-Instruct", "messages": [ { "role": "user", "content": { "type": "text", "text": "where do humans live" } } ], "stream": false }' HTTP/1.1 200 OK date: Fri, 21 Mar 2025 21:51:19 GMT server: uvicorn content-length: 1712 content-type: application/json x-trace-id: 595101ede31ece116ebe35b26d67e8cf {"metrics":[{"metric":"prompt_tokens","value":10,"unit":null},{"metric":"completion_tokens","value":320,"unit":null},{"metric":"total_tokens","value":330,"unit":null}],"completion_message":{"role":"assistant","content":"Humans live on the planet Earth, specifically on its landmasses and in its oceans. Here's a breakdown of where humans live:\n\n1. **Continents:** Humans inhabit all seven continents:\n\t* Africa\n\t* Antarctica ( temporary residents, mostly scientists and researchers)\n\t* Asia\n\t* Australia\n\t* Europe\n\t* North America\n\t* South America\n2. **Countries:** There are 196 countries recognized by the United Nations, and humans live in almost all of them.\n3. **Cities and towns:** Many humans live in urban areas, such as cities and towns, which are often located near coastlines, rivers, or other bodies of water.\n4. **Rural areas:** Some humans live in rural areas, such as villages, farms, and countryside.\n5. **Islands:** Humans inhabit many islands around the world, including tropical islands, island nations, and islands in the Arctic and Antarctic regions.\n6. **Underwater habitats:** A few humans live in underwater habitats, such as research stations and submarines.\n7. **Space:** A small number of humans have lived in space, including astronauts on the International Space Station and those who have visited the Moon.\n\nIn terms of specific environments, humans live in a wide range of ecosystems, including:\n\n* Deserts\n* Forests\n* Grasslands\n* Mountains\n* Oceans\n* Rivers\n* Tundras\n* Wetlands\n\nOverall, humans are incredibly adaptable and can be found living in almost every corner of the globe.","stop_reason":"end_of_turn","tool_calls":[]},"logprobs":null} ``` Same trace id in Jaeger and sqlite: ![Screenshot 2025-03-21 at 2 51 53 PM](https://github.com/user-attachments/assets/38cc04b0-568c-4b9d-bccd-d3b90e581c27) ![Screenshot 2025-03-21 at 2 52 38 PM](https://github.com/user-attachments/assets/722383ad-6305-4020-8a1c-6cfdf381c25f) --- llama_stack/distribution/server/server.py | 12 ++++- .../meta_reference/sqlite_span_processor.py | 9 ++-- .../telemetry/meta_reference/telemetry.py | 36 +++++++------ .../providers/utils/telemetry/tracing.py | 50 +++++++++++++++---- 4 files changed, 73 insertions(+), 34 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index dd430dbcd..39de1e4df 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -237,9 +237,17 @@ class TracingMiddleware: # Use the matched template or original path trace_path = route_template or path - await start_trace(trace_path, {"__location__": "server", "raw_path": path}) + trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path}) + + async def send_with_trace_id(message): + if message["type"] == "http.response.start": + headers = message.get("headers", []) + headers.append([b"x-trace-id", str(trace_context.trace_id).encode()]) + message["headers"] = headers + await send(message) + try: - return await self.app(scope, receive, send) + return await self.app(scope, receive, send_with_trace_id) finally: await end_trace() diff --git a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py index 5ed586fce..e9a003db6 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py @@ -12,6 +12,7 @@ from datetime import datetime, timezone from opentelemetry.sdk.trace import SpanProcessor from opentelemetry.trace import Span +from opentelemetry.trace.span import format_span_id, format_trace_id class SQLiteSpanProcessor(SpanProcessor): @@ -100,14 +101,14 @@ class SQLiteSpanProcessor(SpanProcessor): conn = self._get_connection() cursor = conn.cursor() - trace_id = format(span.get_span_context().trace_id, "032x") - span_id = format(span.get_span_context().span_id, "016x") + trace_id = format_trace_id(span.get_span_context().trace_id) + span_id = format_span_id(span.get_span_context().span_id) service_name = span.resource.attributes.get("service.name", "unknown") parent_span_id = None parent_context = span.parent if parent_context: - parent_span_id = format(parent_context.span_id, "016x") + parent_span_id = format_span_id(parent_context.span_id) # Insert into traces cursor.execute( @@ -123,7 +124,7 @@ class SQLiteSpanProcessor(SpanProcessor): ( trace_id, service_name, - (span_id if not parent_span_id else None), + (span_id if span.attributes.get("__root_span__") == "true" else None), datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(), datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(), ), diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 46a88a7b8..181bfda9b 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -54,16 +54,6 @@ _global_lock = threading.Lock() _TRACER_PROVIDER = None -def string_to_trace_id(s: str) -> int: - # Convert the string to bytes and then to an integer - return int.from_bytes(s.encode(), byteorder="big", signed=False) - - -def string_to_span_id(s: str) -> int: - # Use only the first 8 bytes (64 bits) for span ID - return int.from_bytes(s.encode()[:8], byteorder="big", signed=False) - - def is_tracing_enabled(tracer): with tracer.start_as_current_span("check_tracing") as span: return span.is_recording() @@ -137,7 +127,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None: with self._lock: # Use global storage instead of instance storage - span_id = string_to_span_id(event.span_id) + span_id = event.span_id span = _GLOBAL_STORAGE["active_spans"].get(span_id) if span: @@ -197,8 +187,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: with self._lock: - span_id = string_to_span_id(event.span_id) - trace_id = string_to_trace_id(event.trace_id) + span_id = int(event.span_id, 16) tracer = trace.get_tracer(__name__) if event.attributes is None: event.attributes = {} @@ -209,14 +198,23 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): if span_id in _GLOBAL_STORAGE["active_spans"]: return - parent_span = None + context = None if event.payload.parent_span_id: - parent_span_id = string_to_span_id(event.payload.parent_span_id) + parent_span_id = int(event.payload.parent_span_id, 16) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) - - context = trace.Context(trace_id=trace_id) - if parent_span: - context = trace.set_span_in_context(parent_span, context) + context = trace.set_span_in_context(parent_span) + else: + context = trace.set_span_in_context( + trace.NonRecordingSpan( + trace.SpanContext( + trace_id=int(event.trace_id, 16), + span_id=span_id, + is_remote=False, + trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED), + ) + ) + ) + event.attributes["__root_span__"] = "true" span = tracer.start_span( name=event.payload.name, diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 607d1a918..3d5c717d6 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -5,12 +5,11 @@ # the root directory of this source tree. import asyncio -import base64 import contextvars import logging import queue +import random import threading -import uuid from datetime import datetime, timezone from functools import wraps from typing import Any, Callable, Dict, List, Optional @@ -31,11 +30,44 @@ from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value logger = get_logger(__name__, category="core") -def generate_short_uuid(len: int = 8): - full_uuid = uuid.uuid4() - uuid_bytes = full_uuid.bytes - encoded = base64.urlsafe_b64encode(uuid_bytes) - return encoded.rstrip(b"=").decode("ascii")[:len] +INVALID_SPAN_ID = 0x0000000000000000 +INVALID_TRACE_ID = 0x00000000000000000000000000000000 + + +def trace_id_to_str(trace_id: int) -> str: + """Convenience trace ID formatting method + Args: + trace_id: Trace ID int + + Returns: + The trace ID as 32-byte hexadecimal string + """ + return format(trace_id, "032x") + + +def span_id_to_str(span_id: int) -> str: + """Convenience span ID formatting method + Args: + span_id: Span ID int + + Returns: + The span ID as 16-byte hexadecimal string + """ + return format(span_id, "016x") + + +def generate_span_id() -> str: + span_id = random.getrandbits(64) + while span_id == INVALID_SPAN_ID: + span_id = random.getrandbits(64) + return span_id_to_str(span_id) + + +def generate_trace_id() -> str: + trace_id = random.getrandbits(128) + while trace_id == INVALID_TRACE_ID: + trace_id = random.getrandbits(128) + return trace_id_to_str(trace_id) CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None) @@ -83,7 +115,7 @@ class TraceContext: def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span: current_span = self.get_current_span() span = Span( - span_id=generate_short_uuid(), + span_id=generate_span_id(), trace_id=self.trace_id, name=name, start_time=datetime.now(timezone.utc), @@ -143,7 +175,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceCont logger.debug("No Telemetry implementation set. Skipping trace initialization...") return - trace_id = generate_short_uuid(16) + trace_id = generate_trace_id() context = TraceContext(BACKGROUND_LOGGER, trace_id) context.push_span(name, {"__root__": True, **(attributes or {})})