mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-01 17:24:32 +00:00
address feedback
This commit is contained in:
parent
a900740e30
commit
714c09cd53
4 changed files with 47 additions and 30 deletions
|
|
@ -13,7 +13,7 @@ import re
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator, Optional, TypeVar, Union, get_args, get_origin
|
||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
|
|
@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help
|
|||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.request_headers import (
|
||||
_provider_data_var,
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
|
|
@ -44,6 +44,7 @@ from llama_stack.distribution.stack import (
|
|||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.distribution.utils.exec import in_notebook
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
|
|
@ -384,7 +385,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
yield sse_event.encode("utf-8")
|
||||
|
||||
try:
|
||||
wrapped_gen = self._preserve_contexts_async_generator(gen())
|
||||
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
|
@ -418,27 +419,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return await response.parse()
|
||||
|
||||
def _preserve_contexts_async_generator(self, gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve both tracing and headers context variables across iterations.
|
||||
This is needed because we start a new asyncio event loop for each request, and we need to preserve the context
|
||||
across the event loop boundary.
|
||||
"""
|
||||
tracing_context_value = CURRENT_TRACE_CONTEXT.get()
|
||||
headers_context_value = _provider_data_var.get()
|
||||
|
||||
async def wrapper():
|
||||
while True:
|
||||
_ = CURRENT_TRACE_CONTEXT.set(tracing_context_value)
|
||||
_ = _provider_data_var.set(headers_context_value)
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
yield item
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
return wrapper()
|
||||
|
||||
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
|
||||
if not body:
|
||||
return {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue