fix streaming inline client

This commit is contained in:
Xi Yan 2025-01-13 20:16:43 -08:00
parent ee4e04804f
commit d6dd2ba471

View file

@ -9,12 +9,10 @@ import inspect
import json import json
import logging import logging
import os import os
import queue
import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar from typing import Any, get_args, get_origin, Optional, TypeVar
import httpx import httpx
import yaml import yaml
@ -64,69 +62,69 @@ def in_notebook():
return True return True
def stream_across_asyncio_run_boundary( # def stream_across_asyncio_run_boundary(
async_gen_maker, # async_gen_maker,
pool_executor: ThreadPoolExecutor, # pool_executor: ThreadPoolExecutor,
path: Optional[str] = None, # path: Optional[str] = None,
provider_data: Optional[dict[str, Any]] = None, # provider_data: Optional[dict[str, Any]] = None,
) -> Generator[T, None, None]: # ) -> Generator[T, None, None]:
result_queue = queue.Queue() # result_queue = queue.Queue()
stop_event = threading.Event() # stop_event = threading.Event()
async def consumer(): # async def consumer():
# make sure we make the generator in the event loop context # # make sure we make the generator in the event loop context
gen = await async_gen_maker() # gen = await async_gen_maker()
await start_trace(path, {"__location__": "library_client"}) # await start_trace(path, {"__location__": "library_client"})
if provider_data: # if provider_data:
set_request_provider_data( # set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)} # {"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
) # )
try: # try:
async for item in await gen: # async for item in await gen:
result_queue.put(item) # result_queue.put(item)
except Exception as e: # except Exception as e:
print(f"Error in generator {e}") # print(f"Error in generator {e}")
result_queue.put(e) # result_queue.put(e)
except asyncio.CancelledError: # except asyncio.CancelledError:
return # return
finally: # finally:
result_queue.put(StopIteration) # result_queue.put(StopIteration)
stop_event.set() # stop_event.set()
await end_trace() # await end_trace()
def run_async(): # def run_async():
# Run our own loop to avoid double async generator cleanup which is done # # Run our own loop to avoid double async generator cleanup which is done
# by asyncio.run() # # by asyncio.run()
loop = asyncio.new_event_loop() # loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) # asyncio.set_event_loop(loop)
try: # try:
task = loop.create_task(consumer()) # task = loop.create_task(consumer())
loop.run_until_complete(task) # loop.run_until_complete(task)
finally: # finally:
# Handle pending tasks like a generator's athrow() # # Handle pending tasks like a generator's athrow()
pending = asyncio.all_tasks(loop) # pending = asyncio.all_tasks(loop)
if pending: # if pending:
loop.run_until_complete( # loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True) # asyncio.gather(*pending, return_exceptions=True)
) # )
loop.close() # loop.close()
future = pool_executor.submit(run_async) # future = pool_executor.submit(run_async)
try: # try:
# yield results as they come in # # yield results as they come in
while not stop_event.is_set() or not result_queue.empty(): # while not stop_event.is_set() or not result_queue.empty():
try: # try:
item = result_queue.get(timeout=0.1) # item = result_queue.get(timeout=0.1)
if item is StopIteration: # if item is StopIteration:
break # break
if isinstance(item, Exception): # if isinstance(item, Exception):
raise item # raise item
yield item # yield item
except queue.Empty: # except queue.Empty:
continue # continue
finally: # finally:
future.result() # future.result()
def convert_pydantic_to_json_value(value: Any) -> Any: def convert_pydantic_to_json_value(value: Any) -> Any:
@ -223,12 +221,36 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
path = self._get_path(*args, **kwargs) path = self._get_path(*args, **kwargs)
if kwargs.get("stream"): if kwargs.get("stream"):
return stream_across_asyncio_run_boundary( # NOTE: We are using AsyncLlamaStackClient under the hood
lambda: self.async_client.request(*args, **kwargs), # A new event loop is needed to convert the AsyncStream
self.pool_executor, # from async client into SyncStream return type for streaming
path=path, print("create new event loop")
provider_data=self.provider_data, loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Call the async client request and get the AsyncStream
async_stream = loop.run_until_complete(
self.async_client.request(*args, **kwargs)
) )
def sync_generator():
try:
while True:
chunk = loop.run_until_complete(async_stream.__anext__())
yield chunk
except StopAsyncIteration:
print("StopAsyncIteration in sync_generator")
finally:
loop.close()
return sync_generator()
# return stream_across_asyncio_run_boundary(
# lambda: self.async_client.request(*args, **kwargs),
# self.pool_executor,
# path=path,
# provider_data=self.provider_data,
# )
else: else:
async def _traced_request(): async def _traced_request():
@ -330,7 +352,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError("Client not initialized") raise ValueError("Client not initialized")
if stream: if stream:
return self._call_streaming( return await self._call_streaming(
cast_to=cast_to, cast_to=cast_to,
options=options, options=options,
stream_cls=stream_cls, stream_cls=stream_cls,