make direct client streaming work properly

This commit is contained in:
Ashwin Bharambe 2024-12-07 11:38:56 -08:00
parent fd48cf3fc1
commit 86b5743081
2 changed files with 115 additions and 18 deletions

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import inspect import inspect
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from typing import Any, get_args, get_origin, Optional from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
import yaml import yaml
from llama_stack_client import LlamaStackClient, NOT_GIVEN from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
from pydantic import TypeAdapter from pydantic import TypeAdapter
from rich.console import Console from rich.console import Console
@ -25,6 +29,65 @@ from llama_stack.distribution.stack import (
replace_env_vars, replace_env_vars,
) )
T = TypeVar("T")
def stream_across_asyncio_run_boundary(
async_gen_maker,
pool_executor: ThreadPoolExecutor,
) -> Generator[T, None, None]:
result_queue = queue.Queue()
stop_event = threading.Event()
async def consumer():
# make sure we make the generator in the event loop context
gen = await async_gen_maker()
try:
async for item in gen:
result_queue.put(item)
except Exception as e:
print(f"Error in generator {e}")
result_queue.put(e)
except asyncio.CancelledError:
return
finally:
result_queue.put(StopIteration)
stop_event.set()
def run_async():
# Run our own loop to avoid double async generator cleanup which is done
# by asyncio.run()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
task = loop.create_task(consumer())
loop.run_until_complete(task)
finally:
# Handle pending tasks like a generator's athrow()
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
future = pool_executor.submit(run_async)
try:
# yield results as they come in
while not stop_event.is_set() or not result_queue.empty():
try:
item = result_queue.get(timeout=0.1)
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
except queue.Empty:
continue
finally:
future.result()
class LlamaStackAsLibraryClient(LlamaStackClient): class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__( def __init__(
@ -32,6 +95,37 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
config_path_or_template_name: str, config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None, custom_provider_registry: Optional[ProviderRegistry] = None,
): ):
super().__init__()
self.async_client = LlamaStackAsLibraryAsyncClient(
config_path_or_template_name, custom_provider_registry
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
def initialize(self):
asyncio.run(self.async_client.initialize())
def get(self, *args, **kwargs):
assert not kwargs.get("stream"), "GET never called with stream=True"
return asyncio.run(self.async_client.get(*args, **kwargs))
def post(self, *args, **kwargs):
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.post(*args, **kwargs),
self.pool_executor,
)
else:
return asyncio.run(self.async_client.post(*args, **kwargs))
class LlamaStackAsLibraryAsyncClient(AsyncLlamaStackClient):
def __init__(
self,
config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None,
):
super().__init__()
if config_path_or_template_name.endswith(".yaml"): if config_path_or_template_name.endswith(".yaml"):
config_path = Path(config_path_or_template_name) config_path = Path(config_path_or_template_name)
if not config_path.exists(): if not config_path.exists():
@ -46,8 +140,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
self.config = config self.config = config
self.custom_provider_registry = custom_provider_registry self.custom_provider_registry = custom_provider_registry
super().__init__()
async def initialize(self): async def initialize(self):
try: try:
self.impls = await construct_stack( self.impls = await construct_stack(
@ -153,6 +245,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
try: try:
return [self._convert_param(item_type, item) for item in value] return [self._convert_param(item_type, item) for item in value]
except Exception: except Exception:
print(f"Error converting list {value}")
return value return value
elif origin is dict: elif origin is dict:
@ -160,6 +253,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
try: try:
return {k: self._convert_param(val_type, v) for k, v in value.items()} return {k: self._convert_param(val_type, v) for k, v in value.items()}
except Exception: except Exception:
print(f"Error converting dict {value}")
return value return value
try: try:

View file

@ -7,38 +7,43 @@
import argparse import argparse
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.lib.inference.event_logger import EventLogger
from llama_stack_client.types import UserMessage from llama_stack_client.types import UserMessage
async def main(config_path: str): def main(config_path: str):
client = LlamaStackAsLibraryClient(config_path) client = LlamaStackAsLibraryClient(config_path)
await client.initialize() client.initialize()
models = client.models.list()
print("\nModels:")
for model in models:
print(model)
models = await client.models.list()
print(models)
if not models: if not models:
print("No models found, skipping chat completion test") print("No models found, skipping chat completion test")
return return
model_id = models[0].identifier model_id = models[0].identifier
response = await client.inference.chat_completion( response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")], messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id, model_id=model_id,
stream=False, stream=False,
) )
print("\nChat completion response:") print("\nChat completion response (non-stream):")
print(response) print(response)
response = await client.inference.chat_completion( response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")], messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id, model_id=model_id,
stream=True, stream=True,
) )
print("\nChat completion stream response:")
async for chunk in response:
print(chunk)
response = await client.memory_banks.register( print("\nChat completion response (stream):")
for log in EventLogger().log(response):
log.print()
response = client.memory_banks.register(
memory_bank_id="memory_bank_id", memory_bank_id="memory_bank_id",
params={ params={
"chunk_size_in_tokens": 0, "chunk_size_in_tokens": 0,
@ -51,9 +56,7 @@ async def main(config_path: str):
if __name__ == "__main__": if __name__ == "__main__":
import asyncio
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("config_path", help="Path to the config YAML file") parser.add_argument("config_path", help="Path to the config YAML file")
args = parser.parse_args() args = parser.parse_args()
asyncio.run(main(args.config_path)) main(args.config_path)