From 1fba8f80c2c013dfd1c218cdd948d959b0a3496b Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 7 Dec 2024 11:53:24 -0800 Subject: [PATCH] Make sure Agents work with direct client --- .../distribution/tests/library_client_test.py | 57 ++++++++++++++++--- 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/llama_stack/distribution/tests/library_client_test.py b/llama_stack/distribution/tests/library_client_test.py index d6b1130c6..8381f5470 100644 --- a/llama_stack/distribution/tests/library_client_test.py +++ b/llama_stack/distribution/tests/library_client_test.py @@ -5,10 +5,14 @@ # the root directory of this source tree. import argparse +import os from llama_stack.distribution.library_client import LlamaStackAsLibraryClient +from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger from llama_stack_client.lib.inference.event_logger import EventLogger from llama_stack_client.types import UserMessage +from llama_stack_client.types.agent_create_params import AgentConfig def main(config_path: str): @@ -43,16 +47,53 @@ def main(config_path: str): for log in EventLogger().log(response): log.print() - response = client.memory_banks.register( - memory_bank_id="memory_bank_id", - params={ - "chunk_size_in_tokens": 0, - "embedding_model": "embedding_model", - "memory_bank_type": "vector", + print("\nAgent test:") + agent_config = AgentConfig( + model=model_id, + instructions="You are a helpful assistant", + sampling_params={ + "strategy": "greedy", + "temperature": 1.0, + "top_p": 0.9, }, + tools=( + [ + { + "type": "brave_search", + "engine": "brave", + "api_key": os.getenv("BRAVE_SEARCH_API_KEY"), + } + ] + if os.getenv("BRAVE_SEARCH_API_KEY") + else [] + ), + tool_choice="auto", + tool_prompt_format="json", + input_shields=[], + output_shields=[], + enable_session_persistence=False, ) - print("\nRegister memory bank response:") - print(response) + agent = Agent(client, agent_config) + user_prompts = [ + "Hello", + "Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools", + ] + + session_id = agent.create_session("test-session") + + for prompt in user_prompts: + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": prompt, + } + ], + session_id=session_id, + ) + + for log in AgentEventLogger().log(response): + log.print() if __name__ == "__main__":