mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
consolidate context var restores
This commit is contained in:
parent
9d4716521d
commit
1ddc026355
2 changed files with 36 additions and 44 deletions
|
@ -13,7 +13,7 @@ import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
|
from typing import Any, AsyncGenerator, Optional, TypeVar, Union, get_args, get_origin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
from llama_stack.distribution.request_headers import (
|
from llama_stack.distribution.request_headers import (
|
||||||
preserve_headers_context_async_generator,
|
_provider_data_var,
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
|
@ -159,11 +159,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
|
|
||||||
def sync_generator():
|
def sync_generator():
|
||||||
try:
|
try:
|
||||||
global trace_context
|
|
||||||
async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs))
|
async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs))
|
||||||
while True:
|
while True:
|
||||||
if trace_context:
|
|
||||||
CURRENT_TRACE_CONTEXT.set(trace_context)
|
|
||||||
chunk = loop.run_until_complete(async_stream.__anext__())
|
chunk = loop.run_until_complete(async_stream.__anext__())
|
||||||
yield chunk
|
yield chunk
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
|
@ -380,20 +377,21 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
body = self._convert_body(path, options.method, body)
|
body = self._convert_body(path, options.method, body)
|
||||||
|
|
||||||
|
await start_trace(options.url, {"__location__": "library_client"})
|
||||||
|
|
||||||
async def gen():
|
async def gen():
|
||||||
await start_trace(options.url, {"__location__": "library_client"})
|
async for chunk in await func(**body):
|
||||||
global trace_context
|
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||||
trace_context = CURRENT_TRACE_CONTEXT.get()
|
sse_event = f"data: {data}\n\n"
|
||||||
try:
|
yield sse_event.encode("utf-8")
|
||||||
async for chunk in await func(**body):
|
|
||||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
|
||||||
sse_event = f"data: {data}\n\n"
|
|
||||||
yield sse_event.encode("utf-8")
|
|
||||||
finally:
|
|
||||||
await end_trace()
|
|
||||||
|
|
||||||
# Wrap the generator to preserve context across iterations
|
# Wrap the generator to preserve context across iterations
|
||||||
wrapped_gen = preserve_headers_context_async_generator(gen())
|
try:
|
||||||
|
# Combine both context preservations in a single pass
|
||||||
|
wrapped_gen = self._preserve_contexts_async_generator(gen())
|
||||||
|
finally:
|
||||||
|
await end_trace()
|
||||||
|
|
||||||
mock_response = httpx.Response(
|
mock_response = httpx.Response(
|
||||||
status_code=httpx.codes.OK,
|
status_code=httpx.codes.OK,
|
||||||
content=wrapped_gen,
|
content=wrapped_gen,
|
||||||
|
@ -424,6 +422,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
return await response.parse()
|
return await response.parse()
|
||||||
|
|
||||||
|
def _preserve_contexts_async_generator(self, gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
||||||
|
"""
|
||||||
|
Wraps an async generator to preserve both tracing and headers context variables across iterations.
|
||||||
|
This is needed because we start a new asyncio event loop for each request, and we need to preserve the context
|
||||||
|
across the event loop boundary.
|
||||||
|
"""
|
||||||
|
tracing_context_value = CURRENT_TRACE_CONTEXT.get()
|
||||||
|
headers_context_value = _provider_data_var.get()
|
||||||
|
|
||||||
|
async def wrapper():
|
||||||
|
while True:
|
||||||
|
_ = CURRENT_TRACE_CONTEXT.set(tracing_context_value)
|
||||||
|
_ = _provider_data_var.set(headers_context_value)
|
||||||
|
try:
|
||||||
|
item = await gen.__anext__()
|
||||||
|
yield item
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
return wrapper()
|
||||||
|
|
||||||
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
|
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import contextvars
|
import contextvars
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
|
from typing import Any, ContextManager, Dict, Optional
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
@ -35,33 +35,6 @@ class RequestProviderDataContext(ContextManager):
|
||||||
_provider_data_var.reset(self.token)
|
_provider_data_var.reset(self.token)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
|
||||||
"""
|
|
||||||
Wraps an async generator to preserve request headers context variables across iterations.
|
|
||||||
|
|
||||||
This ensures that context variables set during generator creation are
|
|
||||||
available during each iteration of the generator, even if the original
|
|
||||||
context manager has exited.
|
|
||||||
"""
|
|
||||||
# Capture the current context value right now
|
|
||||||
context_value = _provider_data_var.get()
|
|
||||||
|
|
||||||
async def wrapper():
|
|
||||||
while True:
|
|
||||||
# Set context before each anext() call
|
|
||||||
_ = _provider_data_var.set(context_value)
|
|
||||||
try:
|
|
||||||
item = await gen.__anext__()
|
|
||||||
yield item
|
|
||||||
except StopAsyncIteration:
|
|
||||||
break
|
|
||||||
|
|
||||||
return wrapper()
|
|
||||||
|
|
||||||
|
|
||||||
class NeedsRequestProviderData:
|
class NeedsRequestProviderData:
|
||||||
def get_request_provider_data(self) -> Any:
|
def get_request_provider_data(self) -> Any:
|
||||||
spec = self.__provider_spec__
|
spec = self.__provider_spec__
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue