From 3ecb043d59efd50cb5741846343b4549f744a276 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 27 Oct 2025 13:00:46 -0700 Subject: [PATCH] fix(context): prevent provider data leak between streaming requests The preserve_contexts_async_generator function was not cleaning up context variables after streaming iterations, causing PROVIDER_DATA_VAR to leak between sequential requests. Provider credentials or configuration from one request could persist and leak into subsequent requests. Root cause: Context variables were set at the start of each iteration but never cleared afterward. When generators were consumed outside their original context manager (after the with block exited), the context values remained set indefinitely. The fix clears context variables by setting them to None after each yield and when the generator terminates. This works reliably across all scenarios including when the library client wraps async generators for sync consumption (which creates new asyncio Contexts per iteration). Direct value setting avoids Context-scoped token issues that would occur with token-based reset. Added unit and integration tests that verify context isolation. --- src/llama_stack/core/utils/context.py | 24 +-- .../core/test_provider_data_context_leak.py | 117 ++++++++++++++ .../fixtures/context_echo_provider.py | 153 ++++++++++++++++++ tests/unit/core/test_provider_data_context.py | 60 +++++++ 4 files changed, 345 insertions(+), 9 deletions(-) create mode 100644 tests/integration/core/test_provider_data_context_leak.py create mode 100644 tests/integration/fixtures/context_echo_provider.py create mode 100644 tests/unit/core/test_provider_data_context.py diff --git a/src/llama_stack/core/utils/context.py b/src/llama_stack/core/utils/context.py index 24b249890..af5a625e5 100644 --- a/src/llama_stack/core/utils/context.py +++ b/src/llama_stack/core/utils/context.py @@ -21,20 +21,26 @@ def preserve_contexts_async_generator[T]( async def wrapper() -> AsyncGenerator[T, None]: while True: + # Restore context values before any await + for context_var in context_vars: + context_var.set(initial_context_values[context_var.name]) + try: - # Restore context values before any await - for context_var in context_vars: - context_var.set(initial_context_values[context_var.name]) - item = await gen.__anext__() + except StopAsyncIteration: + # Clear context vars before exiting to prevent leaks + for context_var in context_vars: + context_var.set(None) + break + try: + yield item # Update our tracked values with any changes made during this iteration for context_var in context_vars: initial_context_values[context_var.name] = context_var.get() - - yield item - - except StopAsyncIteration: - break + finally: + # Clear context vars after each yield to prevent leaks between requests + for context_var in context_vars: + context_var.set(None) return wrapper() diff --git a/tests/integration/core/test_provider_data_context_leak.py b/tests/integration/core/test_provider_data_context_leak.py new file mode 100644 index 000000000..d4011b140 --- /dev/null +++ b/tests/integration/core/test_provider_data_context_leak.py @@ -0,0 +1,117 @@ +""" +Integration test for provider data context isolation in streaming requests. + +This test verifies that PROVIDER_DATA_VAR doesn't leak between sequential +streaming requests, ensuring provider credentials and configuration are +properly isolated between requests. +""" + +import json +import pytest + + +@pytest.mark.asyncio +async def test_provider_data_isolation_library_client(): + """ + Verifies that provider data context is properly isolated between + sequential streaming requests and cleaned up after each request. + """ + from llama_stack.core.request_headers import PROVIDER_DATA_VAR, request_provider_data_context + from llama_stack.core.utils.context import preserve_contexts_async_generator + + async def mock_streaming_provider(): + """Simulates a streaming provider that reads PROVIDER_DATA_VAR""" + provider_data = PROVIDER_DATA_VAR.get() + yield {"provider_data": provider_data, "chunk": 1} + + async def sse_generator(gen): + """Simulates the SSE generator in the server""" + async for item in gen: + yield f"data: {json.dumps(item)}\n\n" + + # Request 1: Set provider data to {"key": "value1"} + headers1 = {"X-LlamaStack-Provider-Data": json.dumps({"key": "value1"})} + with request_provider_data_context(headers1): + gen1 = preserve_contexts_async_generator( + sse_generator(mock_streaming_provider()), + [PROVIDER_DATA_VAR] + ) + + chunks1 = [chunk async for chunk in gen1] + data1 = json.loads(chunks1[0].split("data: ")[1]) + assert data1["provider_data"] == {"key": "value1"} + + # Context should be cleared after consuming the generator + leaked_data = PROVIDER_DATA_VAR.get() + assert leaked_data is None, f"Context leaked after request 1: {leaked_data}" + + # Request 2: Set different provider data {"key": "value2"} + headers2 = {"X-LlamaStack-Provider-Data": json.dumps({"key": "value2"})} + with request_provider_data_context(headers2): + gen2 = preserve_contexts_async_generator( + sse_generator(mock_streaming_provider()), + [PROVIDER_DATA_VAR] + ) + + chunks2 = [chunk async for chunk in gen2] + data2 = json.loads(chunks2[0].split("data: ")[1]) + assert data2["provider_data"] == {"key": "value2"} + + leaked_data2 = PROVIDER_DATA_VAR.get() + assert leaked_data2 is None, f"Context leaked after request 2: {leaked_data2}" + + # Request 3: No provider data + gen3 = preserve_contexts_async_generator( + sse_generator(mock_streaming_provider()), + [PROVIDER_DATA_VAR] + ) + + chunks3 = [chunk async for chunk in gen3] + data3 = json.loads(chunks3[0].split("data: ")[1]) + assert data3["provider_data"] is None + + +@pytest.mark.skipif( + True, + reason="Requires custom test provider with context echo capability" +) +def test_provider_data_isolation_with_server(llama_stack_client): + """ + Server-based test for context isolation (currently skipped). + + Requires a test inference provider that echoes back PROVIDER_DATA_VAR + in streaming responses to verify proper isolation. + """ + response1 = llama_stack_client.inference.chat_completion( + model_id="context-echo-model", + messages=[{"role": "user", "content": "test"}], + stream=True, + extra_headers={ + "X-LlamaStack-Provider-Data": json.dumps({"test_key": "value1"}) + }, + ) + + chunks1 = [] + for chunk in response1: + if chunk.choices and chunk.choices[0].delta.content: + chunks1.append(chunk.choices[0].delta.content) + + response1_data = json.loads("".join(chunks1)) + assert response1_data["provider_data"] == {"test_key": "value1"} + + response2 = llama_stack_client.inference.chat_completion( + model_id="context-echo-model", + messages=[{"role": "user", "content": "test"}], + stream=True, + extra_headers={ + "X-LlamaStack-Provider-Data": json.dumps({"test_key": "value2"}) + }, + ) + + chunks2 = [] + for chunk in response2: + if chunk.choices and chunk.choices[0].delta.content: + chunks2.append(chunk.choices[0].delta.content) + + response2_data = json.loads("".join(chunks2)) + assert response2_data["provider_data"] == {"test_key": "value2"} diff --git a/tests/integration/fixtures/context_echo_provider.py b/tests/integration/fixtures/context_echo_provider.py new file mode 100644 index 000000000..c9ee04f7e --- /dev/null +++ b/tests/integration/fixtures/context_echo_provider.py @@ -0,0 +1,153 @@ +""" +Test-only inference provider that echoes PROVIDER_DATA_VAR in responses. + +This provider is used to test context isolation between requests in end-to-end +scenarios with a real server. +""" + +import json +from typing import AsyncIterator +from pydantic import BaseModel + +from llama_stack.apis.inference import ( + Inference, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAIChatCompletionRequestWithExtraBody, + OpenAICompletion, + OpenAICompletionRequestWithExtraBody, + OpenAIEmbeddingsRequestWithExtraBody, + OpenAIEmbeddingsResponse, +) +from llama_stack.apis.models import Model +from llama_stack.core.request_headers import PROVIDER_DATA_VAR +from llama_stack_client.types.inference_chat_completion_chunk import ( + ChatCompletionChunkChoice, + ChatCompletionChunkChoiceDelta, +) + + +class ContextEchoConfig(BaseModel): + """Minimal config for the test provider.""" + pass + + +class ContextEchoInferenceProvider(Inference): + """ + Test-only provider that echoes the current PROVIDER_DATA_VAR value. + + Used to detect context leaks between streaming requests in end-to-end tests. + """ + + def __init__(self, config: ContextEchoConfig) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def register_model(self, model: Model) -> Model: + return model + + async def unregister_model(self, model_id: str) -> None: + pass + + async def list_models(self) -> list[Model]: + return [] + + async def openai_embeddings( + self, + params: OpenAIEmbeddingsRequestWithExtraBody, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError("Embeddings not supported by test provider") + + async def openai_completion( + self, + params: OpenAICompletionRequestWithExtraBody, + ) -> OpenAICompletion: + raise NotImplementedError("Use openai_chat_completion instead") + + async def openai_chat_completion( + self, + params: OpenAIChatCompletionRequestWithExtraBody, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + """Echo the provider data context back in streaming chunks.""" + + async def stream_with_context(): + # Read the current provider data from context + # This is the KEY part - if context leaks, this will show old data + provider_data = PROVIDER_DATA_VAR.get() + + # Create a JSON message with the provider data + # The test will parse this to verify correct isolation + message = json.dumps({ + "provider_data": provider_data, + "test_marker": "context_echo" + }) + + # Yield a chunk with the provider data + yield OpenAIChatCompletionChunk( + id="context-echo-1", + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkChoiceDelta( + content=message, + role="assistant", + ), + index=0, + finish_reason=None, + ) + ], + created=0, + model=params.model, + object="chat.completion.chunk", + ) + + # Final chunk with finish_reason + yield OpenAIChatCompletionChunk( + id="context-echo-2", + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkChoiceDelta(), + index=0, + finish_reason="stop", + ) + ], + created=0, + model=params.model, + object="chat.completion.chunk", + ) + + if params.stream: + return stream_with_context() + else: + # Non-streaming fallback + provider_data = PROVIDER_DATA_VAR.get() + message_content = json.dumps({ + "provider_data": provider_data, + "test_marker": "context_echo" + }) + + from llama_stack_client.types.inference_chat_completion import ( + ChatCompletionChoice, + ChatCompletionMessage, + ) + + return OpenAIChatCompletion( + id="context-echo", + choices=[ + ChatCompletionChoice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + content=message_content, + role="assistant", + ), + ) + ], + created=0, + model=params.model, + object="chat.completion", + ) diff --git a/tests/unit/core/test_provider_data_context.py b/tests/unit/core/test_provider_data_context.py new file mode 100644 index 000000000..06faa59ae --- /dev/null +++ b/tests/unit/core/test_provider_data_context.py @@ -0,0 +1,60 @@ +import json +import asyncio +import pytest +from contextvars import ContextVar +from contextlib import contextmanager + +from llama_stack.core.utils.context import preserve_contexts_async_generator + +# Define provider data context variable and context manager locally +PROVIDER_DATA_VAR = ContextVar("provider_data", default=None) + + +@contextmanager +def request_provider_data_context(headers): + val = headers.get("X-LlamaStack-Provider-Data") + provider_data = json.loads(val) if val else {} + token = PROVIDER_DATA_VAR.set(provider_data) + try: + yield + finally: + PROVIDER_DATA_VAR.reset(token) + + +def create_sse_event(data): + return f"data: {json.dumps(data)}\n\n" + + +async def sse_generator(event_gen_coroutine): + event_gen = await event_gen_coroutine + async for item in event_gen: + yield create_sse_event(item) + await asyncio.sleep(0) + + +async def async_event_gen(): + async def event_gen(): + yield PROVIDER_DATA_VAR.get() + + return event_gen() + + +@pytest.mark.asyncio +async def test_provider_data_context_cleared_between_sse_requests(): + headers = {"X-LlamaStack-Provider-Data": json.dumps({"api_key": "abc"})} + with request_provider_data_context(headers): + gen1 = preserve_contexts_async_generator( + sse_generator(async_event_gen()), [PROVIDER_DATA_VAR] + ) + + events1 = [event async for event in gen1] + assert events1 == [create_sse_event({"api_key": "abc"})] + assert PROVIDER_DATA_VAR.get() is None + + gen2 = preserve_contexts_async_generator( + sse_generator(async_event_gen()), [PROVIDER_DATA_VAR] + ) + events2 = [event async for event in gen2] + assert events2 == [create_sse_event(None)] + assert PROVIDER_DATA_VAR.get() is None +