mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
make direct client streaming work properly
This commit is contained in:
parent
fd48cf3fc1
commit
86b5743081
2 changed files with 115 additions and 18 deletions
|
@ -4,12 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, get_args, get_origin, Optional
|
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from llama_stack_client import LlamaStackClient, NOT_GIVEN
|
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
|
@ -25,6 +29,65 @@ from llama_stack.distribution.stack import (
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def stream_across_asyncio_run_boundary(
|
||||||
|
async_gen_maker,
|
||||||
|
pool_executor: ThreadPoolExecutor,
|
||||||
|
) -> Generator[T, None, None]:
|
||||||
|
result_queue = queue.Queue()
|
||||||
|
stop_event = threading.Event()
|
||||||
|
|
||||||
|
async def consumer():
|
||||||
|
# make sure we make the generator in the event loop context
|
||||||
|
gen = await async_gen_maker()
|
||||||
|
try:
|
||||||
|
async for item in gen:
|
||||||
|
result_queue.put(item)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in generator {e}")
|
||||||
|
result_queue.put(e)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
result_queue.put(StopIteration)
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
def run_async():
|
||||||
|
# Run our own loop to avoid double async generator cleanup which is done
|
||||||
|
# by asyncio.run()
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
task = loop.create_task(consumer())
|
||||||
|
loop.run_until_complete(task)
|
||||||
|
finally:
|
||||||
|
# Handle pending tasks like a generator's athrow()
|
||||||
|
pending = asyncio.all_tasks(loop)
|
||||||
|
if pending:
|
||||||
|
loop.run_until_complete(
|
||||||
|
asyncio.gather(*pending, return_exceptions=True)
|
||||||
|
)
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
future = pool_executor.submit(run_async)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# yield results as they come in
|
||||||
|
while not stop_event.is_set() or not result_queue.empty():
|
||||||
|
try:
|
||||||
|
item = result_queue.get(timeout=0.1)
|
||||||
|
if item is StopIteration:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
yield item
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
finally:
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
|
||||||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -32,6 +95,37 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
config_path_or_template_name: str,
|
config_path_or_template_name: str,
|
||||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||||
):
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.async_client = LlamaStackAsLibraryAsyncClient(
|
||||||
|
config_path_or_template_name, custom_provider_registry
|
||||||
|
)
|
||||||
|
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
asyncio.run(self.async_client.initialize())
|
||||||
|
|
||||||
|
def get(self, *args, **kwargs):
|
||||||
|
assert not kwargs.get("stream"), "GET never called with stream=True"
|
||||||
|
return asyncio.run(self.async_client.get(*args, **kwargs))
|
||||||
|
|
||||||
|
def post(self, *args, **kwargs):
|
||||||
|
if kwargs.get("stream"):
|
||||||
|
return stream_across_asyncio_run_boundary(
|
||||||
|
lambda: self.async_client.post(*args, **kwargs),
|
||||||
|
self.pool_executor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return asyncio.run(self.async_client.post(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaStackAsLibraryAsyncClient(AsyncLlamaStackClient):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config_path_or_template_name: str,
|
||||||
|
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
if config_path_or_template_name.endswith(".yaml"):
|
if config_path_or_template_name.endswith(".yaml"):
|
||||||
config_path = Path(config_path_or_template_name)
|
config_path = Path(config_path_or_template_name)
|
||||||
if not config_path.exists():
|
if not config_path.exists():
|
||||||
|
@ -46,8 +140,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.custom_provider_registry = custom_provider_registry
|
self.custom_provider_registry = custom_provider_registry
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
try:
|
try:
|
||||||
self.impls = await construct_stack(
|
self.impls = await construct_stack(
|
||||||
|
@ -153,6 +245,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
try:
|
try:
|
||||||
return [self._convert_param(item_type, item) for item in value]
|
return [self._convert_param(item_type, item) for item in value]
|
||||||
except Exception:
|
except Exception:
|
||||||
|
print(f"Error converting list {value}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
elif origin is dict:
|
elif origin is dict:
|
||||||
|
@ -160,6 +253,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
try:
|
try:
|
||||||
return {k: self._convert_param(val_type, v) for k, v in value.items()}
|
return {k: self._convert_param(val_type, v) for k, v in value.items()}
|
||||||
except Exception:
|
except Exception:
|
||||||
|
print(f"Error converting dict {value}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -7,38 +7,43 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||||
|
from llama_stack_client.lib.inference.event_logger import EventLogger
|
||||||
from llama_stack_client.types import UserMessage
|
from llama_stack_client.types import UserMessage
|
||||||
|
|
||||||
|
|
||||||
async def main(config_path: str):
|
def main(config_path: str):
|
||||||
client = LlamaStackAsLibraryClient(config_path)
|
client = LlamaStackAsLibraryClient(config_path)
|
||||||
await client.initialize()
|
client.initialize()
|
||||||
|
|
||||||
|
models = client.models.list()
|
||||||
|
print("\nModels:")
|
||||||
|
for model in models:
|
||||||
|
print(model)
|
||||||
|
|
||||||
models = await client.models.list()
|
|
||||||
print(models)
|
|
||||||
if not models:
|
if not models:
|
||||||
print("No models found, skipping chat completion test")
|
print("No models found, skipping chat completion test")
|
||||||
return
|
return
|
||||||
|
|
||||||
model_id = models[0].identifier
|
model_id = models[0].identifier
|
||||||
response = await client.inference.chat_completion(
|
response = client.inference.chat_completion(
|
||||||
messages=[UserMessage(content="What is the capital of France?", role="user")],
|
messages=[UserMessage(content="What is the capital of France?", role="user")],
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
print("\nChat completion response:")
|
print("\nChat completion response (non-stream):")
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
response = await client.inference.chat_completion(
|
response = client.inference.chat_completion(
|
||||||
messages=[UserMessage(content="What is the capital of France?", role="user")],
|
messages=[UserMessage(content="What is the capital of France?", role="user")],
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
print("\nChat completion stream response:")
|
|
||||||
async for chunk in response:
|
|
||||||
print(chunk)
|
|
||||||
|
|
||||||
response = await client.memory_banks.register(
|
print("\nChat completion response (stream):")
|
||||||
|
for log in EventLogger().log(response):
|
||||||
|
log.print()
|
||||||
|
|
||||||
|
response = client.memory_banks.register(
|
||||||
memory_bank_id="memory_bank_id",
|
memory_bank_id="memory_bank_id",
|
||||||
params={
|
params={
|
||||||
"chunk_size_in_tokens": 0,
|
"chunk_size_in_tokens": 0,
|
||||||
|
@ -51,9 +56,7 @@ async def main(config_path: str):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("config_path", help="Path to the config YAML file")
|
parser.add_argument("config_path", help="Path to the config YAML file")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
asyncio.run(main(args.config_path))
|
main(args.config_path)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue