mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
address feedback
This commit is contained in:
parent
a900740e30
commit
714c09cd53
4 changed files with 47 additions and 30 deletions
|
@ -13,7 +13,7 @@ import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, AsyncGenerator, Optional, TypeVar, Union, get_args, get_origin
|
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
from llama_stack.distribution.request_headers import (
|
from llama_stack.distribution.request_headers import (
|
||||||
_provider_data_var,
|
PROVIDER_DATA_VAR,
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
|
@ -44,6 +44,7 @@ from llama_stack.distribution.stack import (
|
||||||
redact_sensitive_fields,
|
redact_sensitive_fields,
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||||
from llama_stack.distribution.utils.exec import in_notebook
|
from llama_stack.distribution.utils.exec import in_notebook
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
CURRENT_TRACE_CONTEXT,
|
CURRENT_TRACE_CONTEXT,
|
||||||
|
@ -384,7 +385,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
yield sse_event.encode("utf-8")
|
yield sse_event.encode("utf-8")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
wrapped_gen = self._preserve_contexts_async_generator(gen())
|
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
||||||
finally:
|
finally:
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
|
@ -418,27 +419,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
return await response.parse()
|
return await response.parse()
|
||||||
|
|
||||||
def _preserve_contexts_async_generator(self, gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
|
||||||
"""
|
|
||||||
Wraps an async generator to preserve both tracing and headers context variables across iterations.
|
|
||||||
This is needed because we start a new asyncio event loop for each request, and we need to preserve the context
|
|
||||||
across the event loop boundary.
|
|
||||||
"""
|
|
||||||
tracing_context_value = CURRENT_TRACE_CONTEXT.get()
|
|
||||||
headers_context_value = _provider_data_var.get()
|
|
||||||
|
|
||||||
async def wrapper():
|
|
||||||
while True:
|
|
||||||
_ = CURRENT_TRACE_CONTEXT.set(tracing_context_value)
|
|
||||||
_ = _provider_data_var.set(headers_context_value)
|
|
||||||
try:
|
|
||||||
item = await gen.__anext__()
|
|
||||||
yield item
|
|
||||||
except StopAsyncIteration:
|
|
||||||
break
|
|
||||||
|
|
||||||
return wrapper()
|
|
||||||
|
|
||||||
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
|
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
|
|
@ -14,7 +14,7 @@ from .utils.dynamic import instantiate_class_type
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Context variable for request provider data
|
# Context variable for request provider data
|
||||||
_provider_data_var = contextvars.ContextVar("provider_data", default=None)
|
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||||
|
|
||||||
|
|
||||||
class RequestProviderDataContext(ContextManager):
|
class RequestProviderDataContext(ContextManager):
|
||||||
|
@ -26,13 +26,13 @@ class RequestProviderDataContext(ContextManager):
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
# Save the current value and set the new one
|
# Save the current value and set the new one
|
||||||
self.token = _provider_data_var.set(self.provider_data)
|
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
# Restore the previous value
|
# Restore the previous value
|
||||||
if self.token is not None:
|
if self.token is not None:
|
||||||
_provider_data_var.reset(self.token)
|
PROVIDER_DATA_VAR.reset(self.token)
|
||||||
|
|
||||||
|
|
||||||
class NeedsRequestProviderData:
|
class NeedsRequestProviderData:
|
||||||
|
@ -45,7 +45,7 @@ class NeedsRequestProviderData:
|
||||||
if not validator_class:
|
if not validator_class:
|
||||||
raise ValueError(f"Provider {provider_type} does not have a validator")
|
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||||
|
|
||||||
val = _provider_data_var.get()
|
val = PROVIDER_DATA_VAR.get()
|
||||||
if not val:
|
if not val:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ from typing_extensions import Annotated
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import (
|
from llama_stack.distribution.request_headers import (
|
||||||
preserve_headers_context_async_generator,
|
PROVIDER_DATA_VAR,
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
|
@ -38,6 +38,7 @@ from llama_stack.distribution.stack import (
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
validate_env_pair,
|
validate_env_pair,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||||
|
@ -45,6 +46,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||||
TelemetryAdapter,
|
TelemetryAdapter,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
|
CURRENT_TRACE_CONTEXT,
|
||||||
end_trace,
|
end_trace,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
start_trace,
|
start_trace,
|
||||||
|
@ -182,7 +184,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
|
gen = preserve_contexts_async_generator(
|
||||||
|
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||||
|
)
|
||||||
return StreamingResponse(gen, media_type="text/event-stream")
|
return StreamingResponse(gen, media_type="text/event-stream")
|
||||||
else:
|
else:
|
||||||
value = func(**kwargs)
|
value = func(**kwargs)
|
||||||
|
|
33
llama_stack/distribution/utils/context.py
Normal file
33
llama_stack/distribution/utils/context.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import AsyncGenerator, List, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def preserve_contexts_async_generator(
|
||||||
|
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
|
||||||
|
) -> AsyncGenerator[T, None]:
|
||||||
|
"""
|
||||||
|
Wraps an async generator to preserve both tracing and headers context variables across iterations.
|
||||||
|
This is needed because we start a new asyncio event loop for each request, and we need to preserve the context
|
||||||
|
across the event loop boundary.
|
||||||
|
"""
|
||||||
|
context_values = [context_var.get() for context_var in context_vars]
|
||||||
|
|
||||||
|
async def wrapper():
|
||||||
|
while True:
|
||||||
|
for context_var, context_value in zip(context_vars, context_values, strict=False):
|
||||||
|
_ = context_var.set(context_value)
|
||||||
|
try:
|
||||||
|
item = await gen.__anext__()
|
||||||
|
yield item
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
return wrapper()
|
Loading…
Add table
Add a link
Reference in a new issue