diff --git a/src/llama_stack/core/utils/context.py b/src/llama_stack/core/utils/context.py index af5a625e5..89fdf0d6f 100644 --- a/src/llama_stack/core/utils/context.py +++ b/src/llama_stack/core/utils/context.py @@ -8,6 +8,9 @@ from collections.abc import AsyncGenerator from contextvars import ContextVar +_MISSING = object() + + def preserve_contexts_async_generator[T]( gen: AsyncGenerator[T, None], context_vars: list[ContextVar] ) -> AsyncGenerator[T, None]: @@ -21,17 +24,43 @@ def preserve_contexts_async_generator[T]( async def wrapper() -> AsyncGenerator[T, None]: while True: - # Restore context values before any await + previous_values: dict[ContextVar, object] = {} + tokens: dict[ContextVar, object] = {} + + # Restore context values before any await and capture previous state for context_var in context_vars: - context_var.set(initial_context_values[context_var.name]) + try: + previous_values[context_var] = context_var.get() + except LookupError: + previous_values[context_var] = _MISSING + tokens[context_var] = context_var.set(initial_context_values[context_var.name]) + + def _restore_context_var(context_var: ContextVar) -> None: + token = tokens.get(context_var) + previous_value = previous_values.get(context_var, _MISSING) + if token is not None: + try: + context_var.reset(token) + return + except (RuntimeError, ValueError): + pass + + if previous_value is _MISSING: + context_var.set(None) + else: + context_var.set(previous_value) try: item = await gen.__anext__() except StopAsyncIteration: - # Clear context vars before exiting to prevent leaks + # Restore context vars before exiting to prevent leaks for context_var in context_vars: - context_var.set(None) + _restore_context_var(context_var) break + except Exception: + for context_var in context_vars: + _restore_context_var(context_var) + raise try: yield item @@ -39,8 +68,8 @@ def preserve_contexts_async_generator[T]( for context_var in context_vars: initial_context_values[context_var.name] = context_var.get() finally: - # Clear context vars after each yield to prevent leaks between requests + # Restore context vars after each yield to prevent leaks between requests for context_var in context_vars: - context_var.set(None) + _restore_context_var(context_var) return wrapper()