mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
preserve context across async generator boundaries
This commit is contained in:
parent
4a894b925d
commit
21769648a6
3 changed files with 41 additions and 5 deletions
|
@ -32,7 +32,10 @@ from termcolor import cprint
|
||||||
from llama_stack.distribution.build import print_pip_install_help
|
from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
from llama_stack.distribution.request_headers import request_provider_data_context
|
from llama_stack.distribution.request_headers import (
|
||||||
|
preserve_headers_context_async_generator,
|
||||||
|
request_provider_data_context,
|
||||||
|
)
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
|
@ -378,9 +381,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
finally:
|
finally:
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
|
# Wrap the generator to preserve context across iterations
|
||||||
|
wrapped_gen = preserve_headers_context_async_generator(gen())
|
||||||
|
|
||||||
mock_response = httpx.Response(
|
mock_response = httpx.Response(
|
||||||
status_code=httpx.codes.OK,
|
status_code=httpx.codes.OK,
|
||||||
content=gen(),
|
content=wrapped_gen,
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import contextvars
|
import contextvars
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, ContextManager, Dict, Optional
|
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
@ -35,6 +35,31 @@ class RequestProviderDataContext(ContextManager):
|
||||||
_provider_data_var.reset(self.token)
|
_provider_data_var.reset(self.token)
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
async def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
||||||
|
"""
|
||||||
|
Wraps an async generator to preserve request headers context variables across iterations.
|
||||||
|
|
||||||
|
This ensures that context variables set during generator creation are
|
||||||
|
available during each iteration of the generator, even if the original
|
||||||
|
context manager has exited.
|
||||||
|
"""
|
||||||
|
# Capture the current context value
|
||||||
|
context_value = _provider_data_var.get()
|
||||||
|
|
||||||
|
# Create a wrapper that restores context for each iteration
|
||||||
|
async for item in gen:
|
||||||
|
# Save the current token to restore later
|
||||||
|
token = _provider_data_var.set(context_value)
|
||||||
|
try:
|
||||||
|
yield item
|
||||||
|
finally:
|
||||||
|
# Restore the previous value
|
||||||
|
_provider_data_var.reset(token)
|
||||||
|
|
||||||
|
|
||||||
class NeedsRequestProviderData:
|
class NeedsRequestProviderData:
|
||||||
def get_request_provider_data(self) -> Any:
|
def get_request_provider_data(self) -> Any:
|
||||||
spec = self.__provider_spec__
|
spec = self.__provider_spec__
|
||||||
|
|
|
@ -29,7 +29,10 @@ from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import request_provider_data_context
|
from llama_stack.distribution.request_headers import (
|
||||||
|
preserve_headers_context_async_generator,
|
||||||
|
request_provider_data_context,
|
||||||
|
)
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
|
@ -203,7 +206,9 @@ async def maybe_await(value):
|
||||||
async def sse_generator(event_gen):
|
async def sse_generator(event_gen):
|
||||||
try:
|
try:
|
||||||
event_gen = await event_gen
|
event_gen = await event_gen
|
||||||
async for item in event_gen:
|
# Wrap the generator to preserve context across iterations
|
||||||
|
wrapped_gen = preserve_headers_context_async_generator(event_gen)
|
||||||
|
async for item in wrapped_gen:
|
||||||
yield create_sse_event(item)
|
yield create_sse_event(item)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue