docs: update test_agents to use new Agent SDK API (#1402)

# Summary:
new Agent SDK API is added in
https://github.com/meta-llama/llama-stack-client-python/pull/178

Update docs and test to reflect this.

Closes https://github.com/meta-llama/llama-stack/issues/1365

# Test Plan:
```bash
py.test -v -s --nbval-lax ./docs/getting_started.ipynb

LLAMA_STACK_CONFIG=fireworks \
   pytest -s -v tests/integration/agents/test_agents.py \
  --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct
```
This commit is contained in:
ehhuang 2025-03-06 15:21:12 -08:00 committed by GitHub
parent 3d71e5a036
commit ca2910d27a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 121 additions and 206 deletions

View file

@ -294,8 +294,9 @@
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
"\n",
" # Define the agent configuration, including the model and tool setup\n",
" agent_config = AgentConfig(\n",
" # Create an agent instance with the client and configuration\n",
" agent = Agent(\n",
" client, \n",
" model=MODEL_NAME,\n",
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
" sampling_params={\n",
@ -303,17 +304,12 @@
" \"type\": \"greedy\",\n",
" },\n",
" },\n",
" tools=[webSearchTool.get_tool_definition()],\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"python_list\",\n",
" tools=[webSearchTool],\n",
" input_shields=input_shields,\n",
" output_shields=output_shields,\n",
" enable_session_persistence=False,\n",
" )\n",
"\n",
" # Create an agent instance with the client and configuration\n",
" agent = Agent(client, agent_config, [webSearchTool])\n",
"\n",
" # Create a session for interaction and print the session ID\n",
" session_id = agent.create_session(\"test-session\")\n",
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",

View file

@ -110,12 +110,12 @@
"from llama_stack_client import LlamaStackClient\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"\n",
"\n",
"async def agent_example():\n",
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
" agent_config = AgentConfig(\n",
" agent = Agent(\n",
" client, \n",
" model=MODEL_NAME,\n",
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
" sampling_params={\n",
@ -130,14 +130,7 @@
" \"api_key\": BRAVE_SEARCH_API_KEY,\n",
" }\n",
" ],\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"function_tag\",\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
" )\n",
"\n",
" agent = Agent(client, agent_config)\n",
" session_id = agent.create_session(\"test-session\")\n",
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",
"\n",

View file

@ -103,7 +103,6 @@
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import (\n",
" AgentConfig,\n",
" AgentConfigToolSearchToolDefinition,\n",
")\n",
"\n",
@ -117,7 +116,8 @@
") -> Agent:\n",
" \"\"\"Create an agent with specified tools.\"\"\"\n",
" print(\"Using the following model: \", model)\n",
" agent_config = AgentConfig(\n",
" return Agent(\n",
" client, \n",
" model=model,\n",
" instructions=instructions,\n",
" sampling_params={\n",
@ -126,12 +126,7 @@
" },\n",
" },\n",
" tools=tools,\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"json\",\n",
" enable_session_persistence=True,\n",
" )\n",
"\n",
" return Agent(client, agent_config)\n"
" )\n"
]
},
{
@ -360,9 +355,9 @@
" # Create the agent with the tool\n",
" weather_tool = WeatherTool()\n",
"\n",
" agent_config = AgentConfig(\n",
" agent = Agent(\n",
" client=client, \n",
" model=LLAMA31_8B_INSTRUCT,\n",
" # model=model_name,\n",
" instructions=\"\"\"\n",
" You are a weather assistant that can provide weather information.\n",
" Always specify the location clearly in your responses.\n",
@ -373,16 +368,9 @@
" \"type\": \"greedy\",\n",
" },\n",
" },\n",
" tools=[weather_tool.get_tool_definition()],\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"json\",\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=True,\n",
" tools=[weather_tool],\n",
" )\n",
"\n",
" agent = Agent(client=client, agent_config=agent_config, custom_tools=[weather_tool])\n",
"\n",
" return agent\n",
"\n",
"\n",