address feedback

This commit is contained in:
Dinesh Yeduguru 2025-03-11 18:04:21 -07:00
parent a900740e30
commit 714c09cd53
4 changed files with 47 additions and 30 deletions

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from contextvars import ContextVar
from typing import AsyncGenerator, List, TypeVar
T = TypeVar("T")
def preserve_contexts_async_generator(
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve both tracing and headers context variables across iterations.
This is needed because we start a new asyncio event loop for each request, and we need to preserve the context
across the event loop boundary.
"""
context_values = [context_var.get() for context_var in context_vars]
async def wrapper():
while True:
for context_var, context_value in zip(context_vars, context_values, strict=False):
_ = context_var.set(context_value)
try:
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break
return wrapper()