consolidate context var restores

This commit is contained in:
Dinesh Yeduguru 2025-03-11 16:14:46 -07:00
parent 9d4716521d
commit 1ddc026355
2 changed files with 36 additions and 44 deletions

View file

@ -13,7 +13,7 @@ import re
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
from typing import Any, AsyncGenerator, Optional, TypeVar, Union, get_args, get_origin
import httpx
import yaml
@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import (
preserve_headers_context_async_generator,
_provider_data_var,
request_provider_data_context,
)
from llama_stack.distribution.resolver import ProviderRegistry
@ -159,11 +159,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
def sync_generator():
try:
global trace_context
async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs))
while True:
if trace_context:
CURRENT_TRACE_CONTEXT.set(trace_context)
chunk = loop.run_until_complete(async_stream.__anext__())
yield chunk
except StopAsyncIteration:
@ -380,20 +377,21 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = self._convert_body(path, options.method, body)
await start_trace(options.url, {"__location__": "library_client"})
async def gen():
await start_trace(options.url, {"__location__": "library_client"})
global trace_context
trace_context = CURRENT_TRACE_CONTEXT.get()
try:
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
sse_event = f"data: {data}\n\n"
yield sse_event.encode("utf-8")
finally:
await end_trace()
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
sse_event = f"data: {data}\n\n"
yield sse_event.encode("utf-8")
# Wrap the generator to preserve context across iterations
wrapped_gen = preserve_headers_context_async_generator(gen())
try:
# Combine both context preservations in a single pass
wrapped_gen = self._preserve_contexts_async_generator(gen())
finally:
await end_trace()
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=wrapped_gen,
@ -424,6 +422,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return await response.parse()
def _preserve_contexts_async_generator(self, gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve both tracing and headers context variables across iterations.
This is needed because we start a new asyncio event loop for each request, and we need to preserve the context
across the event loop boundary.
"""
tracing_context_value = CURRENT_TRACE_CONTEXT.get()
headers_context_value = _provider_data_var.get()
async def wrapper():
while True:
_ = CURRENT_TRACE_CONTEXT.set(tracing_context_value)
_ = _provider_data_var.set(headers_context_value)
try:
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break
return wrapper()
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
if not body:
return {}