From 9d4716521d37c4967e9c6d813fda8e5d84d1c87c Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 11 Mar 2025 15:02:34 -0700 Subject: [PATCH] restore trace context in event loop of lib cli --- llama_stack/distribution/library_client.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 5dc70bb67..ee95f3b37 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -46,6 +46,7 @@ from llama_stack.distribution.stack import ( ) from llama_stack.distribution.utils.exec import in_notebook from llama_stack.providers.utils.telemetry.tracing import ( + CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace, @@ -55,6 +56,8 @@ logger = logging.getLogger(__name__) T = TypeVar("T") +trace_context = None + def convert_pydantic_to_json_value(value: Any) -> Any: if isinstance(value, Enum): @@ -156,8 +159,11 @@ class LlamaStackAsLibraryClient(LlamaStackClient): def sync_generator(): try: + global trace_context async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs)) while True: + if trace_context: + CURRENT_TRACE_CONTEXT.set(trace_context) chunk = loop.run_until_complete(async_stream.__anext__()) yield chunk except StopAsyncIteration: @@ -376,6 +382,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): async def gen(): await start_trace(options.url, {"__location__": "library_client"}) + global trace_context + trace_context = CURRENT_TRACE_CONTEXT.get() try: async for chunk in await func(**body): data = json.dumps(convert_pydantic_to_json_value(chunk))