diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 82650ea40..15c4fe6ea 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -376,18 +376,17 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): body = self._convert_body(path, options.method, body) - await start_trace(options.url, {"__location__": "library_client"}) - async def gen(): - async for chunk in await func(**body): - data = json.dumps(convert_pydantic_to_json_value(chunk)) - sse_event = f"data: {data}\n\n" - yield sse_event.encode("utf-8") + await start_trace(options.url, {"__location__": "library_client"}) + try: + async for chunk in await func(**body): + data = json.dumps(convert_pydantic_to_json_value(chunk)) + sse_event = f"data: {data}\n\n" + yield sse_event.encode("utf-8") + finally: + await end_trace() - try: - wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]) - finally: - await end_trace() + wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]) mock_response = httpx.Response( status_code=httpx.codes.OK, diff --git a/llama_stack/distribution/utils/context.py b/llama_stack/distribution/utils/context.py index a76edb60c..107ce7127 100644 --- a/llama_stack/distribution/utils/context.py +++ b/llama_stack/distribution/utils/context.py @@ -14,19 +14,19 @@ def preserve_contexts_async_generator( gen: AsyncGenerator[T, None], context_vars: List[ContextVar] ) -> AsyncGenerator[T, None]: """ - Wraps an async generator to preserve both tracing and headers context variables across iterations. - This is needed because we start a new asyncio event loop for each request, and we need to preserve the context - across the event loop boundary. + Wraps an async generator to preserve context variables across iterations. + This is needed because we start a new asyncio event loop for each streaming request, + and we need to preserve the context across the event loop boundary. """ - context_values = [context_var.get() for context_var in context_vars] async def wrapper(): while True: - for context_var, context_value in zip(context_vars, context_values, strict=False): - _ = context_var.set(context_value) try: item = await gen.__anext__() + context_values = {context_var.name: context_var.get() for context_var in context_vars} yield item + for context_var in context_vars: + _ = context_var.set(context_values[context_var.name]) except StopAsyncIteration: break diff --git a/llama_stack/distribution/utils/tests/test_context.py b/llama_stack/distribution/utils/tests/test_context.py new file mode 100644 index 000000000..84944bfe8 --- /dev/null +++ b/llama_stack/distribution/utils/tests/test_context.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from contextvars import ContextVar + +import pytest + +from llama_stack.distribution.utils.context import preserve_contexts_async_generator + + +@pytest.mark.asyncio +async def test_preserve_contexts_with_exception(): + # Create context variable + context_var = ContextVar("exception_var", default="initial") + token = context_var.set("start_value") + + # Create an async generator that raises an exception + async def exception_generator(): + yield context_var.get() + context_var.set("modified") + raise ValueError("Test exception") + yield None # This will never be reached + + # Wrap the generator + wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var]) + + # First iteration should work + value = await wrapped_gen.__anext__() + assert value == "start_value" + + # Second iteration should raise the exception + with pytest.raises(ValueError, match="Test exception"): + await wrapped_gen.__anext__() + + # Clean up + context_var.reset(token) + + +@pytest.mark.asyncio +async def test_preserve_contexts_empty_generator(): + # Create context variable + context_var = ContextVar("empty_var", default="initial") + token = context_var.set("value") + + # Create an empty async generator + async def empty_generator(): + if False: # This condition ensures the generator yields nothing + yield None + + # Wrap the generator + wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var]) + + # The generator should raise StopAsyncIteration immediately + with pytest.raises(StopAsyncIteration): + await wrapped_gen.__anext__() + + # Context variable should remain unchanged + assert context_var.get() == "value" + + # Clean up + context_var.reset(token) + + +@pytest.mark.asyncio +async def test_preserve_contexts_across_event_loops(): + """ + Test that context variables are preserved across event loop boundaries with nested generators. + This simulates the real-world scenario where: + 1. A new event loop is created for each streaming request + 2. The async generator runs inside that loop + 3. There are multiple levels of nested generators + 4. Context needs to be preserved across these boundaries + """ + # Create context variables + request_id = ContextVar("request_id", default=None) + user_id = ContextVar("user_id", default=None) + + # Set initial values + + # Results container to verify values across thread boundaries + results = [] + + # Inner-most generator (level 2) + async def inner_generator(): + # Should have the context from the outer scope + yield (1, request_id.get(), user_id.get()) + + # Modify one context variable + user_id.set("user-modified") + + # Should reflect the modification + yield (2, request_id.get(), user_id.get()) + + # Middle generator (level 1) + async def middle_generator(): + inner_gen = inner_generator() + + # Forward the first yield from inner + item = await inner_gen.__anext__() + yield item + + # Forward the second yield from inner + item = await inner_gen.__anext__() + yield item + + request_id.set("req-modified") + + # Add our own yield with both modified variables + yield (3, request_id.get(), user_id.get()) + + # Function to run in a separate thread with a new event loop + def run_in_new_loop(): + # Create a new event loop for this thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # Outer generator (runs in the new loop) + async def outer_generator(): + request_id.set("req-12345") + user_id.set("user-6789") + # Wrap the middle generator + wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id]) + + # Process all items from the middle generator + async for item in wrapped_gen: + # Store results for verification + results.append(item) + + # Run the outer generator in the new loop + loop.run_until_complete(outer_generator()) + finally: + loop.close() + + # Run the generator chain in a separate thread with a new event loop + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_in_new_loop) + future.result() # Wait for completion + + # Verify the results + assert len(results) == 3 + + # First yield should have original values + assert results[0] == (1, "req-12345", "user-6789") + + # Second yield should have modified user_id + assert results[1] == (2, "req-12345", "user-modified") + + # Third yield should have both modified values + assert results[2] == (3, "req-modified", "user-modified")