diff --git a/src/llama_stack/core/utils/context.py b/src/llama_stack/core/utils/context.py index 87ad553e9..e7c61a8ed 100644 --- a/src/llama_stack/core/utils/context.py +++ b/src/llama_stack/core/utils/context.py @@ -7,6 +7,8 @@ from collections.abc import AsyncGenerator from contextvars import ContextVar +from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT + _MISSING = object() @@ -26,7 +28,8 @@ def preserve_contexts_async_generator[T]( previous_values: dict[ContextVar, object] = {} tokens: dict[ContextVar, object] = {} - # Restore context values before any await and capture previous state + # Restore ALL context values before any await and capture previous state + # This is needed to propagate context across async generator boundaries for context_var in context_vars: try: previous_values[context_var] = context_var.get() @@ -52,11 +55,13 @@ def preserve_contexts_async_generator[T]( try: item = await gen.__anext__() except StopAsyncIteration: - # Restore context vars before exiting to prevent leaks + # Restore all context vars before exiting to prevent leaks + # Use _restore_context_var for all vars to properly restore to previous values for context_var in context_vars: _restore_context_var(context_var) break except Exception: + # Restore all context vars on exception for context_var in context_vars: _restore_context_var(context_var) raise @@ -64,11 +69,16 @@ def preserve_contexts_async_generator[T]( try: yield item # Update our tracked values with any changes made during this iteration + # Only for non-trace context vars - trace context must persist across yields + # to allow nested span tracking for telemetry for context_var in context_vars: - initial_context_values[context_var.name] = context_var.get() + if context_var is not CURRENT_TRACE_CONTEXT: + initial_context_values[context_var.name] = context_var.get() finally: - # Restore context vars after each yield to prevent leaks between requests + # Restore non-trace context vars after each yield to prevent leaks between requests + # CURRENT_TRACE_CONTEXT is NOT restored here to preserve telemetry span stack for context_var in context_vars: - _restore_context_var(context_var) + if context_var is not CURRENT_TRACE_CONTEXT: + _restore_context_var(context_var) return wrapper()