fix(context): restore context to previous values after streaming

This commit is contained in:
Ashwin Bharambe 2025-10-27 21:25:57 -07:00
parent 75cdc4dad2
commit 907ba5aecf

View file

@ -8,6 +8,9 @@ from collections.abc import AsyncGenerator
from contextvars import ContextVar from contextvars import ContextVar
_MISSING = object()
def preserve_contexts_async_generator[T]( def preserve_contexts_async_generator[T](
gen: AsyncGenerator[T, None], context_vars: list[ContextVar] gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
) -> AsyncGenerator[T, None]: ) -> AsyncGenerator[T, None]:
@ -21,17 +24,43 @@ def preserve_contexts_async_generator[T](
async def wrapper() -> AsyncGenerator[T, None]: async def wrapper() -> AsyncGenerator[T, None]:
while True: while True:
# Restore context values before any await previous_values: dict[ContextVar, object] = {}
tokens: dict[ContextVar, object] = {}
# Restore context values before any await and capture previous state
for context_var in context_vars: for context_var in context_vars:
context_var.set(initial_context_values[context_var.name]) try:
previous_values[context_var] = context_var.get()
except LookupError:
previous_values[context_var] = _MISSING
tokens[context_var] = context_var.set(initial_context_values[context_var.name])
def _restore_context_var(context_var: ContextVar) -> None:
token = tokens.get(context_var)
previous_value = previous_values.get(context_var, _MISSING)
if token is not None:
try:
context_var.reset(token)
return
except (RuntimeError, ValueError):
pass
if previous_value is _MISSING:
context_var.set(None)
else:
context_var.set(previous_value)
try: try:
item = await gen.__anext__() item = await gen.__anext__()
except StopAsyncIteration: except StopAsyncIteration:
# Clear context vars before exiting to prevent leaks # Restore context vars before exiting to prevent leaks
for context_var in context_vars: for context_var in context_vars:
context_var.set(None) _restore_context_var(context_var)
break break
except Exception:
for context_var in context_vars:
_restore_context_var(context_var)
raise
try: try:
yield item yield item
@ -39,8 +68,8 @@ def preserve_contexts_async_generator[T](
for context_var in context_vars: for context_var in context_vars:
initial_context_values[context_var.name] = context_var.get() initial_context_values[context_var.name] = context_var.get()
finally: finally:
# Clear context vars after each yield to prevent leaks between requests # Restore context vars after each yield to prevent leaks between requests
for context_var in context_vars: for context_var in context_vars:
context_var.set(None) _restore_context_var(context_var)
return wrapper() return wrapper()