forked from phoenix-oss/llama-stack-mirror
		
	
		
			
				
	
	
		
			129 lines
		
	
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			129 lines
		
	
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import argparse
 | |
| import os
 | |
| 
 | |
| from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
 | |
| from llama_stack_client.lib.agents.agent import Agent
 | |
| from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger
 | |
| from llama_stack_client.lib.inference.event_logger import EventLogger
 | |
| from llama_stack_client.types import Attachment, UserMessage
 | |
| from llama_stack_client.types.agent_create_params import AgentConfig
 | |
| 
 | |
| 
 | |
| def main(config_path: str):
 | |
|     client = LlamaStackAsLibraryClient(config_path)
 | |
|     if not client.initialize():
 | |
|         return
 | |
| 
 | |
|     models = client.models.list()
 | |
|     print("\nModels:")
 | |
|     for model in models:
 | |
|         print(model)
 | |
| 
 | |
|     if not models:
 | |
|         print("No models found, skipping chat completion test")
 | |
|         return
 | |
| 
 | |
|     model_id = next(m.identifier for m in models if "8b" in m.identifier.lower())
 | |
|     print(f"Using model: {model_id}")
 | |
|     response = client.inference.chat_completion(
 | |
|         messages=[UserMessage(content="What is the capital of France?", role="user")],
 | |
|         model_id=model_id,
 | |
|         stream=False,
 | |
|     )
 | |
|     print("\nChat completion response (non-stream):")
 | |
|     print(response)
 | |
| 
 | |
|     response = client.inference.chat_completion(
 | |
|         messages=[UserMessage(content="What is the capital of France?", role="user")],
 | |
|         model_id=model_id,
 | |
|         stream=True,
 | |
|     )
 | |
| 
 | |
|     print("\nChat completion response (stream):")
 | |
|     for log in EventLogger().log(response):
 | |
|         log.print()
 | |
| 
 | |
|     print("\nAgent test:")
 | |
|     agent_config = AgentConfig(
 | |
|         model=model_id,
 | |
|         instructions="You are a helpful assistant",
 | |
|         sampling_params={
 | |
|             "strategy": "greedy",
 | |
|             "temperature": 1.0,
 | |
|             "top_p": 0.9,
 | |
|         },
 | |
|         tools=(
 | |
|             [
 | |
|                 {
 | |
|                     "type": "brave_search",
 | |
|                     "engine": "brave",
 | |
|                     "api_key": os.getenv("BRAVE_SEARCH_API_KEY"),
 | |
|                 }
 | |
|             ]
 | |
|             if os.getenv("BRAVE_SEARCH_API_KEY")
 | |
|             else []
 | |
|         )
 | |
|         + (
 | |
|             [
 | |
|                 {
 | |
|                     "type": "code_interpreter",
 | |
|                 }
 | |
|             ]
 | |
|         ),
 | |
|         tool_choice="required",
 | |
|         input_shields=[],
 | |
|         output_shields=[],
 | |
|         enable_session_persistence=False,
 | |
|     )
 | |
|     agent = Agent(client, agent_config)
 | |
|     user_prompts = [
 | |
|         "Hello",
 | |
|         "Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools",
 | |
|     ]
 | |
|     user_prompts = [
 | |
|         (
 | |
|             "Here is a csv, can you describe it ?",
 | |
|             [
 | |
|                 Attachment(
 | |
|                     content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
 | |
|                     mime_type="test/csv",
 | |
|                 )
 | |
|             ],
 | |
|         ),
 | |
|         ("Which year ended with the highest inflation ?", None),
 | |
|         (
 | |
|             "What macro economic situations that led to such high inflation in that period?",
 | |
|             None,
 | |
|         ),
 | |
|         ("Plot average yearly inflation as a time series", None),
 | |
|     ]
 | |
| 
 | |
|     session_id = agent.create_session("test-session")
 | |
| 
 | |
|     for prompt, attachments in user_prompts:
 | |
|         response = agent.create_turn(
 | |
|             messages=[
 | |
|                 {
 | |
|                     "role": "user",
 | |
|                     "content": prompt,
 | |
|                 }
 | |
|             ],
 | |
|             attachments=attachments,
 | |
|             session_id=session_id,
 | |
|         )
 | |
| 
 | |
|         for log in AgentEventLogger().log(response):
 | |
|             log.print()
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser()
 | |
|     parser.add_argument("config_path", help="Path to the config YAML file")
 | |
|     args = parser.parse_args()
 | |
|     main(args.config_path)
 |