This commit is contained in:
Ashwin Bharambe 2025-10-27 22:23:07 -07:00
parent 907ba5aecf
commit 6a849c3b18
2 changed files with 13 additions and 15 deletions

View file

@ -7,7 +7,6 @@
from collections.abc import AsyncGenerator
from contextvars import ContextVar
_MISSING = object()
@ -35,9 +34,9 @@ def preserve_contexts_async_generator[T](
previous_values[context_var] = _MISSING
tokens[context_var] = context_var.set(initial_context_values[context_var.name])
def _restore_context_var(context_var: ContextVar) -> None:
token = tokens.get(context_var)
previous_value = previous_values.get(context_var, _MISSING)
def _restore_context_var(context_var: ContextVar, *, _tokens=tokens, _prev=previous_values) -> None:
token = _tokens.get(context_var)
previous_value = _prev.get(context_var, _MISSING)
if token is not None:
try:
context_var.reset(token)

View file

@ -1,8 +1,13 @@
import json
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import pytest
from contextvars import ContextVar
import json
from contextlib import contextmanager
from contextvars import ContextVar
from llama_stack.core.utils.context import preserve_contexts_async_generator
@ -39,22 +44,16 @@ async def async_event_gen():
return event_gen()
@pytest.mark.asyncio
async def test_provider_data_context_cleared_between_sse_requests():
headers = {"X-LlamaStack-Provider-Data": json.dumps({"api_key": "abc"})}
with request_provider_data_context(headers):
gen1 = preserve_contexts_async_generator(
sse_generator(async_event_gen()), [PROVIDER_DATA_VAR]
)
gen1 = preserve_contexts_async_generator(sse_generator(async_event_gen()), [PROVIDER_DATA_VAR])
events1 = [event async for event in gen1]
assert events1 == [create_sse_event({"api_key": "abc"})]
assert PROVIDER_DATA_VAR.get() is None
gen2 = preserve_contexts_async_generator(
sse_generator(async_event_gen()), [PROVIDER_DATA_VAR]
)
gen2 = preserve_contexts_async_generator(sse_generator(async_event_gen()), [PROVIDER_DATA_VAR])
events2 = [event async for event in gen2]
assert events2 == [create_sse_event(None)]
assert PROVIDER_DATA_VAR.get() is None