llama-stack-mirror/llama_stack/distribution/tests/library_client_test.py
2024-12-07 11:38:56 -08:00

62 lines
1.8 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.lib.inference.event_logger import EventLogger
from llama_stack_client.types import UserMessage
def main(config_path: str):
client = LlamaStackAsLibraryClient(config_path)
client.initialize()
models = client.models.list()
print("\nModels:")
for model in models:
print(model)
if not models:
print("No models found, skipping chat completion test")
return
model_id = models[0].identifier
response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id,
stream=False,
)
print("\nChat completion response (non-stream):")
print(response)
response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id,
stream=True,
)
print("\nChat completion response (stream):")
for log in EventLogger().log(response):
log.print()
response = client.memory_banks.register(
memory_bank_id="memory_bank_id",
params={
"chunk_size_in_tokens": 0,
"embedding_model": "embedding_model",
"memory_bank_type": "vector",
},
)
print("\nRegister memory bank response:")
print(response)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("config_path", help="Path to the config YAML file")
args = parser.parse_args()
main(args.config_path)