docs: update test_agents to use new Agent SDK API (#1402)

# Summary:
new Agent SDK API is added in
https://github.com/meta-llama/llama-stack-client-python/pull/178

Update docs and test to reflect this.

Closes https://github.com/meta-llama/llama-stack/issues/1365

# Test Plan:
```bash
py.test -v -s --nbval-lax ./docs/getting_started.ipynb

LLAMA_STACK_CONFIG=fireworks \
   pytest -s -v tests/integration/agents/test_agents.py \
  --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct
```
This commit is contained in:
ehhuang 2025-03-06 15:21:12 -08:00 committed by GitHub
parent 3d71e5a036
commit ca2910d27a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 121 additions and 206 deletions

View file

@ -49,7 +49,6 @@
"source": [
"from llama_stack_client import LlamaStackClient\n",
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from rich.pretty import pprint\n",
"import json\n",
@ -71,20 +70,12 @@
"\n",
"MODEL_ID = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
"\n",
"base_agent_config = AgentConfig(\n",
"base_agent_config = dict(\n",
" model=MODEL_ID,\n",
" instructions=\"You are a helpful assistant.\",\n",
" sampling_params={\n",
" \"strategy\": {\"type\": \"top_p\", \"temperature\": 1.0, \"top_p\": 0.9},\n",
" },\n",
" toolgroups=[],\n",
" tool_config={\n",
" \"tool_choice\": \"auto\",\n",
" \"tool_prompt_format\": \"python_list\",\n",
" },\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
")"
]
},
@ -172,7 +163,7 @@
}
],
"source": [
"vanilla_agent_config = AgentConfig({\n",
"vanilla_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"\n",
" You are a helpful assistant capable of structuring data extraction and formatting. \n",
@ -189,9 +180,9 @@
" Employee satisfaction is at 87 points.\n",
" Operating margin improved to 34%.\n",
" \"\"\",\n",
"})\n",
"}\n",
"\n",
"vanilla_agent = Agent(client, vanilla_agent_config)\n",
"vanilla_agent = Agent(client, **vanilla_agent_config)\n",
"prompt_chaining_session_id = vanilla_agent.create_session(session_name=f\"vanilla_agent_{uuid.uuid4()}\")\n",
"\n",
"prompts = [\n",
@ -778,7 +769,7 @@
],
"source": [
"# 1. Define a couple of specialized agents\n",
"billing_agent_config = AgentConfig({\n",
"billing_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"You are a billing support specialist. Follow these guidelines:\n",
" 1. Always start with \"Billing Support Response:\"\n",
@ -789,9 +780,9 @@
" \n",
" Keep responses professional but friendly.\n",
" \"\"\",\n",
"})\n",
"}\n",
"\n",
"technical_agent_config = AgentConfig({\n",
"technical_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"You are a technical support engineer. Follow these guidelines:\n",
" 1. Always start with \"Technical Support Response:\"\n",
@ -802,9 +793,9 @@
" \n",
" Use clear, numbered steps and technical details.\n",
" \"\"\",\n",
"})\n",
"}\n",
"\n",
"account_agent_config = AgentConfig({\n",
"account_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"You are an account security specialist. Follow these guidelines:\n",
" 1. Always start with \"Account Support Response:\"\n",
@ -815,9 +806,9 @@
" \n",
" Maintain a serious, security-focused tone.\n",
" \"\"\",\n",
"})\n",
"}\n",
"\n",
"product_agent_config = AgentConfig({\n",
"product_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"You are a product specialist. Follow these guidelines:\n",
" 1. Always start with \"Product Support Response:\"\n",
@ -828,13 +819,13 @@
" \n",
" Be educational and encouraging in tone.\n",
" \"\"\",\n",
"})\n",
"}\n",
"\n",
"specialized_agents = {\n",
" \"billing\": Agent(client, billing_agent_config),\n",
" \"technical\": Agent(client, technical_agent_config),\n",
" \"account\": Agent(client, account_agent_config),\n",
" \"product\": Agent(client, product_agent_config),\n",
" \"billing\": Agent(client, **billing_agent_config),\n",
" \"technical\": Agent(client, **technical_agent_config),\n",
" \"account\": Agent(client, **account_agent_config),\n",
" \"product\": Agent(client, **product_agent_config),\n",
"}\n",
"\n",
"# 2. Define a routing agent\n",
@ -842,7 +833,7 @@
" reasoning: str\n",
" support_team: str\n",
"\n",
"routing_agent_config = AgentConfig({\n",
"routing_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": f\"\"\"You are a routing agent. Analyze the user's input and select the most appropriate support team from these options: \n",
"\n",
@ -862,9 +853,9 @@
" \"type\": \"json_schema\",\n",
" \"json_schema\": OutputSchema.model_json_schema()\n",
" }\n",
"})\n",
"}\n",
"\n",
"routing_agent = Agent(client, routing_agent_config)\n",
"routing_agent = Agent(client, **routing_agent_config)\n",
"\n",
"# 3. Create a session for all agents\n",
"routing_agent_session_id = routing_agent.create_session(session_name=f\"routing_agent_{uuid.uuid4()}\")\n",
@ -1725,17 +1716,17 @@
"from concurrent.futures import ThreadPoolExecutor\n",
"from typing import List\n",
"\n",
"worker_agent_config = AgentConfig({\n",
"worker_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"You are a helpful assistant that can analyze the impact of market changes on stakeholders.\n",
" Analyze how market changes will impact this stakeholder group.\n",
" Provide specific impacts and recommended actions.\n",
" Format with clear sections and priorities.\n",
" \"\"\",\n",
"})\n",
"}\n",
"\n",
"def create_worker_task(task: str):\n",
" worker_agent = Agent(client, worker_agent_config)\n",
" worker_agent = Agent(client, **worker_agent_config)\n",
" worker_session_id = worker_agent.create_session(session_name=f\"worker_agent_{uuid.uuid4()}\")\n",
" task_response = worker_agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": task}],\n",
@ -2248,7 +2239,7 @@
" thoughts: str\n",
" response: str\n",
"\n",
"generator_agent_config = AgentConfig({\n",
"generator_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"Your goal is to complete the task based on <user input>. If there are feedback \n",
" from your previous generations, you should reflect on them to improve your solution\n",
@ -2263,13 +2254,13 @@
" \"type\": \"json_schema\",\n",
" \"json_schema\": GeneratorOutputSchema.model_json_schema()\n",
" }\n",
"})\n",
"}\n",
"\n",
"class EvaluatorOutputSchema(BaseModel):\n",
" evaluation: str\n",
" feedback: str\n",
"\n",
"evaluator_agent_config = AgentConfig({\n",
"evaluator_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"Evaluate this following code implementation for:\n",
" 1. code correctness\n",
@ -2293,10 +2284,10 @@
" \"type\": \"json_schema\",\n",
" \"json_schema\": EvaluatorOutputSchema.model_json_schema()\n",
" }\n",
"})\n",
"}\n",
"\n",
"generator_agent = Agent(client, generator_agent_config)\n",
"evaluator_agent = Agent(client, evaluator_agent_config)\n",
"generator_agent = Agent(client, **generator_agent_config)\n",
"evaluator_agent = Agent(client, **evaluator_agent_config)\n",
"generator_session_id = generator_agent.create_session(session_name=f\"generator_agent_{uuid.uuid4()}\")\n",
"evaluator_session_id = evaluator_agent.create_session(session_name=f\"evaluator_agent_{uuid.uuid4()}\")\n",
"\n",
@ -2628,7 +2619,7 @@
" analysis: str\n",
" tasks: List[Dict[str, str]]\n",
"\n",
"orchestrator_agent_config = AgentConfig({\n",
"orchestrator_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"Your job is to analyize the task provided by the user andbreak it down into 2-3 distinct approaches:\n",
"\n",
@ -2651,9 +2642,9 @@
" \"type\": \"json_schema\",\n",
" \"json_schema\": OrchestratorOutputSchema.model_json_schema()\n",
" }\n",
"})\n",
"}\n",
"\n",
"worker_agent_config = AgentConfig({\n",
"worker_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"You will be given a task guideline. Generate content based on the provided\n",
" task, following the style and guideline descriptions. \n",
@ -2662,7 +2653,7 @@
"\n",
" Response: Your content here, maintaining the specified style and fully addressing requirements.\n",
" \"\"\",\n",
"})\n"
"}\n"
]
},
{
@ -2673,7 +2664,7 @@
"source": [
"def orchestrator_worker_workflow(task, context):\n",
" # single orchestrator agent\n",
" orchestrator_agent = Agent(client, orchestrator_agent_config)\n",
" orchestrator_agent = Agent(client, **orchestrator_agent_config)\n",
" orchestrator_session_id = orchestrator_agent.create_session(session_name=f\"orchestrator_agent_{uuid.uuid4()}\")\n",
"\n",
" orchestrator_response = orchestrator_agent.create_turn(\n",
@ -2689,7 +2680,7 @@
" workers = {}\n",
" # spawn multiple worker agents\n",
" for task in orchestrator_result[\"tasks\"]:\n",
" worker_agent = Agent(client, worker_agent_config)\n",
" worker_agent = Agent(client, **worker_agent_config)\n",
" worker_session_id = worker_agent.create_session(session_name=f\"worker_agent_{uuid.uuid4()}\")\n",
" workers[task[\"type\"]] = worker_agent\n",
" \n",