fix: context token explosion bug fixed

This commit is contained in:
Emilio Garcia 2025-11-18 16:37:11 -05:00
parent 5984ae6a76
commit 914bf84a1d

View file

@ -30,10 +30,22 @@ def preserve_contexts_async_generator[T](
# This is needed to propagate context across async generator boundaries # This is needed to propagate context across async generator boundaries
for context_var in context_vars: for context_var in context_vars:
try: try:
previous_values[context_var] = context_var.get() current_value = context_var.get()
previous_values[context_var] = current_value
# Only set if value actually changed to avoid creating unnecessary tokens
target_value = initial_context_values[context_var.name]
if current_value != target_value:
tokens[context_var] = context_var.set(target_value)
else:
tokens[context_var] = None
except LookupError: except LookupError:
previous_values[context_var] = _MISSING previous_values[context_var] = _MISSING
tokens[context_var] = context_var.set(initial_context_values[context_var.name]) # Var was unset, set it to initial value
target_value = initial_context_values[context_var.name]
if target_value is not None:
tokens[context_var] = context_var.set(target_value)
else:
tokens[context_var] = None
def _restore_context_var(context_var: ContextVar, *, _tokens=tokens, _prev=previous_values) -> None: def _restore_context_var(context_var: ContextVar, *, _tokens=tokens, _prev=previous_values) -> None:
token = _tokens.get(context_var) token = _tokens.get(context_var)
@ -45,10 +57,15 @@ def preserve_contexts_async_generator[T](
except (RuntimeError, ValueError): except (RuntimeError, ValueError):
pass pass
if previous_value is _MISSING: # Manual restoration when token reset fails
context_var.set(None) # Avoid creating unnecessary tokens to prevent accumulation that can cause deadlocks
else: if previous_value is not _MISSING:
context_var.set(previous_value) try:
current_value = context_var.get()
if current_value != previous_value:
context_var.set(previous_value)
except LookupError:
context_var.set(previous_value)
try: try:
item = await gen.__anext__() item = await gen.__anext__()