diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 50af2cdea..aefffc326 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -9,12 +9,10 @@ import inspect import json import logging import os -import queue -import threading from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path -from typing import Any, Generator, get_args, get_origin, Optional, TypeVar +from typing import Any, get_args, get_origin, Optional, TypeVar import httpx import yaml @@ -64,69 +62,69 @@ def in_notebook(): return True -def stream_across_asyncio_run_boundary( - async_gen_maker, - pool_executor: ThreadPoolExecutor, - path: Optional[str] = None, - provider_data: Optional[dict[str, Any]] = None, -) -> Generator[T, None, None]: - result_queue = queue.Queue() - stop_event = threading.Event() +# def stream_across_asyncio_run_boundary( +# async_gen_maker, +# pool_executor: ThreadPoolExecutor, +# path: Optional[str] = None, +# provider_data: Optional[dict[str, Any]] = None, +# ) -> Generator[T, None, None]: +# result_queue = queue.Queue() +# stop_event = threading.Event() - async def consumer(): - # make sure we make the generator in the event loop context - gen = await async_gen_maker() - await start_trace(path, {"__location__": "library_client"}) - if provider_data: - set_request_provider_data( - {"X-LlamaStack-Provider-Data": json.dumps(provider_data)} - ) - try: - async for item in await gen: - result_queue.put(item) - except Exception as e: - print(f"Error in generator {e}") - result_queue.put(e) - except asyncio.CancelledError: - return - finally: - result_queue.put(StopIteration) - stop_event.set() - await end_trace() +# async def consumer(): +# # make sure we make the generator in the event loop context +# gen = await async_gen_maker() +# await start_trace(path, {"__location__": "library_client"}) +# if provider_data: +# set_request_provider_data( +# {"X-LlamaStack-Provider-Data": json.dumps(provider_data)} +# ) +# try: +# async for item in await gen: +# result_queue.put(item) +# except Exception as e: +# print(f"Error in generator {e}") +# result_queue.put(e) +# except asyncio.CancelledError: +# return +# finally: +# result_queue.put(StopIteration) +# stop_event.set() +# await end_trace() - def run_async(): - # Run our own loop to avoid double async generator cleanup which is done - # by asyncio.run() - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - task = loop.create_task(consumer()) - loop.run_until_complete(task) - finally: - # Handle pending tasks like a generator's athrow() - pending = asyncio.all_tasks(loop) - if pending: - loop.run_until_complete( - asyncio.gather(*pending, return_exceptions=True) - ) - loop.close() +# def run_async(): +# # Run our own loop to avoid double async generator cleanup which is done +# # by asyncio.run() +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) +# try: +# task = loop.create_task(consumer()) +# loop.run_until_complete(task) +# finally: +# # Handle pending tasks like a generator's athrow() +# pending = asyncio.all_tasks(loop) +# if pending: +# loop.run_until_complete( +# asyncio.gather(*pending, return_exceptions=True) +# ) +# loop.close() - future = pool_executor.submit(run_async) +# future = pool_executor.submit(run_async) - try: - # yield results as they come in - while not stop_event.is_set() or not result_queue.empty(): - try: - item = result_queue.get(timeout=0.1) - if item is StopIteration: - break - if isinstance(item, Exception): - raise item - yield item - except queue.Empty: - continue - finally: - future.result() +# try: +# # yield results as they come in +# while not stop_event.is_set() or not result_queue.empty(): +# try: +# item = result_queue.get(timeout=0.1) +# if item is StopIteration: +# break +# if isinstance(item, Exception): +# raise item +# yield item +# except queue.Empty: +# continue +# finally: +# future.result() def convert_pydantic_to_json_value(value: Any) -> Any: @@ -223,12 +221,36 @@ class LlamaStackAsLibraryClient(LlamaStackClient): def request(self, *args, **kwargs): path = self._get_path(*args, **kwargs) if kwargs.get("stream"): - return stream_across_asyncio_run_boundary( - lambda: self.async_client.request(*args, **kwargs), - self.pool_executor, - path=path, - provider_data=self.provider_data, + # NOTE: We are using AsyncLlamaStackClient under the hood + # A new event loop is needed to convert the AsyncStream + # from async client into SyncStream return type for streaming + print("create new event loop") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Call the async client request and get the AsyncStream + async_stream = loop.run_until_complete( + self.async_client.request(*args, **kwargs) ) + + def sync_generator(): + try: + while True: + chunk = loop.run_until_complete(async_stream.__anext__()) + yield chunk + except StopAsyncIteration: + print("StopAsyncIteration in sync_generator") + finally: + loop.close() + + return sync_generator() + + # return stream_across_asyncio_run_boundary( + # lambda: self.async_client.request(*args, **kwargs), + # self.pool_executor, + # path=path, + # provider_data=self.provider_data, + # ) else: async def _traced_request(): @@ -330,7 +352,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): raise ValueError("Client not initialized") if stream: - return self._call_streaming( + return await self._call_streaming( cast_to=cast_to, options=options, stream_cls=stream_cls,