mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
fix(context): prevent provider data leak between streaming requests (#3924)
## Summary - `preserve_contexts_async_generator` left `PROVIDER_DATA_VAR` (and other context vars) populated after a streaming generator completed on HEAD~1, so the asyncio context for request N+1 started with request N's provider payload. - FastAPI dependencies and middleware execute before `request_provider_data_context` rebinds the header data, meaning auth/logging hooks could observe a prior tenant's credentials or treat them as authenticated. Traces and any background work that inspects the context outside the `with` block leak as well—this is a real security regression, not just a CLI artifact. - The wrapper now restores each tracked `ContextVar` to the value it held before the iteration (falling back to clearing when necessary) after every yield and when the generator terminates, so provider data is wiped while callers that set their own defaults keep them. ## Test Plan - `uv run pytest tests/unit/core/test_provider_data_context.py -q` - `uv run pytest tests/unit/distribution/test_context.py -q` Both suites fail on HEAD~1 and pass with this change.
This commit is contained in:
parent
bf091306fe
commit
2c56a8560d
2 changed files with 114 additions and 11 deletions
|
|
@ -7,6 +7,10 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
from contextvars import ContextVar
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import CURRENT_TRACE_CONTEXT
|
||||
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
def preserve_contexts_async_generator[T](
|
||||
gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
|
||||
|
|
@ -21,20 +25,60 @@ def preserve_contexts_async_generator[T](
|
|||
|
||||
async def wrapper() -> AsyncGenerator[T, None]:
|
||||
while True:
|
||||
previous_values: dict[ContextVar, object] = {}
|
||||
tokens: dict[ContextVar, object] = {}
|
||||
|
||||
# Restore ALL context values before any await and capture previous state
|
||||
# This is needed to propagate context across async generator boundaries
|
||||
for context_var in context_vars:
|
||||
try:
|
||||
previous_values[context_var] = context_var.get()
|
||||
except LookupError:
|
||||
previous_values[context_var] = _MISSING
|
||||
tokens[context_var] = context_var.set(initial_context_values[context_var.name])
|
||||
|
||||
def _restore_context_var(context_var: ContextVar, *, _tokens=tokens, _prev=previous_values) -> None:
|
||||
token = _tokens.get(context_var)
|
||||
previous_value = _prev.get(context_var, _MISSING)
|
||||
if token is not None:
|
||||
try:
|
||||
context_var.reset(token)
|
||||
return
|
||||
except (RuntimeError, ValueError):
|
||||
pass
|
||||
|
||||
if previous_value is _MISSING:
|
||||
context_var.set(None)
|
||||
else:
|
||||
context_var.set(previous_value)
|
||||
|
||||
try:
|
||||
# Restore context values before any await
|
||||
for context_var in context_vars:
|
||||
context_var.set(initial_context_values[context_var.name])
|
||||
|
||||
item = await gen.__anext__()
|
||||
|
||||
# Update our tracked values with any changes made during this iteration
|
||||
for context_var in context_vars:
|
||||
initial_context_values[context_var.name] = context_var.get()
|
||||
|
||||
yield item
|
||||
|
||||
except StopAsyncIteration:
|
||||
# Restore all context vars before exiting to prevent leaks
|
||||
# Use _restore_context_var for all vars to properly restore to previous values
|
||||
for context_var in context_vars:
|
||||
_restore_context_var(context_var)
|
||||
break
|
||||
except Exception:
|
||||
# Restore all context vars on exception
|
||||
for context_var in context_vars:
|
||||
_restore_context_var(context_var)
|
||||
raise
|
||||
|
||||
try:
|
||||
yield item
|
||||
# Update our tracked values with any changes made during this iteration
|
||||
# Only for non-trace context vars - trace context must persist across yields
|
||||
# to allow nested span tracking for telemetry
|
||||
for context_var in context_vars:
|
||||
if context_var is not CURRENT_TRACE_CONTEXT:
|
||||
initial_context_values[context_var.name] = context_var.get()
|
||||
finally:
|
||||
# Restore non-trace context vars after each yield to prevent leaks between requests
|
||||
# CURRENT_TRACE_CONTEXT is NOT restored here to preserve telemetry span stack
|
||||
for context_var in context_vars:
|
||||
if context_var is not CURRENT_TRACE_CONTEXT:
|
||||
_restore_context_var(context_var)
|
||||
|
||||
return wrapper()
|
||||
|
|
|
|||
59
tests/unit/core/test_provider_data_context.py
Normal file
59
tests/unit/core/test_provider_data_context.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
|
||||
# Define provider data context variable and context manager locally
|
||||
PROVIDER_DATA_VAR = ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def request_provider_data_context(headers):
|
||||
val = headers.get("X-LlamaStack-Provider-Data")
|
||||
provider_data = json.loads(val) if val else {}
|
||||
token = PROVIDER_DATA_VAR.set(provider_data)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
PROVIDER_DATA_VAR.reset(token)
|
||||
|
||||
|
||||
def create_sse_event(data):
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
async def sse_generator(event_gen_coroutine):
|
||||
event_gen = await event_gen_coroutine
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
async def async_event_gen():
|
||||
async def event_gen():
|
||||
yield PROVIDER_DATA_VAR.get()
|
||||
|
||||
return event_gen()
|
||||
|
||||
|
||||
async def test_provider_data_context_cleared_between_sse_requests():
|
||||
headers = {"X-LlamaStack-Provider-Data": json.dumps({"api_key": "abc"})}
|
||||
with request_provider_data_context(headers):
|
||||
gen1 = preserve_contexts_async_generator(sse_generator(async_event_gen()), [PROVIDER_DATA_VAR])
|
||||
|
||||
events1 = [event async for event in gen1]
|
||||
assert events1 == [create_sse_event({"api_key": "abc"})]
|
||||
assert PROVIDER_DATA_VAR.get() is None
|
||||
|
||||
gen2 = preserve_contexts_async_generator(sse_generator(async_event_gen()), [PROVIDER_DATA_VAR])
|
||||
events2 = [event async for event in gen2]
|
||||
assert events2 == [create_sse_event(None)]
|
||||
assert PROVIDER_DATA_VAR.get() is None
|
||||
Loading…
Add table
Add a link
Reference in a new issue