fix streaming inline client

This commit is contained in:
Xi Yan 2025-01-13 20:16:43 -08:00
parent ee4e04804f
commit d6dd2ba471

View file

@ -9,12 +9,10 @@ import inspect
import json
import logging
import os
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
from typing import Any, get_args, get_origin, Optional, TypeVar
import httpx
import yaml
@ -64,69 +62,69 @@ def in_notebook():
return True
def stream_across_asyncio_run_boundary(
async_gen_maker,
pool_executor: ThreadPoolExecutor,
path: Optional[str] = None,
provider_data: Optional[dict[str, Any]] = None,
) -> Generator[T, None, None]:
result_queue = queue.Queue()
stop_event = threading.Event()
# def stream_across_asyncio_run_boundary(
# async_gen_maker,
# pool_executor: ThreadPoolExecutor,
# path: Optional[str] = None,
# provider_data: Optional[dict[str, Any]] = None,
# ) -> Generator[T, None, None]:
# result_queue = queue.Queue()
# stop_event = threading.Event()
async def consumer():
# make sure we make the generator in the event loop context
gen = await async_gen_maker()
await start_trace(path, {"__location__": "library_client"})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
)
try:
async for item in await gen:
result_queue.put(item)
except Exception as e:
print(f"Error in generator {e}")
result_queue.put(e)
except asyncio.CancelledError:
return
finally:
result_queue.put(StopIteration)
stop_event.set()
await end_trace()
# async def consumer():
# # make sure we make the generator in the event loop context
# gen = await async_gen_maker()
# await start_trace(path, {"__location__": "library_client"})
# if provider_data:
# set_request_provider_data(
# {"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
# )
# try:
# async for item in await gen:
# result_queue.put(item)
# except Exception as e:
# print(f"Error in generator {e}")
# result_queue.put(e)
# except asyncio.CancelledError:
# return
# finally:
# result_queue.put(StopIteration)
# stop_event.set()
# await end_trace()
def run_async():
# Run our own loop to avoid double async generator cleanup which is done
# by asyncio.run()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
task = loop.create_task(consumer())
loop.run_until_complete(task)
finally:
# Handle pending tasks like a generator's athrow()
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
# def run_async():
# # Run our own loop to avoid double async generator cleanup which is done
# # by asyncio.run()
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# try:
# task = loop.create_task(consumer())
# loop.run_until_complete(task)
# finally:
# # Handle pending tasks like a generator's athrow()
# pending = asyncio.all_tasks(loop)
# if pending:
# loop.run_until_complete(
# asyncio.gather(*pending, return_exceptions=True)
# )
# loop.close()
future = pool_executor.submit(run_async)
# future = pool_executor.submit(run_async)
try:
# yield results as they come in
while not stop_event.is_set() or not result_queue.empty():
try:
item = result_queue.get(timeout=0.1)
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
except queue.Empty:
continue
finally:
future.result()
# try:
# # yield results as they come in
# while not stop_event.is_set() or not result_queue.empty():
# try:
# item = result_queue.get(timeout=0.1)
# if item is StopIteration:
# break
# if isinstance(item, Exception):
# raise item
# yield item
# except queue.Empty:
# continue
# finally:
# future.result()
def convert_pydantic_to_json_value(value: Any) -> Any:
@ -223,12 +221,36 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
def request(self, *args, **kwargs):
path = self._get_path(*args, **kwargs)
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.request(*args, **kwargs),
self.pool_executor,
path=path,
provider_data=self.provider_data,
# NOTE: We are using AsyncLlamaStackClient under the hood
# A new event loop is needed to convert the AsyncStream
# from async client into SyncStream return type for streaming
print("create new event loop")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Call the async client request and get the AsyncStream
async_stream = loop.run_until_complete(
self.async_client.request(*args, **kwargs)
)
def sync_generator():
try:
while True:
chunk = loop.run_until_complete(async_stream.__anext__())
yield chunk
except StopAsyncIteration:
print("StopAsyncIteration in sync_generator")
finally:
loop.close()
return sync_generator()
# return stream_across_asyncio_run_boundary(
# lambda: self.async_client.request(*args, **kwargs),
# self.pool_executor,
# path=path,
# provider_data=self.provider_data,
# )
else:
async def _traced_request():
@ -330,7 +352,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError("Client not initialized")
if stream:
return self._call_streaming(
return await self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,