make direct client streaming work properly

This commit is contained in:
Ashwin Bharambe 2024-12-07 11:38:56 -08:00
parent fd48cf3fc1
commit 86b5743081
2 changed files with 115 additions and 18 deletions

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, get_args, get_origin, Optional
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
import yaml
from llama_stack_client import LlamaStackClient, NOT_GIVEN
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
from pydantic import TypeAdapter
from rich.console import Console
@ -25,6 +29,65 @@ from llama_stack.distribution.stack import (
replace_env_vars,
)
T = TypeVar("T")
def stream_across_asyncio_run_boundary(
async_gen_maker,
pool_executor: ThreadPoolExecutor,
) -> Generator[T, None, None]:
result_queue = queue.Queue()
stop_event = threading.Event()
async def consumer():
# make sure we make the generator in the event loop context
gen = await async_gen_maker()
try:
async for item in gen:
result_queue.put(item)
except Exception as e:
print(f"Error in generator {e}")
result_queue.put(e)
except asyncio.CancelledError:
return
finally:
result_queue.put(StopIteration)
stop_event.set()
def run_async():
# Run our own loop to avoid double async generator cleanup which is done
# by asyncio.run()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
task = loop.create_task(consumer())
loop.run_until_complete(task)
finally:
# Handle pending tasks like a generator's athrow()
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
future = pool_executor.submit(run_async)
try:
# yield results as they come in
while not stop_event.is_set() or not result_queue.empty():
try:
item = result_queue.get(timeout=0.1)
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
except queue.Empty:
continue
finally:
future.result()
class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__(
@ -32,6 +95,37 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None,
):
super().__init__()
self.async_client = LlamaStackAsLibraryAsyncClient(
config_path_or_template_name, custom_provider_registry
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
def initialize(self):
asyncio.run(self.async_client.initialize())
def get(self, *args, **kwargs):
assert not kwargs.get("stream"), "GET never called with stream=True"
return asyncio.run(self.async_client.get(*args, **kwargs))
def post(self, *args, **kwargs):
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.post(*args, **kwargs),
self.pool_executor,
)
else:
return asyncio.run(self.async_client.post(*args, **kwargs))
class LlamaStackAsLibraryAsyncClient(AsyncLlamaStackClient):
def __init__(
self,
config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None,
):
super().__init__()
if config_path_or_template_name.endswith(".yaml"):
config_path = Path(config_path_or_template_name)
if not config_path.exists():
@ -46,8 +140,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
self.config = config
self.custom_provider_registry = custom_provider_registry
super().__init__()
async def initialize(self):
try:
self.impls = await construct_stack(
@ -153,6 +245,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
try:
return [self._convert_param(item_type, item) for item in value]
except Exception:
print(f"Error converting list {value}")
return value
elif origin is dict:
@ -160,6 +253,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
try:
return {k: self._convert_param(val_type, v) for k, v in value.items()}
except Exception:
print(f"Error converting dict {value}")
return value
try: