preserve context across async generator boundaries

This commit is contained in:
Ashwin Bharambe 2025-03-07 16:16:29 -08:00
parent 4a894b925d
commit 21769648a6
3 changed files with 41 additions and 5 deletions

View file

@ -32,7 +32,10 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import request_provider_data_context
from llama_stack.distribution.request_headers import (
preserve_headers_context_async_generator,
request_provider_data_context,
)
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import (
@ -378,9 +381,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
finally:
await end_trace()
# Wrap the generator to preserve context across iterations
wrapped_gen = preserve_headers_context_async_generator(gen())
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=gen(),
content=wrapped_gen,
headers={
"Content-Type": "application/json",
},

View file

@ -7,7 +7,7 @@
import contextvars
import json
import logging
from typing import Any, ContextManager, Dict, Optional
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
from .utils.dynamic import instantiate_class_type
@ -35,6 +35,31 @@ class RequestProviderDataContext(ContextManager):
_provider_data_var.reset(self.token)
T = TypeVar("T")
async def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve request headers context variables across iterations.
This ensures that context variables set during generator creation are
available during each iteration of the generator, even if the original
context manager has exited.
"""
# Capture the current context value
context_value = _provider_data_var.get()
# Create a wrapper that restores context for each iteration
async for item in gen:
# Save the current token to restore later
token = _provider_data_var.set(context_value)
try:
yield item
finally:
# Restore the previous value
_provider_data_var.reset(token)
class NeedsRequestProviderData:
def get_request_provider_data(self) -> Any:
spec = self.__provider_spec__

View file

@ -29,7 +29,10 @@ from typing_extensions import Annotated
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import request_provider_data_context
from llama_stack.distribution.request_headers import (
preserve_headers_context_async_generator,
request_provider_data_context,
)
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import (
construct_stack,
@ -203,7 +206,9 @@ async def maybe_await(value):
async def sse_generator(event_gen):
try:
event_gen = await event_gen
async for item in event_gen:
# Wrap the generator to preserve context across iterations
wrapped_gen = preserve_headers_context_async_generator(event_gen)
async for item in wrapped_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError: