mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
final fix, pretty subtle
This commit is contained in:
parent
6a849c3b18
commit
bc7d93fb39
1 changed files with 15 additions and 5 deletions
|
|
@ -7,6 +7,8 @@
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT
|
||||||
|
|
||||||
_MISSING = object()
|
_MISSING = object()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -26,7 +28,8 @@ def preserve_contexts_async_generator[T](
|
||||||
previous_values: dict[ContextVar, object] = {}
|
previous_values: dict[ContextVar, object] = {}
|
||||||
tokens: dict[ContextVar, object] = {}
|
tokens: dict[ContextVar, object] = {}
|
||||||
|
|
||||||
# Restore context values before any await and capture previous state
|
# Restore ALL context values before any await and capture previous state
|
||||||
|
# This is needed to propagate context across async generator boundaries
|
||||||
for context_var in context_vars:
|
for context_var in context_vars:
|
||||||
try:
|
try:
|
||||||
previous_values[context_var] = context_var.get()
|
previous_values[context_var] = context_var.get()
|
||||||
|
|
@ -52,11 +55,13 @@ def preserve_contexts_async_generator[T](
|
||||||
try:
|
try:
|
||||||
item = await gen.__anext__()
|
item = await gen.__anext__()
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
# Restore context vars before exiting to prevent leaks
|
# Restore all context vars before exiting to prevent leaks
|
||||||
|
# Use _restore_context_var for all vars to properly restore to previous values
|
||||||
for context_var in context_vars:
|
for context_var in context_vars:
|
||||||
_restore_context_var(context_var)
|
_restore_context_var(context_var)
|
||||||
break
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# Restore all context vars on exception
|
||||||
for context_var in context_vars:
|
for context_var in context_vars:
|
||||||
_restore_context_var(context_var)
|
_restore_context_var(context_var)
|
||||||
raise
|
raise
|
||||||
|
|
@ -64,11 +69,16 @@ def preserve_contexts_async_generator[T](
|
||||||
try:
|
try:
|
||||||
yield item
|
yield item
|
||||||
# Update our tracked values with any changes made during this iteration
|
# Update our tracked values with any changes made during this iteration
|
||||||
|
# Only for non-trace context vars - trace context must persist across yields
|
||||||
|
# to allow nested span tracking for telemetry
|
||||||
for context_var in context_vars:
|
for context_var in context_vars:
|
||||||
initial_context_values[context_var.name] = context_var.get()
|
if context_var is not CURRENT_TRACE_CONTEXT:
|
||||||
|
initial_context_values[context_var.name] = context_var.get()
|
||||||
finally:
|
finally:
|
||||||
# Restore context vars after each yield to prevent leaks between requests
|
# Restore non-trace context vars after each yield to prevent leaks between requests
|
||||||
|
# CURRENT_TRACE_CONTEXT is NOT restored here to preserve telemetry span stack
|
||||||
for context_var in context_vars:
|
for context_var in context_vars:
|
||||||
_restore_context_var(context_var)
|
if context_var is not CURRENT_TRACE_CONTEXT:
|
||||||
|
_restore_context_var(context_var)
|
||||||
|
|
||||||
return wrapper()
|
return wrapper()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue