diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index a8b855f4d..ca2ff5c97 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -250,6 +250,8 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: await log_request_pre_validation(request) test_context_token = None + test_context_var = None + reset_test_context_fn = None # Use context manager with both provider data and auth attributes with request_provider_data_context(request.headers, user): @@ -261,13 +263,18 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: ) test_context_token = sync_test_context_from_provider_data() + test_context_var = TEST_CONTEXT + reset_test_context_fn = reset_test_context is_streaming = is_streaming_request(func.__name__, request, **kwargs) try: if is_streaming: + context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR] + if test_context_var is not None: + context_vars.append(test_context_var) gen = preserve_contexts_async_generator( - sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR, TEST_CONTEXT] + sse_generator(func(**kwargs)), context_vars ) return StreamingResponse(gen, media_type="text/event-stream") else: @@ -287,8 +294,8 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}") raise translate_exception(e) from e finally: - if test_context_token is not None: - reset_test_context(test_context_token) + if test_context_token is not None and reset_test_context_fn is not None: + reset_test_context_fn(test_context_token) sig = inspect.signature(func)