mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
make direct client streaming work properly
This commit is contained in:
parent
fd48cf3fc1
commit
86b5743081
2 changed files with 115 additions and 18 deletions
|
@ -4,12 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any, get_args, get_origin, Optional
|
||||
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
|
||||
|
||||
import yaml
|
||||
from llama_stack_client import LlamaStackClient, NOT_GIVEN
|
||||
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
|
||||
from pydantic import TypeAdapter
|
||||
from rich.console import Console
|
||||
|
||||
|
@ -25,6 +29,65 @@ from llama_stack.distribution.stack import (
|
|||
replace_env_vars,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def stream_across_asyncio_run_boundary(
|
||||
async_gen_maker,
|
||||
pool_executor: ThreadPoolExecutor,
|
||||
) -> Generator[T, None, None]:
|
||||
result_queue = queue.Queue()
|
||||
stop_event = threading.Event()
|
||||
|
||||
async def consumer():
|
||||
# make sure we make the generator in the event loop context
|
||||
gen = await async_gen_maker()
|
||||
try:
|
||||
async for item in gen:
|
||||
result_queue.put(item)
|
||||
except Exception as e:
|
||||
print(f"Error in generator {e}")
|
||||
result_queue.put(e)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
finally:
|
||||
result_queue.put(StopIteration)
|
||||
stop_event.set()
|
||||
|
||||
def run_async():
|
||||
# Run our own loop to avoid double async generator cleanup which is done
|
||||
# by asyncio.run()
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
task = loop.create_task(consumer())
|
||||
loop.run_until_complete(task)
|
||||
finally:
|
||||
# Handle pending tasks like a generator's athrow()
|
||||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(
|
||||
asyncio.gather(*pending, return_exceptions=True)
|
||||
)
|
||||
loop.close()
|
||||
|
||||
future = pool_executor.submit(run_async)
|
||||
|
||||
try:
|
||||
# yield results as they come in
|
||||
while not stop_event.is_set() or not result_queue.empty():
|
||||
try:
|
||||
item = result_queue.get(timeout=0.1)
|
||||
if item is StopIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
except queue.Empty:
|
||||
continue
|
||||
finally:
|
||||
future.result()
|
||||
|
||||
|
||||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||
def __init__(
|
||||
|
@ -32,6 +95,37 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
config_path_or_template_name: str,
|
||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.async_client = LlamaStackAsLibraryAsyncClient(
|
||||
config_path_or_template_name, custom_provider_registry
|
||||
)
|
||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
def initialize(self):
|
||||
asyncio.run(self.async_client.initialize())
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
assert not kwargs.get("stream"), "GET never called with stream=True"
|
||||
return asyncio.run(self.async_client.get(*args, **kwargs))
|
||||
|
||||
def post(self, *args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
return stream_across_asyncio_run_boundary(
|
||||
lambda: self.async_client.post(*args, **kwargs),
|
||||
self.pool_executor,
|
||||
)
|
||||
else:
|
||||
return asyncio.run(self.async_client.post(*args, **kwargs))
|
||||
|
||||
|
||||
class LlamaStackAsLibraryAsyncClient(AsyncLlamaStackClient):
|
||||
def __init__(
|
||||
self,
|
||||
config_path_or_template_name: str,
|
||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if config_path_or_template_name.endswith(".yaml"):
|
||||
config_path = Path(config_path_or_template_name)
|
||||
if not config_path.exists():
|
||||
|
@ -46,8 +140,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
self.config = config
|
||||
self.custom_provider_registry = custom_provider_registry
|
||||
|
||||
super().__init__()
|
||||
|
||||
async def initialize(self):
|
||||
try:
|
||||
self.impls = await construct_stack(
|
||||
|
@ -153,6 +245,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
try:
|
||||
return [self._convert_param(item_type, item) for item in value]
|
||||
except Exception:
|
||||
print(f"Error converting list {value}")
|
||||
return value
|
||||
|
||||
elif origin is dict:
|
||||
|
@ -160,6 +253,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
try:
|
||||
return {k: self._convert_param(val_type, v) for k, v in value.items()}
|
||||
except Exception:
|
||||
print(f"Error converting dict {value}")
|
||||
return value
|
||||
|
||||
try:
|
||||
|
|
|
@ -7,38 +7,43 @@
|
|||
import argparse
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
from llama_stack_client.lib.inference.event_logger import EventLogger
|
||||
from llama_stack_client.types import UserMessage
|
||||
|
||||
|
||||
async def main(config_path: str):
|
||||
def main(config_path: str):
|
||||
client = LlamaStackAsLibraryClient(config_path)
|
||||
await client.initialize()
|
||||
client.initialize()
|
||||
|
||||
models = client.models.list()
|
||||
print("\nModels:")
|
||||
for model in models:
|
||||
print(model)
|
||||
|
||||
models = await client.models.list()
|
||||
print(models)
|
||||
if not models:
|
||||
print("No models found, skipping chat completion test")
|
||||
return
|
||||
|
||||
model_id = models[0].identifier
|
||||
response = await client.inference.chat_completion(
|
||||
response = client.inference.chat_completion(
|
||||
messages=[UserMessage(content="What is the capital of France?", role="user")],
|
||||
model_id=model_id,
|
||||
stream=False,
|
||||
)
|
||||
print("\nChat completion response:")
|
||||
print("\nChat completion response (non-stream):")
|
||||
print(response)
|
||||
|
||||
response = await client.inference.chat_completion(
|
||||
response = client.inference.chat_completion(
|
||||
messages=[UserMessage(content="What is the capital of France?", role="user")],
|
||||
model_id=model_id,
|
||||
stream=True,
|
||||
)
|
||||
print("\nChat completion stream response:")
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
|
||||
response = await client.memory_banks.register(
|
||||
print("\nChat completion response (stream):")
|
||||
for log in EventLogger().log(response):
|
||||
log.print()
|
||||
|
||||
response = client.memory_banks.register(
|
||||
memory_bank_id="memory_bank_id",
|
||||
params={
|
||||
"chunk_size_in_tokens": 0,
|
||||
|
@ -51,9 +56,7 @@ async def main(config_path: str):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("config_path", help="Path to the config YAML file")
|
||||
args = parser.parse_args()
|
||||
asyncio.run(main(args.config_path))
|
||||
main(args.config_path)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue