mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 04:00:42 +00:00
fix(context): prevent provider data leak between streaming requests
The preserve_contexts_async_generator function was not cleaning up context variables after streaming iterations, causing PROVIDER_DATA_VAR to leak between sequential requests. Provider credentials or configuration from one request could persist and leak into subsequent requests. Root cause: Context variables were set at the start of each iteration but never cleared afterward. When generators were consumed outside their original context manager (after the with block exited), the context values remained set indefinitely. The fix clears context variables by setting them to None after each yield and when the generator terminates. This works reliably across all scenarios including when the library client wraps async generators for sync consumption (which creates new asyncio Contexts per iteration). Direct value setting avoids Context-scoped token issues that would occur with token-based reset. Added unit and integration tests that verify context isolation.
This commit is contained in:
parent
471b1b248b
commit
3ecb043d59
4 changed files with 345 additions and 9 deletions
60
tests/unit/core/test_provider_data_context.py
Normal file
60
tests/unit/core/test_provider_data_context.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import json
|
||||
import asyncio
|
||||
import pytest
|
||||
from contextvars import ContextVar
|
||||
from contextlib import contextmanager
|
||||
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
|
||||
# Define provider data context variable and context manager locally
|
||||
PROVIDER_DATA_VAR = ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def request_provider_data_context(headers):
|
||||
val = headers.get("X-LlamaStack-Provider-Data")
|
||||
provider_data = json.loads(val) if val else {}
|
||||
token = PROVIDER_DATA_VAR.set(provider_data)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
PROVIDER_DATA_VAR.reset(token)
|
||||
|
||||
|
||||
def create_sse_event(data):
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
async def sse_generator(event_gen_coroutine):
|
||||
event_gen = await event_gen_coroutine
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
async def async_event_gen():
|
||||
async def event_gen():
|
||||
yield PROVIDER_DATA_VAR.get()
|
||||
|
||||
return event_gen()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_data_context_cleared_between_sse_requests():
|
||||
headers = {"X-LlamaStack-Provider-Data": json.dumps({"api_key": "abc"})}
|
||||
with request_provider_data_context(headers):
|
||||
gen1 = preserve_contexts_async_generator(
|
||||
sse_generator(async_event_gen()), [PROVIDER_DATA_VAR]
|
||||
)
|
||||
|
||||
events1 = [event async for event in gen1]
|
||||
assert events1 == [create_sse_event({"api_key": "abc"})]
|
||||
assert PROVIDER_DATA_VAR.get() is None
|
||||
|
||||
gen2 = preserve_contexts_async_generator(
|
||||
sse_generator(async_event_gen()), [PROVIDER_DATA_VAR]
|
||||
)
|
||||
events2 = [event async for event in gen2]
|
||||
assert events2 == [create_sse_event(None)]
|
||||
assert PROVIDER_DATA_VAR.get() is None
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue