# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import argparse import os from llama_stack.distribution.library_client import LlamaStackAsLibraryClient from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger from llama_stack_client.lib.inference.event_logger import EventLogger from llama_stack_client.types import UserMessage from llama_stack_client.types.agent_create_params import AgentConfig def main(config_path: str): client = LlamaStackAsLibraryClient(config_path) client.initialize() models = client.models.list() print("\nModels:") for model in models: print(model) if not models: print("No models found, skipping chat completion test") return model_id = models[0].identifier response = client.inference.chat_completion( messages=[UserMessage(content="What is the capital of France?", role="user")], model_id=model_id, stream=False, ) print("\nChat completion response (non-stream):") print(response) response = client.inference.chat_completion( messages=[UserMessage(content="What is the capital of France?", role="user")], model_id=model_id, stream=True, ) print("\nChat completion response (stream):") for log in EventLogger().log(response): log.print() print("\nAgent test:") agent_config = AgentConfig( model=model_id, instructions="You are a helpful assistant", sampling_params={ "strategy": "greedy", "temperature": 1.0, "top_p": 0.9, }, tools=( [ { "type": "brave_search", "engine": "brave", "api_key": os.getenv("BRAVE_SEARCH_API_KEY"), } ] if os.getenv("BRAVE_SEARCH_API_KEY") else [] ), tool_choice="auto", tool_prompt_format="json", input_shields=[], output_shields=[], enable_session_persistence=False, ) agent = Agent(client, agent_config) user_prompts = [ "Hello", "Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools", ] session_id = agent.create_session("test-session") for prompt in user_prompts: response = agent.create_turn( messages=[ { "role": "user", "content": prompt, } ], session_id=session_id, ) for log in AgentEventLogger().log(response): log.print() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("config_path", help="Path to the config YAML file") args = parser.parse_args() main(args.config_path)