consolidate context var restores

This commit is contained in:
Dinesh Yeduguru 2025-03-11 16:14:46 -07:00
parent 9d4716521d
commit 1ddc026355
2 changed files with 36 additions and 44 deletions

View file

@ -13,7 +13,7 @@ import re
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
from typing import Any, AsyncGenerator, Optional, TypeVar, Union, get_args, get_origin
import httpx
import yaml
@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import (
preserve_headers_context_async_generator,
_provider_data_var,
request_provider_data_context,
)
from llama_stack.distribution.resolver import ProviderRegistry
@ -159,11 +159,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
def sync_generator():
try:
global trace_context
async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs))
while True:
if trace_context:
CURRENT_TRACE_CONTEXT.set(trace_context)
chunk = loop.run_until_complete(async_stream.__anext__())
yield chunk
except StopAsyncIteration:
@ -380,20 +377,21 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = self._convert_body(path, options.method, body)
await start_trace(options.url, {"__location__": "library_client"})
async def gen():
await start_trace(options.url, {"__location__": "library_client"})
global trace_context
trace_context = CURRENT_TRACE_CONTEXT.get()
try:
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
sse_event = f"data: {data}\n\n"
yield sse_event.encode("utf-8")
finally:
await end_trace()
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
sse_event = f"data: {data}\n\n"
yield sse_event.encode("utf-8")
# Wrap the generator to preserve context across iterations
wrapped_gen = preserve_headers_context_async_generator(gen())
try:
# Combine both context preservations in a single pass
wrapped_gen = self._preserve_contexts_async_generator(gen())
finally:
await end_trace()
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=wrapped_gen,
@ -424,6 +422,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return await response.parse()
def _preserve_contexts_async_generator(self, gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve both tracing and headers context variables across iterations.
This is needed because we start a new asyncio event loop for each request, and we need to preserve the context
across the event loop boundary.
"""
tracing_context_value = CURRENT_TRACE_CONTEXT.get()
headers_context_value = _provider_data_var.get()
async def wrapper():
while True:
_ = CURRENT_TRACE_CONTEXT.set(tracing_context_value)
_ = _provider_data_var.set(headers_context_value)
try:
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break
return wrapper()
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
if not body:
return {}

View file

@ -7,7 +7,7 @@
import contextvars
import json
import logging
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
from typing import Any, ContextManager, Dict, Optional
from .utils.dynamic import instantiate_class_type
@ -35,33 +35,6 @@ class RequestProviderDataContext(ContextManager):
_provider_data_var.reset(self.token)
T = TypeVar("T")
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve request headers context variables across iterations.
This ensures that context variables set during generator creation are
available during each iteration of the generator, even if the original
context manager has exited.
"""
# Capture the current context value right now
context_value = _provider_data_var.get()
async def wrapper():
while True:
# Set context before each anext() call
_ = _provider_data_var.set(context_value)
try:
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break
return wrapper()
class NeedsRequestProviderData:
def get_request_provider_data(self) -> Any:
spec = self.__provider_spec__