diff --git a/src/llama_stack/core/utils/context.py b/src/llama_stack/core/utils/context.py index 24b249890..af5a625e5 100644 --- a/src/llama_stack/core/utils/context.py +++ b/src/llama_stack/core/utils/context.py @@ -21,20 +21,26 @@ def preserve_contexts_async_generator[T]( async def wrapper() -> AsyncGenerator[T, None]: while True: + # Restore context values before any await + for context_var in context_vars: + context_var.set(initial_context_values[context_var.name]) + try: - # Restore context values before any await - for context_var in context_vars: - context_var.set(initial_context_values[context_var.name]) - item = await gen.__anext__() + except StopAsyncIteration: + # Clear context vars before exiting to prevent leaks + for context_var in context_vars: + context_var.set(None) + break + try: + yield item # Update our tracked values with any changes made during this iteration for context_var in context_vars: initial_context_values[context_var.name] = context_var.get() - - yield item - - except StopAsyncIteration: - break + finally: + # Clear context vars after each yield to prevent leaks between requests + for context_var in context_vars: + context_var.set(None) return wrapper() diff --git a/tests/integration/core/test_provider_data_context_leak.py b/tests/integration/core/test_provider_data_context_leak.py new file mode 100644 index 000000000..d4011b140 --- /dev/null +++ b/tests/integration/core/test_provider_data_context_leak.py @@ -0,0 +1,117 @@ +""" +Integration test for provider data context isolation in streaming requests. + +This test verifies that PROVIDER_DATA_VAR doesn't leak between sequential +streaming requests, ensuring provider credentials and configuration are +properly isolated between requests. +""" + +import json +import pytest + + +@pytest.mark.asyncio +async def test_provider_data_isolation_library_client(): + """ + Verifies that provider data context is properly isolated between + sequential streaming requests and cleaned up after each request. + """ + from llama_stack.core.request_headers import PROVIDER_DATA_VAR, request_provider_data_context + from llama_stack.core.utils.context import preserve_contexts_async_generator + + async def mock_streaming_provider(): + """Simulates a streaming provider that reads PROVIDER_DATA_VAR""" + provider_data = PROVIDER_DATA_VAR.get() + yield {"provider_data": provider_data, "chunk": 1} + + async def sse_generator(gen): + """Simulates the SSE generator in the server""" + async for item in gen: + yield f"data: {json.dumps(item)}\n\n" + + # Request 1: Set provider data to {"key": "value1"} + headers1 = {"X-LlamaStack-Provider-Data": json.dumps({"key": "value1"})} + with request_provider_data_context(headers1): + gen1 = preserve_contexts_async_generator( + sse_generator(mock_streaming_provider()), + [PROVIDER_DATA_VAR] + ) + + chunks1 = [chunk async for chunk in gen1] + data1 = json.loads(chunks1[0].split("data: ")[1]) + assert data1["provider_data"] == {"key": "value1"} + + # Context should be cleared after consuming the generator + leaked_data = PROVIDER_DATA_VAR.get() + assert leaked_data is None, f"Context leaked after request 1: {leaked_data}" + + # Request 2: Set different provider data {"key": "value2"} + headers2 = {"X-LlamaStack-Provider-Data": json.dumps({"key": "value2"})} + with request_provider_data_context(headers2): + gen2 = preserve_contexts_async_generator( + sse_generator(mock_streaming_provider()), + [PROVIDER_DATA_VAR] + ) + + chunks2 = [chunk async for chunk in gen2] + data2 = json.loads(chunks2[0].split("data: ")[1]) + assert data2["provider_data"] == {"key": "value2"} + + leaked_data2 = PROVIDER_DATA_VAR.get() + assert leaked_data2 is None, f"Context leaked after request 2: {leaked_data2}" + + # Request 3: No provider data + gen3 = preserve_contexts_async_generator( + sse_generator(mock_streaming_provider()), + [PROVIDER_DATA_VAR] + ) + + chunks3 = [chunk async for chunk in gen3] + data3 = json.loads(chunks3[0].split("data: ")[1]) + assert data3["provider_data"] is None + + +@pytest.mark.skipif( + True, + reason="Requires custom test provider with context echo capability" +) +def test_provider_data_isolation_with_server(llama_stack_client): + """ + Server-based test for context isolation (currently skipped). + + Requires a test inference provider that echoes back PROVIDER_DATA_VAR + in streaming responses to verify proper isolation. + """ + response1 = llama_stack_client.inference.chat_completion( + model_id="context-echo-model", + messages=[{"role": "user", "content": "test"}], + stream=True, + extra_headers={ + "X-LlamaStack-Provider-Data": json.dumps({"test_key": "value1"}) + }, + ) + + chunks1 = [] + for chunk in response1: + if chunk.choices and chunk.choices[0].delta.content: + chunks1.append(chunk.choices[0].delta.content) + + response1_data = json.loads("".join(chunks1)) + assert response1_data["provider_data"] == {"test_key": "value1"} + + response2 = llama_stack_client.inference.chat_completion( + model_id="context-echo-model", + messages=[{"role": "user", "content": "test"}], + stream=True, + extra_headers={ + "X-LlamaStack-Provider-Data": json.dumps({"test_key": "value2"}) + }, + ) + + chunks2 = [] + for chunk in response2: + if chunk.choices and chunk.choices[0].delta.content: + chunks2.append(chunk.choices[0].delta.content) + + response2_data = json.loads("".join(chunks2)) + assert response2_data["provider_data"] == {"test_key": "value2"} diff --git a/tests/integration/fixtures/context_echo_provider.py b/tests/integration/fixtures/context_echo_provider.py new file mode 100644 index 000000000..c9ee04f7e --- /dev/null +++ b/tests/integration/fixtures/context_echo_provider.py @@ -0,0 +1,153 @@ +""" +Test-only inference provider that echoes PROVIDER_DATA_VAR in responses. + +This provider is used to test context isolation between requests in end-to-end +scenarios with a real server. +""" + +import json +from typing import AsyncIterator +from pydantic import BaseModel + +from llama_stack.apis.inference import ( + Inference, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAIChatCompletionRequestWithExtraBody, + OpenAICompletion, + OpenAICompletionRequestWithExtraBody, + OpenAIEmbeddingsRequestWithExtraBody, + OpenAIEmbeddingsResponse, +) +from llama_stack.apis.models import Model +from llama_stack.core.request_headers import PROVIDER_DATA_VAR +from llama_stack_client.types.inference_chat_completion_chunk import ( + ChatCompletionChunkChoice, + ChatCompletionChunkChoiceDelta, +) + + +class ContextEchoConfig(BaseModel): + """Minimal config for the test provider.""" + pass + + +class ContextEchoInferenceProvider(Inference): + """ + Test-only provider that echoes the current PROVIDER_DATA_VAR value. + + Used to detect context leaks between streaming requests in end-to-end tests. + """ + + def __init__(self, config: ContextEchoConfig) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def register_model(self, model: Model) -> Model: + return model + + async def unregister_model(self, model_id: str) -> None: + pass + + async def list_models(self) -> list[Model]: + return [] + + async def openai_embeddings( + self, + params: OpenAIEmbeddingsRequestWithExtraBody, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError("Embeddings not supported by test provider") + + async def openai_completion( + self, + params: OpenAICompletionRequestWithExtraBody, + ) -> OpenAICompletion: + raise NotImplementedError("Use openai_chat_completion instead") + + async def openai_chat_completion( + self, + params: OpenAIChatCompletionRequestWithExtraBody, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + """Echo the provider data context back in streaming chunks.""" + + async def stream_with_context(): + # Read the current provider data from context + # This is the KEY part - if context leaks, this will show old data + provider_data = PROVIDER_DATA_VAR.get() + + # Create a JSON message with the provider data + # The test will parse this to verify correct isolation + message = json.dumps({ + "provider_data": provider_data, + "test_marker": "context_echo" + }) + + # Yield a chunk with the provider data + yield OpenAIChatCompletionChunk( + id="context-echo-1", + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkChoiceDelta( + content=message, + role="assistant", + ), + index=0, + finish_reason=None, + ) + ], + created=0, + model=params.model, + object="chat.completion.chunk", + ) + + # Final chunk with finish_reason + yield OpenAIChatCompletionChunk( + id="context-echo-2", + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkChoiceDelta(), + index=0, + finish_reason="stop", + ) + ], + created=0, + model=params.model, + object="chat.completion.chunk", + ) + + if params.stream: + return stream_with_context() + else: + # Non-streaming fallback + provider_data = PROVIDER_DATA_VAR.get() + message_content = json.dumps({ + "provider_data": provider_data, + "test_marker": "context_echo" + }) + + from llama_stack_client.types.inference_chat_completion import ( + ChatCompletionChoice, + ChatCompletionMessage, + ) + + return OpenAIChatCompletion( + id="context-echo", + choices=[ + ChatCompletionChoice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + content=message_content, + role="assistant", + ), + ) + ], + created=0, + model=params.model, + object="chat.completion", + ) diff --git a/tests/unit/core/test_provider_data_context.py b/tests/unit/core/test_provider_data_context.py new file mode 100644 index 000000000..06faa59ae --- /dev/null +++ b/tests/unit/core/test_provider_data_context.py @@ -0,0 +1,60 @@ +import json +import asyncio +import pytest +from contextvars import ContextVar +from contextlib import contextmanager + +from llama_stack.core.utils.context import preserve_contexts_async_generator + +# Define provider data context variable and context manager locally +PROVIDER_DATA_VAR = ContextVar("provider_data", default=None) + + +@contextmanager +def request_provider_data_context(headers): + val = headers.get("X-LlamaStack-Provider-Data") + provider_data = json.loads(val) if val else {} + token = PROVIDER_DATA_VAR.set(provider_data) + try: + yield + finally: + PROVIDER_DATA_VAR.reset(token) + + +def create_sse_event(data): + return f"data: {json.dumps(data)}\n\n" + + +async def sse_generator(event_gen_coroutine): + event_gen = await event_gen_coroutine + async for item in event_gen: + yield create_sse_event(item) + await asyncio.sleep(0) + + +async def async_event_gen(): + async def event_gen(): + yield PROVIDER_DATA_VAR.get() + + return event_gen() + + +@pytest.mark.asyncio +async def test_provider_data_context_cleared_between_sse_requests(): + headers = {"X-LlamaStack-Provider-Data": json.dumps({"api_key": "abc"})} + with request_provider_data_context(headers): + gen1 = preserve_contexts_async_generator( + sse_generator(async_event_gen()), [PROVIDER_DATA_VAR] + ) + + events1 = [event async for event in gen1] + assert events1 == [create_sse_event({"api_key": "abc"})] + assert PROVIDER_DATA_VAR.get() is None + + gen2 = preserve_contexts_async_generator( + sse_generator(async_event_gen()), [PROVIDER_DATA_VAR] + ) + events2 = [event async for event in gen2] + assert events2 == [create_sse_event(None)] + assert PROVIDER_DATA_VAR.get() is None +