final fix, pretty subtle

This commit is contained in:
Ashwin Bharambe 2025-10-27 22:57:00 -07:00
parent 6a849c3b18
commit bc7d93fb39

View file

@ -7,6 +7,8 @@
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextvars import ContextVar from contextvars import ContextVar
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT
_MISSING = object() _MISSING = object()
@ -26,7 +28,8 @@ def preserve_contexts_async_generator[T](
previous_values: dict[ContextVar, object] = {} previous_values: dict[ContextVar, object] = {}
tokens: dict[ContextVar, object] = {} tokens: dict[ContextVar, object] = {}
# Restore context values before any await and capture previous state # Restore ALL context values before any await and capture previous state
# This is needed to propagate context across async generator boundaries
for context_var in context_vars: for context_var in context_vars:
try: try:
previous_values[context_var] = context_var.get() previous_values[context_var] = context_var.get()
@ -52,11 +55,13 @@ def preserve_contexts_async_generator[T](
try: try:
item = await gen.__anext__() item = await gen.__anext__()
except StopAsyncIteration: except StopAsyncIteration:
# Restore context vars before exiting to prevent leaks # Restore all context vars before exiting to prevent leaks
# Use _restore_context_var for all vars to properly restore to previous values
for context_var in context_vars: for context_var in context_vars:
_restore_context_var(context_var) _restore_context_var(context_var)
break break
except Exception: except Exception:
# Restore all context vars on exception
for context_var in context_vars: for context_var in context_vars:
_restore_context_var(context_var) _restore_context_var(context_var)
raise raise
@ -64,11 +69,16 @@ def preserve_contexts_async_generator[T](
try: try:
yield item yield item
# Update our tracked values with any changes made during this iteration # Update our tracked values with any changes made during this iteration
# Only for non-trace context vars - trace context must persist across yields
# to allow nested span tracking for telemetry
for context_var in context_vars: for context_var in context_vars:
initial_context_values[context_var.name] = context_var.get() if context_var is not CURRENT_TRACE_CONTEXT:
initial_context_values[context_var.name] = context_var.get()
finally: finally:
# Restore context vars after each yield to prevent leaks between requests # Restore non-trace context vars after each yield to prevent leaks between requests
# CURRENT_TRACE_CONTEXT is NOT restored here to preserve telemetry span stack
for context_var in context_vars: for context_var in context_vars:
_restore_context_var(context_var) if context_var is not CURRENT_TRACE_CONTEXT:
_restore_context_var(context_var)
return wrapper() return wrapper()