mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 17:22:38 +00:00
make direct client streaming work properly
This commit is contained in:
parent
fd48cf3fc1
commit
86b5743081
2 changed files with 115 additions and 18 deletions
|
|
@ -7,38 +7,43 @@
|
|||
import argparse
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
from llama_stack_client.lib.inference.event_logger import EventLogger
|
||||
from llama_stack_client.types import UserMessage
|
||||
|
||||
|
||||
async def main(config_path: str):
|
||||
def main(config_path: str):
|
||||
client = LlamaStackAsLibraryClient(config_path)
|
||||
await client.initialize()
|
||||
client.initialize()
|
||||
|
||||
models = client.models.list()
|
||||
print("\nModels:")
|
||||
for model in models:
|
||||
print(model)
|
||||
|
||||
models = await client.models.list()
|
||||
print(models)
|
||||
if not models:
|
||||
print("No models found, skipping chat completion test")
|
||||
return
|
||||
|
||||
model_id = models[0].identifier
|
||||
response = await client.inference.chat_completion(
|
||||
response = client.inference.chat_completion(
|
||||
messages=[UserMessage(content="What is the capital of France?", role="user")],
|
||||
model_id=model_id,
|
||||
stream=False,
|
||||
)
|
||||
print("\nChat completion response:")
|
||||
print("\nChat completion response (non-stream):")
|
||||
print(response)
|
||||
|
||||
response = await client.inference.chat_completion(
|
||||
response = client.inference.chat_completion(
|
||||
messages=[UserMessage(content="What is the capital of France?", role="user")],
|
||||
model_id=model_id,
|
||||
stream=True,
|
||||
)
|
||||
print("\nChat completion stream response:")
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
|
||||
response = await client.memory_banks.register(
|
||||
print("\nChat completion response (stream):")
|
||||
for log in EventLogger().log(response):
|
||||
log.print()
|
||||
|
||||
response = client.memory_banks.register(
|
||||
memory_bank_id="memory_bank_id",
|
||||
params={
|
||||
"chunk_size_in_tokens": 0,
|
||||
|
|
@ -51,9 +56,7 @@ async def main(config_path: str):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("config_path", help="Path to the config YAML file")
|
||||
args = parser.parse_args()
|
||||
asyncio.run(main(args.config_path))
|
||||
main(args.config_path)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue