diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index b79ed0f7c..81d7eae48 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -4,12 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import inspect +import queue +import threading +from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Any, get_args, get_origin, Optional +from typing import Any, Generator, get_args, get_origin, Optional, TypeVar import yaml -from llama_stack_client import LlamaStackClient, NOT_GIVEN +from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN from pydantic import TypeAdapter from rich.console import Console @@ -25,6 +29,65 @@ from llama_stack.distribution.stack import ( replace_env_vars, ) +T = TypeVar("T") + + +def stream_across_asyncio_run_boundary( + async_gen_maker, + pool_executor: ThreadPoolExecutor, +) -> Generator[T, None, None]: + result_queue = queue.Queue() + stop_event = threading.Event() + + async def consumer(): + # make sure we make the generator in the event loop context + gen = await async_gen_maker() + try: + async for item in gen: + result_queue.put(item) + except Exception as e: + print(f"Error in generator {e}") + result_queue.put(e) + except asyncio.CancelledError: + return + finally: + result_queue.put(StopIteration) + stop_event.set() + + def run_async(): + # Run our own loop to avoid double async generator cleanup which is done + # by asyncio.run() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + task = loop.create_task(consumer()) + loop.run_until_complete(task) + finally: + # Handle pending tasks like a generator's athrow() + pending = asyncio.all_tasks(loop) + if pending: + loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) + loop.close() + + future = pool_executor.submit(run_async) + + try: + # yield results as they come in + while not stop_event.is_set() or not result_queue.empty(): + try: + item = result_queue.get(timeout=0.1) + if item is StopIteration: + break + if isinstance(item, Exception): + raise item + yield item + except queue.Empty: + continue + finally: + future.result() + class LlamaStackAsLibraryClient(LlamaStackClient): def __init__( @@ -32,6 +95,37 @@ class LlamaStackAsLibraryClient(LlamaStackClient): config_path_or_template_name: str, custom_provider_registry: Optional[ProviderRegistry] = None, ): + super().__init__() + self.async_client = LlamaStackAsLibraryAsyncClient( + config_path_or_template_name, custom_provider_registry + ) + self.pool_executor = ThreadPoolExecutor(max_workers=4) + + def initialize(self): + asyncio.run(self.async_client.initialize()) + + def get(self, *args, **kwargs): + assert not kwargs.get("stream"), "GET never called with stream=True" + return asyncio.run(self.async_client.get(*args, **kwargs)) + + def post(self, *args, **kwargs): + if kwargs.get("stream"): + return stream_across_asyncio_run_boundary( + lambda: self.async_client.post(*args, **kwargs), + self.pool_executor, + ) + else: + return asyncio.run(self.async_client.post(*args, **kwargs)) + + +class LlamaStackAsLibraryAsyncClient(AsyncLlamaStackClient): + def __init__( + self, + config_path_or_template_name: str, + custom_provider_registry: Optional[ProviderRegistry] = None, + ): + super().__init__() + if config_path_or_template_name.endswith(".yaml"): config_path = Path(config_path_or_template_name) if not config_path.exists(): @@ -46,8 +140,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient): self.config = config self.custom_provider_registry = custom_provider_registry - super().__init__() - async def initialize(self): try: self.impls = await construct_stack( @@ -153,6 +245,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): try: return [self._convert_param(item_type, item) for item in value] except Exception: + print(f"Error converting list {value}") return value elif origin is dict: @@ -160,6 +253,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): try: return {k: self._convert_param(val_type, v) for k, v in value.items()} except Exception: + print(f"Error converting dict {value}") return value try: diff --git a/llama_stack/distribution/tests/library_client_test.py b/llama_stack/distribution/tests/library_client_test.py index fc1ab5e7e..d6b1130c6 100644 --- a/llama_stack/distribution/tests/library_client_test.py +++ b/llama_stack/distribution/tests/library_client_test.py @@ -7,38 +7,43 @@ import argparse from llama_stack.distribution.library_client import LlamaStackAsLibraryClient +from llama_stack_client.lib.inference.event_logger import EventLogger from llama_stack_client.types import UserMessage -async def main(config_path: str): +def main(config_path: str): client = LlamaStackAsLibraryClient(config_path) - await client.initialize() + client.initialize() + + models = client.models.list() + print("\nModels:") + for model in models: + print(model) - models = await client.models.list() - print(models) if not models: print("No models found, skipping chat completion test") return model_id = models[0].identifier - response = await client.inference.chat_completion( + response = client.inference.chat_completion( messages=[UserMessage(content="What is the capital of France?", role="user")], model_id=model_id, stream=False, ) - print("\nChat completion response:") + print("\nChat completion response (non-stream):") print(response) - response = await client.inference.chat_completion( + response = client.inference.chat_completion( messages=[UserMessage(content="What is the capital of France?", role="user")], model_id=model_id, stream=True, ) - print("\nChat completion stream response:") - async for chunk in response: - print(chunk) - response = await client.memory_banks.register( + print("\nChat completion response (stream):") + for log in EventLogger().log(response): + log.print() + + response = client.memory_banks.register( memory_bank_id="memory_bank_id", params={ "chunk_size_in_tokens": 0, @@ -51,9 +56,7 @@ async def main(config_path: str): if __name__ == "__main__": - import asyncio - parser = argparse.ArgumentParser() parser.add_argument("config_path", help="Path to the config YAML file") args = parser.parse_args() - asyncio.run(main(args.config_path)) + main(args.config_path)