From 714c09cd53e66068f26b77dcdb6c74802ee362d1 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 11 Mar 2025 18:04:21 -0700 Subject: [PATCH] address feedback --- llama_stack/distribution/library_client.py | 28 +++-------------- llama_stack/distribution/request_headers.py | 8 ++--- llama_stack/distribution/server/server.py | 8 +++-- llama_stack/distribution/utils/context.py | 33 +++++++++++++++++++++ 4 files changed, 47 insertions(+), 30 deletions(-) create mode 100644 llama_stack/distribution/utils/context.py diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 9c7b116d4..82650ea40 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -13,7 +13,7 @@ import re from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path -from typing import Any, AsyncGenerator, Optional, TypeVar, Union, get_args, get_origin +from typing import Any, Optional, TypeVar, Union, get_args, get_origin import httpx import yaml @@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.datatypes import Api from llama_stack.distribution.request_headers import ( - _provider_data_var, + PROVIDER_DATA_VAR, request_provider_data_context, ) from llama_stack.distribution.resolver import ProviderRegistry @@ -44,6 +44,7 @@ from llama_stack.distribution.stack import ( redact_sensitive_fields, replace_env_vars, ) +from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.distribution.utils.exec import in_notebook from llama_stack.providers.utils.telemetry.tracing import ( CURRENT_TRACE_CONTEXT, @@ -384,7 +385,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): yield sse_event.encode("utf-8") try: - wrapped_gen = self._preserve_contexts_async_generator(gen()) + wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]) finally: await end_trace() @@ -418,27 +419,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) return await response.parse() - def _preserve_contexts_async_generator(self, gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]: - """ - Wraps an async generator to preserve both tracing and headers context variables across iterations. - This is needed because we start a new asyncio event loop for each request, and we need to preserve the context - across the event loop boundary. - """ - tracing_context_value = CURRENT_TRACE_CONTEXT.get() - headers_context_value = _provider_data_var.get() - - async def wrapper(): - while True: - _ = CURRENT_TRACE_CONTEXT.set(tracing_context_value) - _ = _provider_data_var.set(headers_context_value) - try: - item = await gen.__anext__() - yield item - except StopAsyncIteration: - break - - return wrapper() - def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict: if not body: return {} diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 87850f752..8709fc040 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -14,7 +14,7 @@ from .utils.dynamic import instantiate_class_type log = logging.getLogger(__name__) # Context variable for request provider data -_provider_data_var = contextvars.ContextVar("provider_data", default=None) +PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) class RequestProviderDataContext(ContextManager): @@ -26,13 +26,13 @@ class RequestProviderDataContext(ContextManager): def __enter__(self): # Save the current value and set the new one - self.token = _provider_data_var.set(self.provider_data) + self.token = PROVIDER_DATA_VAR.set(self.provider_data) return self def __exit__(self, exc_type, exc_val, exc_tb): # Restore the previous value if self.token is not None: - _provider_data_var.reset(self.token) + PROVIDER_DATA_VAR.reset(self.token) class NeedsRequestProviderData: @@ -45,7 +45,7 @@ class NeedsRequestProviderData: if not validator_class: raise ValueError(f"Provider {provider_type} does not have a validator") - val = _provider_data_var.get() + val = PROVIDER_DATA_VAR.get() if not val: return None diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index ea8723365..b1ec508a5 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -28,7 +28,7 @@ from typing_extensions import Annotated from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import ( - preserve_headers_context_async_generator, + PROVIDER_DATA_VAR, request_provider_data_context, ) from llama_stack.distribution.resolver import InvalidProviderError @@ -38,6 +38,7 @@ from llama_stack.distribution.stack import ( replace_env_vars, validate_env_pair, ) +from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig @@ -45,6 +46,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( TelemetryAdapter, ) from llama_stack.providers.utils.telemetry.tracing import ( + CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace, @@ -182,7 +184,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str): try: if is_streaming: - gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs))) + gen = preserve_contexts_async_generator( + sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR] + ) return StreamingResponse(gen, media_type="text/event-stream") else: value = func(**kwargs) diff --git a/llama_stack/distribution/utils/context.py b/llama_stack/distribution/utils/context.py new file mode 100644 index 000000000..a76edb60c --- /dev/null +++ b/llama_stack/distribution/utils/context.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from contextvars import ContextVar +from typing import AsyncGenerator, List, TypeVar + +T = TypeVar("T") + + +def preserve_contexts_async_generator( + gen: AsyncGenerator[T, None], context_vars: List[ContextVar] +) -> AsyncGenerator[T, None]: + """ + Wraps an async generator to preserve both tracing and headers context variables across iterations. + This is needed because we start a new asyncio event loop for each request, and we need to preserve the context + across the event loop boundary. + """ + context_values = [context_var.get() for context_var in context_vars] + + async def wrapper(): + while True: + for context_var, context_value in zip(context_vars, context_values, strict=False): + _ = context_var.set(context_value) + try: + item = await gen.__anext__() + yield item + except StopAsyncIteration: + break + + return wrapper()