fix: tracing fixes for trace context propogation across coroutines (#1522)

# What does this PR do?
This PR has two fixes needed for correct trace context propagation
across asycnio boundary
Fix 1: Start using context vars to store the global trace context.
This is needed since we cannot use the same trace context across
coroutines since the state is shared. each coroutine
should have its own trace context so that each of it can start storing
its state correctly.
Fix 2: Start a new span for each new coroutines started for running
shields to keep the span tree clean


## Test Plan

### Integration tests with server
LLAMA_STACK_DISABLE_VERSION_CHECK=true llama stack run
~/.llama/distributions/together/together-run.yaml
LLAMA_STACK_CONFIG=http://localhost:8321 pytest -s --safety-shield
meta-llama/Llama-Guard-3-8B --text-model
meta-llama/Llama-3.1-8B-Instruct
server logs:
https://gist.github.com/dineshyv/51ac5d9864ed031d0d89ce77352821fe
test logs:
https://gist.github.com/dineshyv/e66acc1c4648a42f1854600609c467f3
 
### Integration tests with library client
LLAMA_STACK_CONFIG=fireworks pytest -s --safety-shield
meta-llama/Llama-Guard-3-8B --text-model
meta-llama/Llama-3.1-8B-Instruct

logs: https://gist.github.com/dineshyv/ca160696a0b167223378673fb1dcefb8

### Apps test with server:
```
LLAMA_STACK_DISABLE_VERSION_CHECK=true llama stack run ~/.llama/distributions/together/together-run.yaml
python -m examples.agents.e2e_loop_with_client_tools localhost 8321
```
server logs:
https://gist.github.com/dineshyv/1717a572d8f7c14279c36123b79c5797
app logs:
https://gist.github.com/dineshyv/44167e9f57806a0ba3b710c32aec02f8
This commit is contained in:
Dinesh Yeduguru 2025-03-11 07:12:48 -07:00 committed by GitHub
parent e3edca7739
commit ead9397e22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 54 additions and 35 deletions

View file

@ -6,6 +6,7 @@
import asyncio
import base64
import contextvars
import logging
import queue
import threading
@ -24,9 +25,10 @@ from llama_stack.apis.telemetry import (
Telemetry,
UnstructuredLogEvent,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
log = logging.getLogger(__name__)
logger = get_logger(__name__, category="core")
def generate_short_uuid(len: int = 8):
@ -36,7 +38,7 @@ def generate_short_uuid(len: int = 8):
return encoded.rstrip(b"=").decode("ascii")[:len]
CURRENT_TRACE_CONTEXT = None
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
BACKGROUND_LOGGER = None
@ -51,7 +53,7 @@ class BackgroundLogger:
try:
self.log_queue.put_nowait(event)
except queue.Full:
log.error("Log queue is full, dropping event")
logger.error("Log queue is full, dropping event")
def _process_logs(self):
while True:
@ -129,35 +131,36 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
if BACKGROUND_LOGGER is None:
BACKGROUND_LOGGER = BackgroundLogger(api)
logger = logging.getLogger()
logger.setLevel(level)
logger.addHandler(TelemetryHandler())
root_logger = logging.getLogger()
root_logger.setLevel(level)
root_logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
log.info("No Telemetry implementation set. Skipping trace initialization...")
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
return
trace_id = generate_short_uuid(16)
context = TraceContext(BACKGROUND_LOGGER, trace_id)
context.push_span(name, {"__root__": True, **(attributes or {})})
CURRENT_TRACE_CONTEXT = context
CURRENT_TRACE_CONTEXT.set(context)
return context
async def end_trace(status: SpanStatus = SpanStatus.OK):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if context is None:
logger.debug("No trace context to end")
return
context.pop_span(status)
CURRENT_TRACE_CONTEXT = None
CURRENT_TRACE_CONTEXT.set(None)
def severity(levelname: str) -> LogSeverity:
@ -188,7 +191,7 @@ class TelemetryHandler(logging.Handler):
if BACKGROUND_LOGGER is None:
raise RuntimeError("Telemetry API not initialized")
context = CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if context is None:
return
@ -218,16 +221,22 @@ class SpanContextManager:
def __enter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
self.span = context.push_span(self.name, self.attributes)
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to push span")
return self
self.span = context.push_span(self.name, self.attributes)
return self
def __exit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.pop_span()
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to pop span")
return
context.pop_span()
def set_attribute(self, key: str, value: Any):
if self.span:
@ -237,16 +246,22 @@ class SpanContextManager:
async def __aenter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
self.span = context.push_span(self.name, self.attributes)
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to push span")
return self
self.span = context.push_span(self.name, self.attributes)
return self
async def __aexit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.pop_span()
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to pop span")
return
context.pop_span()
def __call__(self, func: Callable):
@wraps(func)
@ -275,7 +290,11 @@ def span(name: str, attributes: Dict[str, Any] = None):
def get_current_span() -> Optional[Span]:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if CURRENT_TRACE_CONTEXT is None:
logger.debug("No trace context to get current span")
return None
context = CURRENT_TRACE_CONTEXT.get()
if context:
return context.get_current_span()
return None