make direct client streaming work properly

This commit is contained in:
Ashwin Bharambe 2024-12-07 11:38:56 -08:00
parent fd48cf3fc1
commit 86b5743081
2 changed files with 115 additions and 18 deletions

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, get_args, get_origin, Optional
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
import yaml
from llama_stack_client import LlamaStackClient, NOT_GIVEN
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
from pydantic import TypeAdapter
from rich.console import Console
@ -25,6 +29,65 @@ from llama_stack.distribution.stack import (
replace_env_vars,
)
T = TypeVar("T")
def stream_across_asyncio_run_boundary(
async_gen_maker,
pool_executor: ThreadPoolExecutor,
) -> Generator[T, None, None]:
result_queue = queue.Queue()
stop_event = threading.Event()
async def consumer():
# make sure we make the generator in the event loop context
gen = await async_gen_maker()
try:
async for item in gen:
result_queue.put(item)
except Exception as e:
print(f"Error in generator {e}")
result_queue.put(e)
except asyncio.CancelledError:
return
finally:
result_queue.put(StopIteration)
stop_event.set()
def run_async():
# Run our own loop to avoid double async generator cleanup which is done
# by asyncio.run()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
task = loop.create_task(consumer())
loop.run_until_complete(task)
finally:
# Handle pending tasks like a generator's athrow()
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
future = pool_executor.submit(run_async)
try:
# yield results as they come in
while not stop_event.is_set() or not result_queue.empty():
try:
item = result_queue.get(timeout=0.1)
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
except queue.Empty:
continue
finally:
future.result()
class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__(
@ -32,6 +95,37 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None,
):
super().__init__()
self.async_client = LlamaStackAsLibraryAsyncClient(
config_path_or_template_name, custom_provider_registry
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
def initialize(self):
asyncio.run(self.async_client.initialize())
def get(self, *args, **kwargs):
assert not kwargs.get("stream"), "GET never called with stream=True"
return asyncio.run(self.async_client.get(*args, **kwargs))
def post(self, *args, **kwargs):
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.post(*args, **kwargs),
self.pool_executor,
)
else:
return asyncio.run(self.async_client.post(*args, **kwargs))
class LlamaStackAsLibraryAsyncClient(AsyncLlamaStackClient):
def __init__(
self,
config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None,
):
super().__init__()
if config_path_or_template_name.endswith(".yaml"):
config_path = Path(config_path_or_template_name)
if not config_path.exists():
@ -46,8 +140,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
self.config = config
self.custom_provider_registry = custom_provider_registry
super().__init__()
async def initialize(self):
try:
self.impls = await construct_stack(
@ -153,6 +245,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
try:
return [self._convert_param(item_type, item) for item in value]
except Exception:
print(f"Error converting list {value}")
return value
elif origin is dict:
@ -160,6 +253,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
try:
return {k: self._convert_param(val_type, v) for k, v in value.items()}
except Exception:
print(f"Error converting dict {value}")
return value
try:

View file

@ -7,38 +7,43 @@
import argparse
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.lib.inference.event_logger import EventLogger
from llama_stack_client.types import UserMessage
async def main(config_path: str):
def main(config_path: str):
client = LlamaStackAsLibraryClient(config_path)
await client.initialize()
client.initialize()
models = client.models.list()
print("\nModels:")
for model in models:
print(model)
models = await client.models.list()
print(models)
if not models:
print("No models found, skipping chat completion test")
return
model_id = models[0].identifier
response = await client.inference.chat_completion(
response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id,
stream=False,
)
print("\nChat completion response:")
print("\nChat completion response (non-stream):")
print(response)
response = await client.inference.chat_completion(
response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id,
stream=True,
)
print("\nChat completion stream response:")
async for chunk in response:
print(chunk)
response = await client.memory_banks.register(
print("\nChat completion response (stream):")
for log in EventLogger().log(response):
log.print()
response = client.memory_banks.register(
memory_bank_id="memory_bank_id",
params={
"chunk_size_in_tokens": 0,
@ -51,9 +56,7 @@ async def main(config_path: str):
if __name__ == "__main__":
import asyncio
parser = argparse.ArgumentParser()
parser.add_argument("config_path", help="Path to the config YAML file")
args = parser.parse_args()
asyncio.run(main(args.config_path))
main(args.config_path)