Make sure Agents work with direct client

This commit is contained in:
Ashwin Bharambe 2024-12-07 11:53:24 -08:00
parent 86b5743081
commit 1fba8f80c2

View file

@ -5,10 +5,14 @@
# the root directory of this source tree.
import argparse
import os
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger
from llama_stack_client.lib.inference.event_logger import EventLogger
from llama_stack_client.types import UserMessage
from llama_stack_client.types.agent_create_params import AgentConfig
def main(config_path: str):
@ -43,16 +47,53 @@ def main(config_path: str):
for log in EventLogger().log(response):
log.print()
response = client.memory_banks.register(
memory_bank_id="memory_bank_id",
params={
"chunk_size_in_tokens": 0,
"embedding_model": "embedding_model",
"memory_bank_type": "vector",
print("\nAgent test:")
agent_config = AgentConfig(
model=model_id,
instructions="You are a helpful assistant",
sampling_params={
"strategy": "greedy",
"temperature": 1.0,
"top_p": 0.9,
},
tools=(
[
{
"type": "brave_search",
"engine": "brave",
"api_key": os.getenv("BRAVE_SEARCH_API_KEY"),
}
]
if os.getenv("BRAVE_SEARCH_API_KEY")
else []
),
tool_choice="auto",
tool_prompt_format="json",
input_shields=[],
output_shields=[],
enable_session_persistence=False,
)
print("\nRegister memory bank response:")
print(response)
agent = Agent(client, agent_config)
user_prompts = [
"Hello",
"Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools",
]
session_id = agent.create_session("test-session")
for prompt in user_prompts:
response = agent.create_turn(
messages=[
{
"role": "user",
"content": prompt,
}
],
session_id=session_id,
)
for log in AgentEventLogger().log(response):
log.print()
if __name__ == "__main__":