mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
test(core): remove provider data context leak integration
This commit is contained in:
parent
3ecb043d59
commit
75cdc4dad2
2 changed files with 0 additions and 270 deletions
|
|
@ -1,117 +0,0 @@
|
|||
"""
|
||||
Integration test for provider data context isolation in streaming requests.
|
||||
|
||||
This test verifies that PROVIDER_DATA_VAR doesn't leak between sequential
|
||||
streaming requests, ensuring provider credentials and configuration are
|
||||
properly isolated between requests.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_data_isolation_library_client():
|
||||
"""
|
||||
Verifies that provider data context is properly isolated between
|
||||
sequential streaming requests and cleaned up after each request.
|
||||
"""
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, request_provider_data_context
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
|
||||
async def mock_streaming_provider():
|
||||
"""Simulates a streaming provider that reads PROVIDER_DATA_VAR"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
yield {"provider_data": provider_data, "chunk": 1}
|
||||
|
||||
async def sse_generator(gen):
|
||||
"""Simulates the SSE generator in the server"""
|
||||
async for item in gen:
|
||||
yield f"data: {json.dumps(item)}\n\n"
|
||||
|
||||
# Request 1: Set provider data to {"key": "value1"}
|
||||
headers1 = {"X-LlamaStack-Provider-Data": json.dumps({"key": "value1"})}
|
||||
with request_provider_data_context(headers1):
|
||||
gen1 = preserve_contexts_async_generator(
|
||||
sse_generator(mock_streaming_provider()),
|
||||
[PROVIDER_DATA_VAR]
|
||||
)
|
||||
|
||||
chunks1 = [chunk async for chunk in gen1]
|
||||
data1 = json.loads(chunks1[0].split("data: ")[1])
|
||||
assert data1["provider_data"] == {"key": "value1"}
|
||||
|
||||
# Context should be cleared after consuming the generator
|
||||
leaked_data = PROVIDER_DATA_VAR.get()
|
||||
assert leaked_data is None, f"Context leaked after request 1: {leaked_data}"
|
||||
|
||||
# Request 2: Set different provider data {"key": "value2"}
|
||||
headers2 = {"X-LlamaStack-Provider-Data": json.dumps({"key": "value2"})}
|
||||
with request_provider_data_context(headers2):
|
||||
gen2 = preserve_contexts_async_generator(
|
||||
sse_generator(mock_streaming_provider()),
|
||||
[PROVIDER_DATA_VAR]
|
||||
)
|
||||
|
||||
chunks2 = [chunk async for chunk in gen2]
|
||||
data2 = json.loads(chunks2[0].split("data: ")[1])
|
||||
assert data2["provider_data"] == {"key": "value2"}
|
||||
|
||||
leaked_data2 = PROVIDER_DATA_VAR.get()
|
||||
assert leaked_data2 is None, f"Context leaked after request 2: {leaked_data2}"
|
||||
|
||||
# Request 3: No provider data
|
||||
gen3 = preserve_contexts_async_generator(
|
||||
sse_generator(mock_streaming_provider()),
|
||||
[PROVIDER_DATA_VAR]
|
||||
)
|
||||
|
||||
chunks3 = [chunk async for chunk in gen3]
|
||||
data3 = json.loads(chunks3[0].split("data: ")[1])
|
||||
assert data3["provider_data"] is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
True,
|
||||
reason="Requires custom test provider with context echo capability"
|
||||
)
|
||||
def test_provider_data_isolation_with_server(llama_stack_client):
|
||||
"""
|
||||
Server-based test for context isolation (currently skipped).
|
||||
|
||||
Requires a test inference provider that echoes back PROVIDER_DATA_VAR
|
||||
in streaming responses to verify proper isolation.
|
||||
"""
|
||||
response1 = llama_stack_client.inference.chat_completion(
|
||||
model_id="context-echo-model",
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
stream=True,
|
||||
extra_headers={
|
||||
"X-LlamaStack-Provider-Data": json.dumps({"test_key": "value1"})
|
||||
},
|
||||
)
|
||||
|
||||
chunks1 = []
|
||||
for chunk in response1:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
chunks1.append(chunk.choices[0].delta.content)
|
||||
|
||||
response1_data = json.loads("".join(chunks1))
|
||||
assert response1_data["provider_data"] == {"test_key": "value1"}
|
||||
|
||||
response2 = llama_stack_client.inference.chat_completion(
|
||||
model_id="context-echo-model",
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
stream=True,
|
||||
extra_headers={
|
||||
"X-LlamaStack-Provider-Data": json.dumps({"test_key": "value2"})
|
||||
},
|
||||
)
|
||||
|
||||
chunks2 = []
|
||||
for chunk in response2:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
chunks2.append(chunk.choices[0].delta.content)
|
||||
|
||||
response2_data = json.loads("".join(chunks2))
|
||||
assert response2_data["provider_data"] == {"test_key": "value2"}
|
||||
|
|
@ -1,153 +0,0 @@
|
|||
"""
|
||||
Test-only inference provider that echoes PROVIDER_DATA_VAR in responses.
|
||||
|
||||
This provider is used to test context isolation between requests in end-to-end
|
||||
scenarios with a real server.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import AsyncIterator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
|
||||
from llama_stack_client.types.inference_chat_completion_chunk import (
|
||||
ChatCompletionChunkChoice,
|
||||
ChatCompletionChunkChoiceDelta,
|
||||
)
|
||||
|
||||
|
||||
class ContextEchoConfig(BaseModel):
|
||||
"""Minimal config for the test provider."""
|
||||
pass
|
||||
|
||||
|
||||
class ContextEchoInferenceProvider(Inference):
|
||||
"""
|
||||
Test-only provider that echoes the current PROVIDER_DATA_VAR value.
|
||||
|
||||
Used to detect context leaks between streaming requests in end-to-end tests.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ContextEchoConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> list[Model]:
|
||||
return []
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError("Embeddings not supported by test provider")
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("Use openai_chat_completion instead")
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Echo the provider data context back in streaming chunks."""
|
||||
|
||||
async def stream_with_context():
|
||||
# Read the current provider data from context
|
||||
# This is the KEY part - if context leaks, this will show old data
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
|
||||
# Create a JSON message with the provider data
|
||||
# The test will parse this to verify correct isolation
|
||||
message = json.dumps({
|
||||
"provider_data": provider_data,
|
||||
"test_marker": "context_echo"
|
||||
})
|
||||
|
||||
# Yield a chunk with the provider data
|
||||
yield OpenAIChatCompletionChunk(
|
||||
id="context-echo-1",
|
||||
choices=[
|
||||
ChatCompletionChunkChoice(
|
||||
delta=ChatCompletionChunkChoiceDelta(
|
||||
content=message,
|
||||
role="assistant",
|
||||
),
|
||||
index=0,
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=params.model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
|
||||
# Final chunk with finish_reason
|
||||
yield OpenAIChatCompletionChunk(
|
||||
id="context-echo-2",
|
||||
choices=[
|
||||
ChatCompletionChunkChoice(
|
||||
delta=ChatCompletionChunkChoiceDelta(),
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=params.model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
|
||||
if params.stream:
|
||||
return stream_with_context()
|
||||
else:
|
||||
# Non-streaming fallback
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
message_content = json.dumps({
|
||||
"provider_data": provider_data,
|
||||
"test_marker": "context_echo"
|
||||
})
|
||||
|
||||
from llama_stack_client.types.inference_chat_completion import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
)
|
||||
|
||||
return OpenAIChatCompletion(
|
||||
id="context-echo",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=message_content,
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=params.model,
|
||||
object="chat.completion",
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue