diff --git a/src/llama_stack/core/utils/context.py b/src/llama_stack/core/utils/context.py index 9221979b9..0c3e41f00 100644 --- a/src/llama_stack/core/utils/context.py +++ b/src/llama_stack/core/utils/context.py @@ -30,22 +30,10 @@ def preserve_contexts_async_generator[T]( # This is needed to propagate context across async generator boundaries for context_var in context_vars: try: - current_value = context_var.get() - previous_values[context_var] = current_value - # Only set if value actually changed to avoid creating unnecessary tokens - target_value = initial_context_values[context_var.name] - if current_value != target_value: - tokens[context_var] = context_var.set(target_value) - else: - tokens[context_var] = None + previous_values[context_var] = context_var.get() except LookupError: previous_values[context_var] = _MISSING - # Var was unset, set it to initial value - target_value = initial_context_values[context_var.name] - if target_value is not None: - tokens[context_var] = context_var.set(target_value) - else: - tokens[context_var] = None + tokens[context_var] = context_var.set(initial_context_values[context_var.name]) def _restore_context_var(context_var: ContextVar, *, _tokens=tokens, _prev=previous_values) -> None: token = _tokens.get(context_var) @@ -57,15 +45,10 @@ def preserve_contexts_async_generator[T]( except (RuntimeError, ValueError): pass - # Manual restoration when token reset fails - # Avoid creating unnecessary tokens to prevent accumulation that can cause deadlocks - if previous_value is not _MISSING: - try: - current_value = context_var.get() - if current_value != previous_value: - context_var.set(previous_value) - except LookupError: - context_var.set(previous_value) + if previous_value is _MISSING: + context_var.set(None) + else: + context_var.set(previous_value) try: item = await gen.__anext__()