preserve context across async generator boundaries

This commit is contained in:
Ashwin Bharambe 2025-03-07 16:16:29 -08:00
parent 4a894b925d
commit 21769648a6
3 changed files with 41 additions and 5 deletions

View file

@ -32,7 +32,10 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import request_provider_data_context from llama_stack.distribution.request_headers import (
preserve_headers_context_async_generator,
request_provider_data_context,
)
from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
@ -378,9 +381,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
finally: finally:
await end_trace() await end_trace()
# Wrap the generator to preserve context across iterations
wrapped_gen = preserve_headers_context_async_generator(gen())
mock_response = httpx.Response( mock_response = httpx.Response(
status_code=httpx.codes.OK, status_code=httpx.codes.OK,
content=gen(), content=wrapped_gen,
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
}, },

View file

@ -7,7 +7,7 @@
import contextvars import contextvars
import json import json
import logging import logging
from typing import Any, ContextManager, Dict, Optional from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
from .utils.dynamic import instantiate_class_type from .utils.dynamic import instantiate_class_type
@ -35,6 +35,31 @@ class RequestProviderDataContext(ContextManager):
_provider_data_var.reset(self.token) _provider_data_var.reset(self.token)
T = TypeVar("T")
async def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve request headers context variables across iterations.
This ensures that context variables set during generator creation are
available during each iteration of the generator, even if the original
context manager has exited.
"""
# Capture the current context value
context_value = _provider_data_var.get()
# Create a wrapper that restores context for each iteration
async for item in gen:
# Save the current token to restore later
token = _provider_data_var.set(context_value)
try:
yield item
finally:
# Restore the previous value
_provider_data_var.reset(token)
class NeedsRequestProviderData: class NeedsRequestProviderData:
def get_request_provider_data(self) -> Any: def get_request_provider_data(self) -> Any:
spec = self.__provider_spec__ spec = self.__provider_spec__

View file

@ -29,7 +29,10 @@ from typing_extensions import Annotated
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import request_provider_data_context from llama_stack.distribution.request_headers import (
preserve_headers_context_async_generator,
request_provider_data_context,
)
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
@ -203,7 +206,9 @@ async def maybe_await(value):
async def sse_generator(event_gen): async def sse_generator(event_gen):
try: try:
event_gen = await event_gen event_gen = await event_gen
async for item in event_gen: # Wrap the generator to preserve context across iterations
wrapped_gen = preserve_headers_context_async_generator(event_gen)
async for item in wrapped_gen:
yield create_sse_event(item) yield create_sse_event(item)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
except asyncio.CancelledError: except asyncio.CancelledError: