From 6a849c3b186b0c3128b04b6461923a84f82d722c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 27 Oct 2025 22:23:07 -0700 Subject: [PATCH] fixes --- src/llama_stack/core/utils/context.py | 7 +++---- tests/unit/core/test_provider_data_context.py | 21 +++++++++---------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/llama_stack/core/utils/context.py b/src/llama_stack/core/utils/context.py index 89fdf0d6f..87ad553e9 100644 --- a/src/llama_stack/core/utils/context.py +++ b/src/llama_stack/core/utils/context.py @@ -7,7 +7,6 @@ from collections.abc import AsyncGenerator from contextvars import ContextVar - _MISSING = object() @@ -35,9 +34,9 @@ def preserve_contexts_async_generator[T]( previous_values[context_var] = _MISSING tokens[context_var] = context_var.set(initial_context_values[context_var.name]) - def _restore_context_var(context_var: ContextVar) -> None: - token = tokens.get(context_var) - previous_value = previous_values.get(context_var, _MISSING) + def _restore_context_var(context_var: ContextVar, *, _tokens=tokens, _prev=previous_values) -> None: + token = _tokens.get(context_var) + previous_value = _prev.get(context_var, _MISSING) if token is not None: try: context_var.reset(token) diff --git a/tests/unit/core/test_provider_data_context.py b/tests/unit/core/test_provider_data_context.py index 06faa59ae..a45805863 100644 --- a/tests/unit/core/test_provider_data_context.py +++ b/tests/unit/core/test_provider_data_context.py @@ -1,8 +1,13 @@ -import json +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + import asyncio -import pytest -from contextvars import ContextVar +import json from contextlib import contextmanager +from contextvars import ContextVar from llama_stack.core.utils.context import preserve_contexts_async_generator @@ -39,22 +44,16 @@ async def async_event_gen(): return event_gen() -@pytest.mark.asyncio async def test_provider_data_context_cleared_between_sse_requests(): headers = {"X-LlamaStack-Provider-Data": json.dumps({"api_key": "abc"})} with request_provider_data_context(headers): - gen1 = preserve_contexts_async_generator( - sse_generator(async_event_gen()), [PROVIDER_DATA_VAR] - ) + gen1 = preserve_contexts_async_generator(sse_generator(async_event_gen()), [PROVIDER_DATA_VAR]) events1 = [event async for event in gen1] assert events1 == [create_sse_event({"api_key": "abc"})] assert PROVIDER_DATA_VAR.get() is None - gen2 = preserve_contexts_async_generator( - sse_generator(async_event_gen()), [PROVIDER_DATA_VAR] - ) + gen2 = preserve_contexts_async_generator(sse_generator(async_event_gen()), [PROVIDER_DATA_VAR]) events2 = [event async for event in gen2] assert events2 == [create_sse_event(None)] assert PROVIDER_DATA_VAR.get() is None -