mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
final fix, pretty subtle
This commit is contained in:
parent
6a849c3b18
commit
bc7d93fb39
1 changed files with 15 additions and 5 deletions
|
|
@ -7,6 +7,8 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
from contextvars import ContextVar
|
||||
|
||||
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT
|
||||
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
|
|
@ -26,7 +28,8 @@ def preserve_contexts_async_generator[T](
|
|||
previous_values: dict[ContextVar, object] = {}
|
||||
tokens: dict[ContextVar, object] = {}
|
||||
|
||||
# Restore context values before any await and capture previous state
|
||||
# Restore ALL context values before any await and capture previous state
|
||||
# This is needed to propagate context across async generator boundaries
|
||||
for context_var in context_vars:
|
||||
try:
|
||||
previous_values[context_var] = context_var.get()
|
||||
|
|
@ -52,11 +55,13 @@ def preserve_contexts_async_generator[T](
|
|||
try:
|
||||
item = await gen.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# Restore context vars before exiting to prevent leaks
|
||||
# Restore all context vars before exiting to prevent leaks
|
||||
# Use _restore_context_var for all vars to properly restore to previous values
|
||||
for context_var in context_vars:
|
||||
_restore_context_var(context_var)
|
||||
break
|
||||
except Exception:
|
||||
# Restore all context vars on exception
|
||||
for context_var in context_vars:
|
||||
_restore_context_var(context_var)
|
||||
raise
|
||||
|
|
@ -64,11 +69,16 @@ def preserve_contexts_async_generator[T](
|
|||
try:
|
||||
yield item
|
||||
# Update our tracked values with any changes made during this iteration
|
||||
# Only for non-trace context vars - trace context must persist across yields
|
||||
# to allow nested span tracking for telemetry
|
||||
for context_var in context_vars:
|
||||
initial_context_values[context_var.name] = context_var.get()
|
||||
if context_var is not CURRENT_TRACE_CONTEXT:
|
||||
initial_context_values[context_var.name] = context_var.get()
|
||||
finally:
|
||||
# Restore context vars after each yield to prevent leaks between requests
|
||||
# Restore non-trace context vars after each yield to prevent leaks between requests
|
||||
# CURRENT_TRACE_CONTEXT is NOT restored here to preserve telemetry span stack
|
||||
for context_var in context_vars:
|
||||
_restore_context_var(context_var)
|
||||
if context_var is not CURRENT_TRACE_CONTEXT:
|
||||
_restore_context_var(context_var)
|
||||
|
||||
return wrapper()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue