diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 8915daf5a..ab8ff60fa 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -32,7 +32,10 @@ from termcolor import cprint from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.datatypes import Api -from llama_stack.distribution.request_headers import set_request_provider_data +from llama_stack.distribution.request_headers import ( + preserve_headers_context_async_generator, + request_provider_data_context, +) from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.stack import ( @@ -262,21 +265,25 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if not self.endpoint_impls: raise ValueError("Client not initialized") + # Create headers with provider data if available + headers = {} if self.provider_data: - set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}) + headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data) - if stream: - response = await self._call_streaming( - cast_to=cast_to, - options=options, - stream_cls=stream_cls, - ) - else: - response = await self._call_non_streaming( - cast_to=cast_to, - options=options, - ) - return response + # Use context manager for provider data + with request_provider_data_context(headers): + if stream: + response = await self._call_streaming( + cast_to=cast_to, + options=options, + stream_cls=stream_cls, + ) + else: + response = await self._call_non_streaming( + cast_to=cast_to, + options=options, + ) + return response def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]: """Find the matching endpoint implementation for a given method and path. @@ -374,9 +381,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): finally: await end_trace() + # Wrap the generator to preserve context across iterations + wrapped_gen = preserve_headers_context_async_generator(gen()) + mock_response = httpx.Response( status_code=httpx.codes.OK, - content=gen(), + content=wrapped_gen, headers={ "Content-Type": "application/json", }, diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 2a9bc622a..19afae59b 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -4,16 +4,62 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import contextvars import json import logging -import threading -from typing import Any, Dict +from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar from .utils.dynamic import instantiate_class_type log = logging.getLogger(__name__) -_THREAD_LOCAL = threading.local() +# Context variable for request provider data +_provider_data_var = contextvars.ContextVar("provider_data", default=None) + + +class RequestProviderDataContext(ContextManager): + """Context manager for request provider data""" + + def __init__(self, provider_data: Optional[Dict[str, Any]] = None): + self.provider_data = provider_data + self.token = None + + def __enter__(self): + # Save the current value and set the new one + self.token = _provider_data_var.set(self.provider_data) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Restore the previous value + if self.token is not None: + _provider_data_var.reset(self.token) + + +T = TypeVar("T") + + +def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]: + """ + Wraps an async generator to preserve request headers context variables across iterations. + + This ensures that context variables set during generator creation are + available during each iteration of the generator, even if the original + context manager has exited. + """ + # Capture the current context value right now + context_value = _provider_data_var.get() + + async def wrapper(): + while True: + # Set context before each anext() call + _ = _provider_data_var.set(context_value) + try: + item = await gen.__anext__() + yield item + except StopAsyncIteration: + break + + return wrapper() class NeedsRequestProviderData: @@ -26,7 +72,7 @@ class NeedsRequestProviderData: if not validator_class: raise ValueError(f"Provider {provider_type} does not have a validator") - val = getattr(_THREAD_LOCAL, "provider_data_header_value", None) + val = _provider_data_var.get() if not val: return None @@ -36,25 +82,32 @@ class NeedsRequestProviderData: return provider_data except Exception as e: log.error(f"Error parsing provider data: {e}") + return None -def set_request_provider_data(headers: Dict[str, str]): +def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]: + """Parse provider data from request headers""" keys = [ "X-LlamaStack-Provider-Data", "x-llamastack-provider-data", ] + val = None for key in keys: val = headers.get(key, None) if val: break if not val: - return + return None try: - val = json.loads(val) + return json.loads(val) except json.JSONDecodeError: - log.error("Provider data not encoded as a JSON object!", val) - return + log.error("Provider data not encoded as a JSON object!") + return None - _THREAD_LOCAL.provider_data_header_value = val + +def request_provider_data_context(headers: Dict[str, str]) -> ContextManager: + """Context manager that sets request provider data from headers for the duration of the context""" + provider_data = parse_request_provider_data(headers) + return RequestProviderDataContext(provider_data) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index c4ef79a69..347d88a2c 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -29,7 +29,10 @@ from typing_extensions import Annotated from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis -from llama_stack.distribution.request_headers import set_request_provider_data +from llama_stack.distribution.request_headers import ( + preserve_headers_context_async_generator, + request_provider_data_context, +) from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.stack import ( construct_stack, @@ -202,16 +205,14 @@ async def maybe_await(value): async def sse_generator(event_gen): try: - event_gen = await event_gen - async for item in event_gen: + async for item in await event_gen: yield create_sse_event(item) await asyncio.sleep(0.01) except asyncio.CancelledError: logger.info("Generator cancelled") await event_gen.aclose() except Exception as e: - logger.exception(f"Error in sse_generator: {e}") - logger.exception(f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}") + logger.exception("Error in sse_generator") yield create_sse_event( { "error": { @@ -223,18 +224,20 @@ async def sse_generator(event_gen): def create_dynamic_typed_route(func: Any, method: str, route: str): async def endpoint(request: Request, **kwargs): - set_request_provider_data(request.headers) + # Use context manager for request provider data + with request_provider_data_context(request.headers): + is_streaming = is_streaming_request(func.__name__, request, **kwargs) - is_streaming = is_streaming_request(func.__name__, request, **kwargs) - try: - if is_streaming: - return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream") - else: - value = func(**kwargs) - return await maybe_await(value) - except Exception as e: - traceback.print_exception(e) - raise translate_exception(e) from e + try: + if is_streaming: + gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs))) + return StreamingResponse(gen, media_type="text/event-stream") + else: + value = func(**kwargs) + return await maybe_await(value) + except Exception as e: + logger.exception("Error executing endpoint %s", method, route) + raise translate_exception(e) from e sig = inspect.signature(func) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index ec68fb556..4acbe43f8 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -70,8 +70,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv pass def _get_api_key(self) -> str: - if self.config.api_key is not None: - return self.config.api_key.get_secret_value() + config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None + if config_api_key: + return config_api_key else: provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.fireworks_api_key: diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 2046d4aae..dfc9ae6d3 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -93,8 +93,9 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi def _get_client(self) -> Together: together_api_key = None - if self.config.api_key is not None: - together_api_key = self.config.api_key.get_secret_value() + config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None + if config_api_key: + together_api_key = config_api_key else: provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.together_api_key: diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 6a75b3adf..e410039e7 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -42,7 +42,7 @@ def provider_data(): for key, value in keymap.items(): if os.environ.get(key): provider_data[value] = os.environ[key] - return provider_data if len(provider_data) > 0 else None + return provider_data @pytest.fixture(scope="session")