consolidate context var restores

This commit is contained in:
Dinesh Yeduguru 2025-03-11 16:14:46 -07:00
parent 9d4716521d
commit 1ddc026355
2 changed files with 36 additions and 44 deletions

View file

@ -13,7 +13,7 @@ import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Optional, TypeVar, Union, get_args, get_origin from typing import Any, AsyncGenerator, Optional, TypeVar, Union, get_args, get_origin
import httpx import httpx
import yaml import yaml
@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import ( from llama_stack.distribution.request_headers import (
preserve_headers_context_async_generator, _provider_data_var,
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.resolver import ProviderRegistry
@ -159,11 +159,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
def sync_generator(): def sync_generator():
try: try:
global trace_context
async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs)) async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs))
while True: while True:
if trace_context:
CURRENT_TRACE_CONTEXT.set(trace_context)
chunk = loop.run_until_complete(async_stream.__anext__()) chunk = loop.run_until_complete(async_stream.__anext__())
yield chunk yield chunk
except StopAsyncIteration: except StopAsyncIteration:
@ -380,20 +377,21 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = self._convert_body(path, options.method, body) body = self._convert_body(path, options.method, body)
await start_trace(options.url, {"__location__": "library_client"})
async def gen(): async def gen():
await start_trace(options.url, {"__location__": "library_client"}) async for chunk in await func(**body):
global trace_context data = json.dumps(convert_pydantic_to_json_value(chunk))
trace_context = CURRENT_TRACE_CONTEXT.get() sse_event = f"data: {data}\n\n"
try: yield sse_event.encode("utf-8")
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
sse_event = f"data: {data}\n\n"
yield sse_event.encode("utf-8")
finally:
await end_trace()
# Wrap the generator to preserve context across iterations # Wrap the generator to preserve context across iterations
wrapped_gen = preserve_headers_context_async_generator(gen()) try:
# Combine both context preservations in a single pass
wrapped_gen = self._preserve_contexts_async_generator(gen())
finally:
await end_trace()
mock_response = httpx.Response( mock_response = httpx.Response(
status_code=httpx.codes.OK, status_code=httpx.codes.OK,
content=wrapped_gen, content=wrapped_gen,
@ -424,6 +422,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return await response.parse() return await response.parse()
def _preserve_contexts_async_generator(self, gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve both tracing and headers context variables across iterations.
This is needed because we start a new asyncio event loop for each request, and we need to preserve the context
across the event loop boundary.
"""
tracing_context_value = CURRENT_TRACE_CONTEXT.get()
headers_context_value = _provider_data_var.get()
async def wrapper():
while True:
_ = CURRENT_TRACE_CONTEXT.set(tracing_context_value)
_ = _provider_data_var.set(headers_context_value)
try:
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break
return wrapper()
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict: def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
if not body: if not body:
return {} return {}

View file

@ -7,7 +7,7 @@
import contextvars import contextvars
import json import json
import logging import logging
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar from typing import Any, ContextManager, Dict, Optional
from .utils.dynamic import instantiate_class_type from .utils.dynamic import instantiate_class_type
@ -35,33 +35,6 @@ class RequestProviderDataContext(ContextManager):
_provider_data_var.reset(self.token) _provider_data_var.reset(self.token)
T = TypeVar("T")
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve request headers context variables across iterations.
This ensures that context variables set during generator creation are
available during each iteration of the generator, even if the original
context manager has exited.
"""
# Capture the current context value right now
context_value = _provider_data_var.get()
async def wrapper():
while True:
# Set context before each anext() call
_ = _provider_data_var.set(context_value)
try:
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break
return wrapper()
class NeedsRequestProviderData: class NeedsRequestProviderData:
def get_request_provider_data(self) -> Any: def get_request_provider_data(self) -> Any:
spec = self.__provider_spec__ spec = self.__provider_spec__