diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index bc496cca5..ab8ff60fa 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -32,7 +32,10 @@ from termcolor import cprint from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.datatypes import Api -from llama_stack.distribution.request_headers import request_provider_data_context +from llama_stack.distribution.request_headers import ( + preserve_headers_context_async_generator, + request_provider_data_context, +) from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.stack import ( @@ -378,9 +381,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): finally: await end_trace() + # Wrap the generator to preserve context across iterations + wrapped_gen = preserve_headers_context_async_generator(gen()) + mock_response = httpx.Response( status_code=httpx.codes.OK, - content=gen(), + content=wrapped_gen, headers={ "Content-Type": "application/json", }, diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 87850f752..f617ce945 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -7,7 +7,7 @@ import contextvars import json import logging -from typing import Any, ContextManager, Dict, Optional +from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar from .utils.dynamic import instantiate_class_type @@ -35,6 +35,31 @@ class RequestProviderDataContext(ContextManager): _provider_data_var.reset(self.token) +T = TypeVar("T") + + +async def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]: + """ + Wraps an async generator to preserve request headers context variables across iterations. + + This ensures that context variables set during generator creation are + available during each iteration of the generator, even if the original + context manager has exited. + """ + # Capture the current context value + context_value = _provider_data_var.get() + + # Create a wrapper that restores context for each iteration + async for item in gen: + # Save the current token to restore later + token = _provider_data_var.set(context_value) + try: + yield item + finally: + # Restore the previous value + _provider_data_var.reset(token) + + class NeedsRequestProviderData: def get_request_provider_data(self) -> Any: spec = self.__provider_spec__ diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 9d4a29a31..842798614 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -29,7 +29,10 @@ from typing_extensions import Annotated from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis -from llama_stack.distribution.request_headers import request_provider_data_context +from llama_stack.distribution.request_headers import ( + preserve_headers_context_async_generator, + request_provider_data_context, +) from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.stack import ( construct_stack, @@ -203,7 +206,9 @@ async def maybe_await(value): async def sse_generator(event_gen): try: event_gen = await event_gen - async for item in event_gen: + # Wrap the generator to preserve context across iterations + wrapped_gen = preserve_headers_context_async_generator(event_gen) + async for item in wrapped_gen: yield create_sse_event(item) await asyncio.sleep(0.01) except asyncio.CancelledError: