From 194d12b304cd0ec68f79235beea0a7fb2cbb16b9 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 14 Jan 2025 10:58:46 -0800 Subject: [PATCH] [bugfix] fix streaming GeneratorExit exception with LlamaStackAsLibraryClient (#760) # What does this PR do? #### Issue - Using Jupyter notebook with LlamaStackAsLibraryClient + streaming gives exception ``` Exception ignored in: Traceback (most recent call last): File "/opt/anaconda3/envs/fresh/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 404, in _aiter_ yield part RuntimeError: async generator ignored GeneratorExit ``` - Reproduce w/ https://github.com/meta-llama/llama-stack/blob/notebook-streaming-debug/inline.ipynb #### Fix - Issue likely comes from stream_across_asyncio_run_boundary closing connection too soon when interacting in jupyter environment - This uses an alternative way to convert AsyncStream to SyncStream return type by sync version of LlamaStackAsLibraryClient, which calls AsyncLlamaStackAsLibraryClient calling async impls under the hood #### Additional changes - Moved tracing logic into AsyncLlamaStackAsLibraryClient.request s.t. streaming / non-streaming request for LlamaStackAsLibraryClient shares same code ## Test Plan - Test w/ together & fireworks & ollama with streaming and non-streaming using notebook in: https://github.com/meta-llama/llama-stack/blob/notebook-streaming-debug/inline.ipynb - Note: need to restart kernel and run pip install -e . in jupyter interpreter for local code change to take effect image ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- llama_stack/distribution/library_client.py | 130 +++++---------------- 1 file changed, 31 insertions(+), 99 deletions(-) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 50af2cdea..0c124e64b 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -9,12 +9,10 @@ import inspect import json import logging import os -import queue -import threading from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path -from typing import Any, Generator, get_args, get_origin, Optional, TypeVar +from typing import Any, get_args, get_origin, Optional, TypeVar import httpx import yaml @@ -64,71 +62,6 @@ def in_notebook(): return True -def stream_across_asyncio_run_boundary( - async_gen_maker, - pool_executor: ThreadPoolExecutor, - path: Optional[str] = None, - provider_data: Optional[dict[str, Any]] = None, -) -> Generator[T, None, None]: - result_queue = queue.Queue() - stop_event = threading.Event() - - async def consumer(): - # make sure we make the generator in the event loop context - gen = await async_gen_maker() - await start_trace(path, {"__location__": "library_client"}) - if provider_data: - set_request_provider_data( - {"X-LlamaStack-Provider-Data": json.dumps(provider_data)} - ) - try: - async for item in await gen: - result_queue.put(item) - except Exception as e: - print(f"Error in generator {e}") - result_queue.put(e) - except asyncio.CancelledError: - return - finally: - result_queue.put(StopIteration) - stop_event.set() - await end_trace() - - def run_async(): - # Run our own loop to avoid double async generator cleanup which is done - # by asyncio.run() - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - task = loop.create_task(consumer()) - loop.run_until_complete(task) - finally: - # Handle pending tasks like a generator's athrow() - pending = asyncio.all_tasks(loop) - if pending: - loop.run_until_complete( - asyncio.gather(*pending, return_exceptions=True) - ) - loop.close() - - future = pool_executor.submit(run_async) - - try: - # yield results as they come in - while not stop_event.is_set() or not result_queue.empty(): - try: - item = result_queue.get(timeout=0.1) - if item is StopIteration: - break - if isinstance(item, Exception): - raise item - yield item - except queue.Empty: - continue - finally: - future.result() - - def convert_pydantic_to_json_value(value: Any) -> Any: if isinstance(value, Enum): return value.value @@ -184,7 +117,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): ): super().__init__() self.async_client = AsyncLlamaStackAsLibraryClient( - config_path_or_template_name, custom_provider_registry + config_path_or_template_name, custom_provider_registry, provider_data ) self.pool_executor = ThreadPoolExecutor(max_workers=4) self.skip_logger_removal = skip_logger_removal @@ -210,39 +143,30 @@ class LlamaStackAsLibraryClient(LlamaStackClient): root_logger.removeHandler(handler) print(f"Removed handler {handler.__class__.__name__} from root logger") - def _get_path( - self, - cast_to: Any, - options: Any, - *, - stream=False, - stream_cls=None, - ): - return options.url - def request(self, *args, **kwargs): - path = self._get_path(*args, **kwargs) if kwargs.get("stream"): - return stream_across_asyncio_run_boundary( - lambda: self.async_client.request(*args, **kwargs), - self.pool_executor, - path=path, - provider_data=self.provider_data, - ) - else: + # NOTE: We are using AsyncLlamaStackClient under the hood + # A new event loop is needed to convert the AsyncStream + # from async client into SyncStream return type for streaming + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) - async def _traced_request(): - if self.provider_data: - set_request_provider_data( - {"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)} - ) - await start_trace(path, {"__location__": "library_client"}) + def sync_generator(): try: - return await self.async_client.request(*args, **kwargs) + async_stream = loop.run_until_complete( + self.async_client.request(*args, **kwargs) + ) + while True: + chunk = loop.run_until_complete(async_stream.__anext__()) + yield chunk + except StopAsyncIteration: + pass finally: - await end_trace() + loop.close() - return asyncio.run(_traced_request()) + return sync_generator() + else: + return asyncio.run(self.async_client.request(*args, **kwargs)) class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): @@ -250,9 +174,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self, config_path_or_template_name: str, custom_provider_registry: Optional[ProviderRegistry] = None, + provider_data: Optional[dict[str, Any]] = None, ): super().__init__() - # when using the library client, we should not log to console since many # of our logs are intended for server-side usage current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") @@ -273,6 +197,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self.config_path_or_template_name = config_path_or_template_name self.config = config self.custom_provider_registry = custom_provider_registry + self.provider_data = provider_data async def initialize(self): try: @@ -329,17 +254,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if not self.endpoint_impls: raise ValueError("Client not initialized") + if self.provider_data: + set_request_provider_data( + {"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)} + ) + await start_trace(options.url, {"__location__": "library_client"}) if stream: - return self._call_streaming( + response = await self._call_streaming( cast_to=cast_to, options=options, stream_cls=stream_cls, ) else: - return await self._call_non_streaming( + response = await self._call_non_streaming( cast_to=cast_to, options=options, ) + await end_trace() + return response async def _call_non_streaming( self,