make direct client streaming work properly

This commit is contained in:
Ashwin Bharambe 2024-12-07 11:38:56 -08:00
parent fd48cf3fc1
commit 86b5743081
2 changed files with 115 additions and 18 deletions

View file

@ -7,38 +7,43 @@
import argparse
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.lib.inference.event_logger import EventLogger
from llama_stack_client.types import UserMessage
async def main(config_path: str):
def main(config_path: str):
client = LlamaStackAsLibraryClient(config_path)
await client.initialize()
client.initialize()
models = client.models.list()
print("\nModels:")
for model in models:
print(model)
models = await client.models.list()
print(models)
if not models:
print("No models found, skipping chat completion test")
return
model_id = models[0].identifier
response = await client.inference.chat_completion(
response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id,
stream=False,
)
print("\nChat completion response:")
print("\nChat completion response (non-stream):")
print(response)
response = await client.inference.chat_completion(
response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id,
stream=True,
)
print("\nChat completion stream response:")
async for chunk in response:
print(chunk)
response = await client.memory_banks.register(
print("\nChat completion response (stream):")
for log in EventLogger().log(response):
log.print()
response = client.memory_banks.register(
memory_bank_id="memory_bank_id",
params={
"chunk_size_in_tokens": 0,
@ -51,9 +56,7 @@ async def main(config_path: str):
if __name__ == "__main__":
import asyncio
parser = argparse.ArgumentParser()
parser.add_argument("config_path", help="Path to the config YAML file")
args = parser.parse_args()
asyncio.run(main(args.config_path))
main(args.config_path)