diff --git a/tests/integration/core/test_provider_data_context_leak.py b/tests/integration/core/test_provider_data_context_leak.py deleted file mode 100644 index d4011b140..000000000 --- a/tests/integration/core/test_provider_data_context_leak.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -Integration test for provider data context isolation in streaming requests. - -This test verifies that PROVIDER_DATA_VAR doesn't leak between sequential -streaming requests, ensuring provider credentials and configuration are -properly isolated between requests. -""" - -import json -import pytest - - -@pytest.mark.asyncio -async def test_provider_data_isolation_library_client(): - """ - Verifies that provider data context is properly isolated between - sequential streaming requests and cleaned up after each request. - """ - from llama_stack.core.request_headers import PROVIDER_DATA_VAR, request_provider_data_context - from llama_stack.core.utils.context import preserve_contexts_async_generator - - async def mock_streaming_provider(): - """Simulates a streaming provider that reads PROVIDER_DATA_VAR""" - provider_data = PROVIDER_DATA_VAR.get() - yield {"provider_data": provider_data, "chunk": 1} - - async def sse_generator(gen): - """Simulates the SSE generator in the server""" - async for item in gen: - yield f"data: {json.dumps(item)}\n\n" - - # Request 1: Set provider data to {"key": "value1"} - headers1 = {"X-LlamaStack-Provider-Data": json.dumps({"key": "value1"})} - with request_provider_data_context(headers1): - gen1 = preserve_contexts_async_generator( - sse_generator(mock_streaming_provider()), - [PROVIDER_DATA_VAR] - ) - - chunks1 = [chunk async for chunk in gen1] - data1 = json.loads(chunks1[0].split("data: ")[1]) - assert data1["provider_data"] == {"key": "value1"} - - # Context should be cleared after consuming the generator - leaked_data = PROVIDER_DATA_VAR.get() - assert leaked_data is None, f"Context leaked after request 1: {leaked_data}" - - # Request 2: Set different provider data {"key": "value2"} - headers2 = {"X-LlamaStack-Provider-Data": json.dumps({"key": "value2"})} - with request_provider_data_context(headers2): - gen2 = preserve_contexts_async_generator( - sse_generator(mock_streaming_provider()), - [PROVIDER_DATA_VAR] - ) - - chunks2 = [chunk async for chunk in gen2] - data2 = json.loads(chunks2[0].split("data: ")[1]) - assert data2["provider_data"] == {"key": "value2"} - - leaked_data2 = PROVIDER_DATA_VAR.get() - assert leaked_data2 is None, f"Context leaked after request 2: {leaked_data2}" - - # Request 3: No provider data - gen3 = preserve_contexts_async_generator( - sse_generator(mock_streaming_provider()), - [PROVIDER_DATA_VAR] - ) - - chunks3 = [chunk async for chunk in gen3] - data3 = json.loads(chunks3[0].split("data: ")[1]) - assert data3["provider_data"] is None - - -@pytest.mark.skipif( - True, - reason="Requires custom test provider with context echo capability" -) -def test_provider_data_isolation_with_server(llama_stack_client): - """ - Server-based test for context isolation (currently skipped). - - Requires a test inference provider that echoes back PROVIDER_DATA_VAR - in streaming responses to verify proper isolation. - """ - response1 = llama_stack_client.inference.chat_completion( - model_id="context-echo-model", - messages=[{"role": "user", "content": "test"}], - stream=True, - extra_headers={ - "X-LlamaStack-Provider-Data": json.dumps({"test_key": "value1"}) - }, - ) - - chunks1 = [] - for chunk in response1: - if chunk.choices and chunk.choices[0].delta.content: - chunks1.append(chunk.choices[0].delta.content) - - response1_data = json.loads("".join(chunks1)) - assert response1_data["provider_data"] == {"test_key": "value1"} - - response2 = llama_stack_client.inference.chat_completion( - model_id="context-echo-model", - messages=[{"role": "user", "content": "test"}], - stream=True, - extra_headers={ - "X-LlamaStack-Provider-Data": json.dumps({"test_key": "value2"}) - }, - ) - - chunks2 = [] - for chunk in response2: - if chunk.choices and chunk.choices[0].delta.content: - chunks2.append(chunk.choices[0].delta.content) - - response2_data = json.loads("".join(chunks2)) - assert response2_data["provider_data"] == {"test_key": "value2"} diff --git a/tests/integration/fixtures/context_echo_provider.py b/tests/integration/fixtures/context_echo_provider.py deleted file mode 100644 index c9ee04f7e..000000000 --- a/tests/integration/fixtures/context_echo_provider.py +++ /dev/null @@ -1,153 +0,0 @@ -""" -Test-only inference provider that echoes PROVIDER_DATA_VAR in responses. - -This provider is used to test context isolation between requests in end-to-end -scenarios with a real server. -""" - -import json -from typing import AsyncIterator -from pydantic import BaseModel - -from llama_stack.apis.inference import ( - Inference, - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAIChatCompletionRequestWithExtraBody, - OpenAICompletion, - OpenAICompletionRequestWithExtraBody, - OpenAIEmbeddingsRequestWithExtraBody, - OpenAIEmbeddingsResponse, -) -from llama_stack.apis.models import Model -from llama_stack.core.request_headers import PROVIDER_DATA_VAR -from llama_stack_client.types.inference_chat_completion_chunk import ( - ChatCompletionChunkChoice, - ChatCompletionChunkChoiceDelta, -) - - -class ContextEchoConfig(BaseModel): - """Minimal config for the test provider.""" - pass - - -class ContextEchoInferenceProvider(Inference): - """ - Test-only provider that echoes the current PROVIDER_DATA_VAR value. - - Used to detect context leaks between streaming requests in end-to-end tests. - """ - - def __init__(self, config: ContextEchoConfig) -> None: - self.config = config - - async def initialize(self) -> None: - pass - - async def shutdown(self) -> None: - pass - - async def register_model(self, model: Model) -> Model: - return model - - async def unregister_model(self, model_id: str) -> None: - pass - - async def list_models(self) -> list[Model]: - return [] - - async def openai_embeddings( - self, - params: OpenAIEmbeddingsRequestWithExtraBody, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError("Embeddings not supported by test provider") - - async def openai_completion( - self, - params: OpenAICompletionRequestWithExtraBody, - ) -> OpenAICompletion: - raise NotImplementedError("Use openai_chat_completion instead") - - async def openai_chat_completion( - self, - params: OpenAIChatCompletionRequestWithExtraBody, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - """Echo the provider data context back in streaming chunks.""" - - async def stream_with_context(): - # Read the current provider data from context - # This is the KEY part - if context leaks, this will show old data - provider_data = PROVIDER_DATA_VAR.get() - - # Create a JSON message with the provider data - # The test will parse this to verify correct isolation - message = json.dumps({ - "provider_data": provider_data, - "test_marker": "context_echo" - }) - - # Yield a chunk with the provider data - yield OpenAIChatCompletionChunk( - id="context-echo-1", - choices=[ - ChatCompletionChunkChoice( - delta=ChatCompletionChunkChoiceDelta( - content=message, - role="assistant", - ), - index=0, - finish_reason=None, - ) - ], - created=0, - model=params.model, - object="chat.completion.chunk", - ) - - # Final chunk with finish_reason - yield OpenAIChatCompletionChunk( - id="context-echo-2", - choices=[ - ChatCompletionChunkChoice( - delta=ChatCompletionChunkChoiceDelta(), - index=0, - finish_reason="stop", - ) - ], - created=0, - model=params.model, - object="chat.completion.chunk", - ) - - if params.stream: - return stream_with_context() - else: - # Non-streaming fallback - provider_data = PROVIDER_DATA_VAR.get() - message_content = json.dumps({ - "provider_data": provider_data, - "test_marker": "context_echo" - }) - - from llama_stack_client.types.inference_chat_completion import ( - ChatCompletionChoice, - ChatCompletionMessage, - ) - - return OpenAIChatCompletion( - id="context-echo", - choices=[ - ChatCompletionChoice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage( - content=message_content, - role="assistant", - ), - ) - ], - created=0, - model=params.model, - object="chat.completion", - )