diff --git a/src/llama_stack/core/utils/context.py b/src/llama_stack/core/utils/context.py index 24b249890..e7c61a8ed 100644 --- a/src/llama_stack/core/utils/context.py +++ b/src/llama_stack/core/utils/context.py @@ -7,6 +7,10 @@ from collections.abc import AsyncGenerator from contextvars import ContextVar +from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT + +_MISSING = object() + def preserve_contexts_async_generator[T]( gen: AsyncGenerator[T, None], context_vars: list[ContextVar] @@ -21,20 +25,60 @@ def preserve_contexts_async_generator[T]( async def wrapper() -> AsyncGenerator[T, None]: while True: + previous_values: dict[ContextVar, object] = {} + tokens: dict[ContextVar, object] = {} + + # Restore ALL context values before any await and capture previous state + # This is needed to propagate context across async generator boundaries + for context_var in context_vars: + try: + previous_values[context_var] = context_var.get() + except LookupError: + previous_values[context_var] = _MISSING + tokens[context_var] = context_var.set(initial_context_values[context_var.name]) + + def _restore_context_var(context_var: ContextVar, *, _tokens=tokens, _prev=previous_values) -> None: + token = _tokens.get(context_var) + previous_value = _prev.get(context_var, _MISSING) + if token is not None: + try: + context_var.reset(token) + return + except (RuntimeError, ValueError): + pass + + if previous_value is _MISSING: + context_var.set(None) + else: + context_var.set(previous_value) + try: - # Restore context values before any await - for context_var in context_vars: - context_var.set(initial_context_values[context_var.name]) - item = await gen.__anext__() - - # Update our tracked values with any changes made during this iteration - for context_var in context_vars: - initial_context_values[context_var.name] = context_var.get() - - yield item - except StopAsyncIteration: + # Restore all context vars before exiting to prevent leaks + # Use _restore_context_var for all vars to properly restore to previous values + for context_var in context_vars: + _restore_context_var(context_var) break + except Exception: + # Restore all context vars on exception + for context_var in context_vars: + _restore_context_var(context_var) + raise + + try: + yield item + # Update our tracked values with any changes made during this iteration + # Only for non-trace context vars - trace context must persist across yields + # to allow nested span tracking for telemetry + for context_var in context_vars: + if context_var is not CURRENT_TRACE_CONTEXT: + initial_context_values[context_var.name] = context_var.get() + finally: + # Restore non-trace context vars after each yield to prevent leaks between requests + # CURRENT_TRACE_CONTEXT is NOT restored here to preserve telemetry span stack + for context_var in context_vars: + if context_var is not CURRENT_TRACE_CONTEXT: + _restore_context_var(context_var) return wrapper() diff --git a/tests/unit/core/test_provider_data_context.py b/tests/unit/core/test_provider_data_context.py new file mode 100644 index 000000000..a45805863 --- /dev/null +++ b/tests/unit/core/test_provider_data_context.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json +from contextlib import contextmanager +from contextvars import ContextVar + +from llama_stack.core.utils.context import preserve_contexts_async_generator + +# Define provider data context variable and context manager locally +PROVIDER_DATA_VAR = ContextVar("provider_data", default=None) + + +@contextmanager +def request_provider_data_context(headers): + val = headers.get("X-LlamaStack-Provider-Data") + provider_data = json.loads(val) if val else {} + token = PROVIDER_DATA_VAR.set(provider_data) + try: + yield + finally: + PROVIDER_DATA_VAR.reset(token) + + +def create_sse_event(data): + return f"data: {json.dumps(data)}\n\n" + + +async def sse_generator(event_gen_coroutine): + event_gen = await event_gen_coroutine + async for item in event_gen: + yield create_sse_event(item) + await asyncio.sleep(0) + + +async def async_event_gen(): + async def event_gen(): + yield PROVIDER_DATA_VAR.get() + + return event_gen() + + +async def test_provider_data_context_cleared_between_sse_requests(): + headers = {"X-LlamaStack-Provider-Data": json.dumps({"api_key": "abc"})} + with request_provider_data_context(headers): + gen1 = preserve_contexts_async_generator(sse_generator(async_event_gen()), [PROVIDER_DATA_VAR]) + + events1 = [event async for event in gen1] + assert events1 == [create_sse_event({"api_key": "abc"})] + assert PROVIDER_DATA_VAR.get() is None + + gen2 = preserve_contexts_async_generator(sse_generator(async_event_gen()), [PROVIDER_DATA_VAR]) + events2 = [event async for event in gen2] + assert events2 == [create_sse_event(None)] + assert PROVIDER_DATA_VAR.get() is None