This commit is contained in:
Ashwin Bharambe 2025-10-27 22:23:07 -07:00
parent 907ba5aecf
commit 6a849c3b18
2 changed files with 13 additions and 15 deletions

View file

@ -7,7 +7,6 @@
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextvars import ContextVar from contextvars import ContextVar
_MISSING = object() _MISSING = object()
@ -35,9 +34,9 @@ def preserve_contexts_async_generator[T](
previous_values[context_var] = _MISSING previous_values[context_var] = _MISSING
tokens[context_var] = context_var.set(initial_context_values[context_var.name]) tokens[context_var] = context_var.set(initial_context_values[context_var.name])
def _restore_context_var(context_var: ContextVar) -> None: def _restore_context_var(context_var: ContextVar, *, _tokens=tokens, _prev=previous_values) -> None:
token = tokens.get(context_var) token = _tokens.get(context_var)
previous_value = previous_values.get(context_var, _MISSING) previous_value = _prev.get(context_var, _MISSING)
if token is not None: if token is not None:
try: try:
context_var.reset(token) context_var.reset(token)

View file

@ -1,8 +1,13 @@
import json # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio import asyncio
import pytest import json
from contextvars import ContextVar
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar
from llama_stack.core.utils.context import preserve_contexts_async_generator from llama_stack.core.utils.context import preserve_contexts_async_generator
@ -39,22 +44,16 @@ async def async_event_gen():
return event_gen() return event_gen()
@pytest.mark.asyncio
async def test_provider_data_context_cleared_between_sse_requests(): async def test_provider_data_context_cleared_between_sse_requests():
headers = {"X-LlamaStack-Provider-Data": json.dumps({"api_key": "abc"})} headers = {"X-LlamaStack-Provider-Data": json.dumps({"api_key": "abc"})}
with request_provider_data_context(headers): with request_provider_data_context(headers):
gen1 = preserve_contexts_async_generator( gen1 = preserve_contexts_async_generator(sse_generator(async_event_gen()), [PROVIDER_DATA_VAR])
sse_generator(async_event_gen()), [PROVIDER_DATA_VAR]
)
events1 = [event async for event in gen1] events1 = [event async for event in gen1]
assert events1 == [create_sse_event({"api_key": "abc"})] assert events1 == [create_sse_event({"api_key": "abc"})]
assert PROVIDER_DATA_VAR.get() is None assert PROVIDER_DATA_VAR.get() is None
gen2 = preserve_contexts_async_generator( gen2 = preserve_contexts_async_generator(sse_generator(async_event_gen()), [PROVIDER_DATA_VAR])
sse_generator(async_event_gen()), [PROVIDER_DATA_VAR]
)
events2 = [event async for event in gen2] events2 = [event async for event in gen2]
assert events2 == [create_sse_event(None)] assert events2 == [create_sse_event(None)]
assert PROVIDER_DATA_VAR.get() is None assert PROVIDER_DATA_VAR.get() is None