mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
preserve context across async generator boundaries
This commit is contained in:
parent
4a894b925d
commit
21769648a6
3 changed files with 41 additions and 5 deletions
|
@ -32,7 +32,10 @@ from termcolor import cprint
|
|||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.request_headers import request_provider_data_context
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.stack import (
|
||||
|
@ -378,9 +381,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
finally:
|
||||
await end_trace()
|
||||
|
||||
# Wrap the generator to preserve context across iterations
|
||||
wrapped_gen = preserve_headers_context_async_generator(gen())
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=gen(),
|
||||
content=wrapped_gen,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, ContextManager, Dict, Optional
|
||||
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
|
@ -35,6 +35,31 @@ class RequestProviderDataContext(ContextManager):
|
|||
_provider_data_var.reset(self.token)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve request headers context variables across iterations.
|
||||
|
||||
This ensures that context variables set during generator creation are
|
||||
available during each iteration of the generator, even if the original
|
||||
context manager has exited.
|
||||
"""
|
||||
# Capture the current context value
|
||||
context_value = _provider_data_var.get()
|
||||
|
||||
# Create a wrapper that restores context for each iteration
|
||||
async for item in gen:
|
||||
# Save the current token to restore later
|
||||
token = _provider_data_var.set(context_value)
|
||||
try:
|
||||
yield item
|
||||
finally:
|
||||
# Restore the previous value
|
||||
_provider_data_var.reset(token)
|
||||
|
||||
|
||||
class NeedsRequestProviderData:
|
||||
def get_request_provider_data(self) -> Any:
|
||||
spec = self.__provider_spec__
|
||||
|
|
|
@ -29,7 +29,10 @@ from typing_extensions import Annotated
|
|||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import request_provider_data_context
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
|
@ -203,7 +206,9 @@ async def maybe_await(value):
|
|||
async def sse_generator(event_gen):
|
||||
try:
|
||||
event_gen = await event_gen
|
||||
async for item in event_gen:
|
||||
# Wrap the generator to preserve context across iterations
|
||||
wrapped_gen = preserve_headers_context_async_generator(event_gen)
|
||||
async for item in wrapped_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue