diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 8915daf5a..bc496cca5 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -32,7 +32,7 @@ from termcolor import cprint from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.datatypes import Api -from llama_stack.distribution.request_headers import set_request_provider_data +from llama_stack.distribution.request_headers import request_provider_data_context from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.stack import ( @@ -262,21 +262,25 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if not self.endpoint_impls: raise ValueError("Client not initialized") + # Create headers with provider data if available + headers = {} if self.provider_data: - set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}) + headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data) - if stream: - response = await self._call_streaming( - cast_to=cast_to, - options=options, - stream_cls=stream_cls, - ) - else: - response = await self._call_non_streaming( - cast_to=cast_to, - options=options, - ) - return response + # Use context manager for provider data + with request_provider_data_context(headers): + if stream: + response = await self._call_streaming( + cast_to=cast_to, + options=options, + stream_cls=stream_cls, + ) + else: + response = await self._call_non_streaming( + cast_to=cast_to, + options=options, + ) + return response def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]: """Find the matching endpoint implementation for a given method and path. diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 2a9bc622a..87850f752 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -4,16 +4,35 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import contextvars import json import logging -import threading -from typing import Any, Dict +from typing import Any, ContextManager, Dict, Optional from .utils.dynamic import instantiate_class_type log = logging.getLogger(__name__) -_THREAD_LOCAL = threading.local() +# Context variable for request provider data +_provider_data_var = contextvars.ContextVar("provider_data", default=None) + + +class RequestProviderDataContext(ContextManager): + """Context manager for request provider data""" + + def __init__(self, provider_data: Optional[Dict[str, Any]] = None): + self.provider_data = provider_data + self.token = None + + def __enter__(self): + # Save the current value and set the new one + self.token = _provider_data_var.set(self.provider_data) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Restore the previous value + if self.token is not None: + _provider_data_var.reset(self.token) class NeedsRequestProviderData: @@ -26,7 +45,7 @@ class NeedsRequestProviderData: if not validator_class: raise ValueError(f"Provider {provider_type} does not have a validator") - val = getattr(_THREAD_LOCAL, "provider_data_header_value", None) + val = _provider_data_var.get() if not val: return None @@ -36,25 +55,32 @@ class NeedsRequestProviderData: return provider_data except Exception as e: log.error(f"Error parsing provider data: {e}") + return None -def set_request_provider_data(headers: Dict[str, str]): +def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]: + """Parse provider data from request headers""" keys = [ "X-LlamaStack-Provider-Data", "x-llamastack-provider-data", ] + val = None for key in keys: val = headers.get(key, None) if val: break if not val: - return + return None try: - val = json.loads(val) + return json.loads(val) except json.JSONDecodeError: - log.error("Provider data not encoded as a JSON object!", val) - return + log.error("Provider data not encoded as a JSON object!") + return None - _THREAD_LOCAL.provider_data_header_value = val + +def request_provider_data_context(headers: Dict[str, str]) -> ContextManager: + """Context manager that sets request provider data from headers for the duration of the context""" + provider_data = parse_request_provider_data(headers) + return RequestProviderDataContext(provider_data) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index c4ef79a69..9d4a29a31 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -29,7 +29,7 @@ from typing_extensions import Annotated from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis -from llama_stack.distribution.request_headers import set_request_provider_data +from llama_stack.distribution.request_headers import request_provider_data_context from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.stack import ( construct_stack, @@ -223,18 +223,18 @@ async def sse_generator(event_gen): def create_dynamic_typed_route(func: Any, method: str, route: str): async def endpoint(request: Request, **kwargs): - set_request_provider_data(request.headers) - - is_streaming = is_streaming_request(func.__name__, request, **kwargs) - try: - if is_streaming: - return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream") - else: - value = func(**kwargs) - return await maybe_await(value) - except Exception as e: - traceback.print_exception(e) - raise translate_exception(e) from e + # Use context manager for request provider data + with request_provider_data_context(request.headers): + is_streaming = is_streaming_request(func.__name__, request, **kwargs) + try: + if is_streaming: + return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream") + else: + value = func(**kwargs) + return await maybe_await(value) + except Exception as e: + traceback.print_exception(e) + raise translate_exception(e) from e sig = inspect.signature(func)