diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index ee95f3b37..c5d6b6af5 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -13,7 +13,7 @@ import re from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path -from typing import Any, Optional, TypeVar, Union, get_args, get_origin +from typing import Any, AsyncGenerator, Optional, TypeVar, Union, get_args, get_origin import httpx import yaml @@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.datatypes import Api from llama_stack.distribution.request_headers import ( - preserve_headers_context_async_generator, + _provider_data_var, request_provider_data_context, ) from llama_stack.distribution.resolver import ProviderRegistry @@ -159,11 +159,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient): def sync_generator(): try: - global trace_context async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs)) while True: - if trace_context: - CURRENT_TRACE_CONTEXT.set(trace_context) chunk = loop.run_until_complete(async_stream.__anext__()) yield chunk except StopAsyncIteration: @@ -380,20 +377,21 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): body = self._convert_body(path, options.method, body) + await start_trace(options.url, {"__location__": "library_client"}) + async def gen(): - await start_trace(options.url, {"__location__": "library_client"}) - global trace_context - trace_context = CURRENT_TRACE_CONTEXT.get() - try: - async for chunk in await func(**body): - data = json.dumps(convert_pydantic_to_json_value(chunk)) - sse_event = f"data: {data}\n\n" - yield sse_event.encode("utf-8") - finally: - await end_trace() + async for chunk in await func(**body): + data = json.dumps(convert_pydantic_to_json_value(chunk)) + sse_event = f"data: {data}\n\n" + yield sse_event.encode("utf-8") # Wrap the generator to preserve context across iterations - wrapped_gen = preserve_headers_context_async_generator(gen()) + try: + # Combine both context preservations in a single pass + wrapped_gen = self._preserve_contexts_async_generator(gen()) + finally: + await end_trace() + mock_response = httpx.Response( status_code=httpx.codes.OK, content=wrapped_gen, @@ -424,6 +422,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) return await response.parse() + def _preserve_contexts_async_generator(self, gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]: + """ + Wraps an async generator to preserve both tracing and headers context variables across iterations. + This is needed because we start a new asyncio event loop for each request, and we need to preserve the context + across the event loop boundary. + """ + tracing_context_value = CURRENT_TRACE_CONTEXT.get() + headers_context_value = _provider_data_var.get() + + async def wrapper(): + while True: + _ = CURRENT_TRACE_CONTEXT.set(tracing_context_value) + _ = _provider_data_var.set(headers_context_value) + try: + item = await gen.__anext__() + yield item + except StopAsyncIteration: + break + + return wrapper() + def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict: if not body: return {} diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 19afae59b..87850f752 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -7,7 +7,7 @@ import contextvars import json import logging -from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar +from typing import Any, ContextManager, Dict, Optional from .utils.dynamic import instantiate_class_type @@ -35,33 +35,6 @@ class RequestProviderDataContext(ContextManager): _provider_data_var.reset(self.token) -T = TypeVar("T") - - -def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]: - """ - Wraps an async generator to preserve request headers context variables across iterations. - - This ensures that context variables set during generator creation are - available during each iteration of the generator, even if the original - context manager has exited. - """ - # Capture the current context value right now - context_value = _provider_data_var.get() - - async def wrapper(): - while True: - # Set context before each anext() call - _ = _provider_data_var.set(context_value) - try: - item = await gen.__anext__() - yield item - except StopAsyncIteration: - break - - return wrapper() - - class NeedsRequestProviderData: def get_request_provider_data(self) -> Any: spec = self.__provider_spec__