address feedback

This commit is contained in:
Dinesh Yeduguru 2025-03-11 18:04:21 -07:00
parent a900740e30
commit 714c09cd53
4 changed files with 47 additions and 30 deletions

View file

@ -13,7 +13,7 @@ import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, AsyncGenerator, Optional, TypeVar, Union, get_args, get_origin from typing import Any, Optional, TypeVar, Union, get_args, get_origin
import httpx import httpx
import yaml import yaml
@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import ( from llama_stack.distribution.request_headers import (
_provider_data_var, PROVIDER_DATA_VAR,
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.resolver import ProviderRegistry
@ -44,6 +44,7 @@ from llama_stack.distribution.stack import (
redact_sensitive_fields, redact_sensitive_fields,
replace_env_vars, replace_env_vars,
) )
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.distribution.utils.exec import in_notebook from llama_stack.distribution.utils.exec import in_notebook
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT, CURRENT_TRACE_CONTEXT,
@ -384,7 +385,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
yield sse_event.encode("utf-8") yield sse_event.encode("utf-8")
try: try:
wrapped_gen = self._preserve_contexts_async_generator(gen()) wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
finally: finally:
await end_trace() await end_trace()
@ -418,27 +419,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return await response.parse() return await response.parse()
def _preserve_contexts_async_generator(self, gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve both tracing and headers context variables across iterations.
This is needed because we start a new asyncio event loop for each request, and we need to preserve the context
across the event loop boundary.
"""
tracing_context_value = CURRENT_TRACE_CONTEXT.get()
headers_context_value = _provider_data_var.get()
async def wrapper():
while True:
_ = CURRENT_TRACE_CONTEXT.set(tracing_context_value)
_ = _provider_data_var.set(headers_context_value)
try:
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break
return wrapper()
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict: def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
if not body: if not body:
return {} return {}

View file

@ -14,7 +14,7 @@ from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Context variable for request provider data # Context variable for request provider data
_provider_data_var = contextvars.ContextVar("provider_data", default=None) PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(ContextManager): class RequestProviderDataContext(ContextManager):
@ -26,13 +26,13 @@ class RequestProviderDataContext(ContextManager):
def __enter__(self): def __enter__(self):
# Save the current value and set the new one # Save the current value and set the new one
self.token = _provider_data_var.set(self.provider_data) self.token = PROVIDER_DATA_VAR.set(self.provider_data)
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
# Restore the previous value # Restore the previous value
if self.token is not None: if self.token is not None:
_provider_data_var.reset(self.token) PROVIDER_DATA_VAR.reset(self.token)
class NeedsRequestProviderData: class NeedsRequestProviderData:
@ -45,7 +45,7 @@ class NeedsRequestProviderData:
if not validator_class: if not validator_class:
raise ValueError(f"Provider {provider_type} does not have a validator") raise ValueError(f"Provider {provider_type} does not have a validator")
val = _provider_data_var.get() val = PROVIDER_DATA_VAR.get()
if not val: if not val:
return None return None

View file

@ -28,7 +28,7 @@ from typing_extensions import Annotated
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import ( from llama_stack.distribution.request_headers import (
preserve_headers_context_async_generator, PROVIDER_DATA_VAR,
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
@ -38,6 +38,7 @@ from llama_stack.distribution.stack import (
replace_env_vars, replace_env_vars,
validate_env_pair, validate_env_pair,
) )
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
@ -45,6 +46,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter, TelemetryAdapter,
) )
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace, end_trace,
setup_logger, setup_logger,
start_trace, start_trace,
@ -182,7 +184,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
try: try:
if is_streaming: if is_streaming:
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs))) gen = preserve_contexts_async_generator(
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
)
return StreamingResponse(gen, media_type="text/event-stream") return StreamingResponse(gen, media_type="text/event-stream")
else: else:
value = func(**kwargs) value = func(**kwargs)

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from contextvars import ContextVar
from typing import AsyncGenerator, List, TypeVar
T = TypeVar("T")
def preserve_contexts_async_generator(
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve both tracing and headers context variables across iterations.
This is needed because we start a new asyncio event loop for each request, and we need to preserve the context
across the event loop boundary.
"""
context_values = [context_var.get() for context_var in context_vars]
async def wrapper():
while True:
for context_var, context_value in zip(context_vars, context_values, strict=False):
_ = context_var.set(context_value)
try:
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break
return wrapper()