docs: update test_agents to use new Agent SDK API (#1402)

# Summary:
new Agent SDK API is added in
https://github.com/meta-llama/llama-stack-client-python/pull/178

Update docs and test to reflect this.

Closes https://github.com/meta-llama/llama-stack/issues/1365

# Test Plan:
```bash
py.test -v -s --nbval-lax ./docs/getting_started.ipynb

LLAMA_STACK_CONFIG=fireworks \
   pytest -s -v tests/integration/agents/test_agents.py \
  --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct
```
This commit is contained in:
ehhuang 2025-03-06 15:21:12 -08:00 committed by GitHub
parent 3d71e5a036
commit ca2910d27a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 121 additions and 206 deletions

View file

@ -1635,18 +1635,14 @@
"source": [ "source": [
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from termcolor import cprint\n", "from termcolor import cprint\n",
"\n", "\n",
"agent_config = AgentConfig(\n", "agent = Agent(\n",
" client, \n",
" model=model_id,\n", " model=model_id,\n",
" instructions=\"You are a helpful assistant\",\n", " instructions=\"You are a helpful assistant\",\n",
" toolgroups=[\"builtin::websearch\"],\n", " tools=[\"builtin::websearch\"],\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
")\n", ")\n",
"agent = Agent(client, agent_config)\n",
"user_prompts = [\n", "user_prompts = [\n",
" \"Hello\",\n", " \"Hello\",\n",
" \"Which teams played in the NBA western conference finals of 2024\",\n", " \"Which teams played in the NBA western conference finals of 2024\",\n",
@ -1815,7 +1811,6 @@
"import uuid\n", "import uuid\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from termcolor import cprint\n", "from termcolor import cprint\n",
"from llama_stack_client.types import Document\n", "from llama_stack_client.types import Document\n",
"\n", "\n",
@ -1841,11 +1836,11 @@
" vector_db_id=vector_db_id,\n", " vector_db_id=vector_db_id,\n",
" chunk_size_in_tokens=512,\n", " chunk_size_in_tokens=512,\n",
")\n", ")\n",
"agent_config = AgentConfig(\n", "rag_agent = Agent(\n",
" client, \n",
" model=model_id,\n", " model=model_id,\n",
" instructions=\"You are a helpful assistant\",\n", " instructions=\"You are a helpful assistant\",\n",
" enable_session_persistence=False,\n", " tools = [\n",
" toolgroups = [\n",
" {\n", " {\n",
" \"name\": \"builtin::rag/knowledge_search\",\n", " \"name\": \"builtin::rag/knowledge_search\",\n",
" \"args\" : {\n", " \"args\" : {\n",
@ -1854,7 +1849,6 @@
" }\n", " }\n",
" ],\n", " ],\n",
")\n", ")\n",
"rag_agent = Agent(client, agent_config)\n",
"session_id = rag_agent.create_session(\"test-session\")\n", "session_id = rag_agent.create_session(\"test-session\")\n",
"user_prompts = [\n", "user_prompts = [\n",
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\",\n", " \"What are the top 5 topics that were explained? Only list succinct bullet points.\",\n",
@ -1978,23 +1972,19 @@
"source": [ "source": [
"from llama_stack_client.types.agents.turn_create_params import Document\n", "from llama_stack_client.types.agents.turn_create_params import Document\n",
"\n", "\n",
"agent_config = AgentConfig(\n", "codex_agent = Agent(\n",
" client, \n",
" model=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
" instructions=\"You are a helpful assistant\",\n",
" tools=[\n",
" \"builtin::code_interpreter\",\n",
" \"builtin::websearch\"\n",
" ],\n",
" sampling_params = {\n", " sampling_params = {\n",
" \"max_tokens\" : 4096,\n", " \"max_tokens\" : 4096,\n",
" \"temperature\": 0.0\n", " \"temperature\": 0.0\n",
" },\n", " },\n",
" model=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
" instructions=\"You are a helpful assistant\",\n",
" toolgroups=[\n",
" \"builtin::code_interpreter\",\n",
" \"builtin::websearch\"\n",
" ],\n",
" tool_choice=\"auto\",\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
")\n", ")\n",
"codex_agent = Agent(client, agent_config)\n",
"session_id = codex_agent.create_session(\"test-session\")\n", "session_id = codex_agent.create_session(\"test-session\")\n",
"\n", "\n",
"\n", "\n",
@ -2904,18 +2894,14 @@
"# NBVAL_SKIP\n", "# NBVAL_SKIP\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from termcolor import cprint\n", "from termcolor import cprint\n",
"\n", "\n",
"agent_config = AgentConfig(\n", "agent = Agent(\n",
" client, \n",
" model=model_id,\n", " model=model_id,\n",
" instructions=\"You are a helpful assistant\",\n", " instructions=\"You are a helpful assistant\",\n",
" toolgroups=[\"mcp::filesystem\"],\n", " tools=[\"mcp::filesystem\"],\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
")\n", ")\n",
"agent = Agent(client, agent_config)\n",
"user_prompts = [\n", "user_prompts = [\n",
" \"Hello\",\n", " \"Hello\",\n",
" \"list all the files /content\",\n", " \"list all the files /content\",\n",
@ -3010,17 +2996,13 @@
"source": [ "source": [
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"\n", "\n",
"agent_config = AgentConfig(\n", "agent = Agent(\n",
" client, \n",
" model=\"meta-llama/Llama-3.3-70B-Instruct\",\n", " model=\"meta-llama/Llama-3.3-70B-Instruct\",\n",
" instructions=\"You are a helpful assistant. Use search tool to answer the questions. \",\n", " instructions=\"You are a helpful assistant. Use search tool to answer the questions. \",\n",
" toolgroups=[\"builtin::websearch\"],\n", " tools=[\"builtin::websearch\"],\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
")\n", ")\n",
"agent = Agent(client, agent_config)\n",
"user_prompts = [\n", "user_prompts = [\n",
" \"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.\",\n", " \"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.\",\n",
" \"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.\",\n", " \"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.\",\n",
@ -4346,16 +4328,11 @@
} }
], ],
"source": [ "source": [
"from llama_stack_client.types.agent_create_params import AgentConfig\n", "agent = Agent(\n",
"\n", " client, \n",
"agent_config = AgentConfig(\n",
" model=vision_model_id,\n", " model=vision_model_id,\n",
" instructions=\"You are a helpful assistant\",\n", " instructions=\"You are a helpful assistant\",\n",
" enable_session_persistence=False,\n",
" toolgroups=[],\n",
")\n", ")\n",
"\n",
"agent = Agent(client, agent_config)\n",
"session_id = agent.create_session(\"test-session\")\n", "session_id = agent.create_session(\"test-session\")\n",
"\n", "\n",
"response = agent.create_turn(\n", "response = agent.create_turn(\n",

View file

@ -49,7 +49,6 @@
"source": [ "source": [
"from llama_stack_client import LlamaStackClient\n", "from llama_stack_client import LlamaStackClient\n",
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from rich.pretty import pprint\n", "from rich.pretty import pprint\n",
"import json\n", "import json\n",
@ -71,20 +70,12 @@
"\n", "\n",
"MODEL_ID = \"meta-llama/Llama-3.3-70B-Instruct\"\n", "MODEL_ID = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
"\n", "\n",
"base_agent_config = AgentConfig(\n", "base_agent_config = dict(\n",
" model=MODEL_ID,\n", " model=MODEL_ID,\n",
" instructions=\"You are a helpful assistant.\",\n", " instructions=\"You are a helpful assistant.\",\n",
" sampling_params={\n", " sampling_params={\n",
" \"strategy\": {\"type\": \"top_p\", \"temperature\": 1.0, \"top_p\": 0.9},\n", " \"strategy\": {\"type\": \"top_p\", \"temperature\": 1.0, \"top_p\": 0.9},\n",
" },\n", " },\n",
" toolgroups=[],\n",
" tool_config={\n",
" \"tool_choice\": \"auto\",\n",
" \"tool_prompt_format\": \"python_list\",\n",
" },\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
")" ")"
] ]
}, },
@ -172,7 +163,7 @@
} }
], ],
"source": [ "source": [
"vanilla_agent_config = AgentConfig({\n", "vanilla_agent_config = {\n",
" **base_agent_config,\n", " **base_agent_config,\n",
" \"instructions\": \"\"\"\n", " \"instructions\": \"\"\"\n",
" You are a helpful assistant capable of structuring data extraction and formatting. \n", " You are a helpful assistant capable of structuring data extraction and formatting. \n",
@ -189,9 +180,9 @@
" Employee satisfaction is at 87 points.\n", " Employee satisfaction is at 87 points.\n",
" Operating margin improved to 34%.\n", " Operating margin improved to 34%.\n",
" \"\"\",\n", " \"\"\",\n",
"})\n", "}\n",
"\n", "\n",
"vanilla_agent = Agent(client, vanilla_agent_config)\n", "vanilla_agent = Agent(client, **vanilla_agent_config)\n",
"prompt_chaining_session_id = vanilla_agent.create_session(session_name=f\"vanilla_agent_{uuid.uuid4()}\")\n", "prompt_chaining_session_id = vanilla_agent.create_session(session_name=f\"vanilla_agent_{uuid.uuid4()}\")\n",
"\n", "\n",
"prompts = [\n", "prompts = [\n",
@ -778,7 +769,7 @@
], ],
"source": [ "source": [
"# 1. Define a couple of specialized agents\n", "# 1. Define a couple of specialized agents\n",
"billing_agent_config = AgentConfig({\n", "billing_agent_config = {\n",
" **base_agent_config,\n", " **base_agent_config,\n",
" \"instructions\": \"\"\"You are a billing support specialist. Follow these guidelines:\n", " \"instructions\": \"\"\"You are a billing support specialist. Follow these guidelines:\n",
" 1. Always start with \"Billing Support Response:\"\n", " 1. Always start with \"Billing Support Response:\"\n",
@ -789,9 +780,9 @@
" \n", " \n",
" Keep responses professional but friendly.\n", " Keep responses professional but friendly.\n",
" \"\"\",\n", " \"\"\",\n",
"})\n", "}\n",
"\n", "\n",
"technical_agent_config = AgentConfig({\n", "technical_agent_config = {\n",
" **base_agent_config,\n", " **base_agent_config,\n",
" \"instructions\": \"\"\"You are a technical support engineer. Follow these guidelines:\n", " \"instructions\": \"\"\"You are a technical support engineer. Follow these guidelines:\n",
" 1. Always start with \"Technical Support Response:\"\n", " 1. Always start with \"Technical Support Response:\"\n",
@ -802,9 +793,9 @@
" \n", " \n",
" Use clear, numbered steps and technical details.\n", " Use clear, numbered steps and technical details.\n",
" \"\"\",\n", " \"\"\",\n",
"})\n", "}\n",
"\n", "\n",
"account_agent_config = AgentConfig({\n", "account_agent_config = {\n",
" **base_agent_config,\n", " **base_agent_config,\n",
" \"instructions\": \"\"\"You are an account security specialist. Follow these guidelines:\n", " \"instructions\": \"\"\"You are an account security specialist. Follow these guidelines:\n",
" 1. Always start with \"Account Support Response:\"\n", " 1. Always start with \"Account Support Response:\"\n",
@ -815,9 +806,9 @@
" \n", " \n",
" Maintain a serious, security-focused tone.\n", " Maintain a serious, security-focused tone.\n",
" \"\"\",\n", " \"\"\",\n",
"})\n", "}\n",
"\n", "\n",
"product_agent_config = AgentConfig({\n", "product_agent_config = {\n",
" **base_agent_config,\n", " **base_agent_config,\n",
" \"instructions\": \"\"\"You are a product specialist. Follow these guidelines:\n", " \"instructions\": \"\"\"You are a product specialist. Follow these guidelines:\n",
" 1. Always start with \"Product Support Response:\"\n", " 1. Always start with \"Product Support Response:\"\n",
@ -828,13 +819,13 @@
" \n", " \n",
" Be educational and encouraging in tone.\n", " Be educational and encouraging in tone.\n",
" \"\"\",\n", " \"\"\",\n",
"})\n", "}\n",
"\n", "\n",
"specialized_agents = {\n", "specialized_agents = {\n",
" \"billing\": Agent(client, billing_agent_config),\n", " \"billing\": Agent(client, **billing_agent_config),\n",
" \"technical\": Agent(client, technical_agent_config),\n", " \"technical\": Agent(client, **technical_agent_config),\n",
" \"account\": Agent(client, account_agent_config),\n", " \"account\": Agent(client, **account_agent_config),\n",
" \"product\": Agent(client, product_agent_config),\n", " \"product\": Agent(client, **product_agent_config),\n",
"}\n", "}\n",
"\n", "\n",
"# 2. Define a routing agent\n", "# 2. Define a routing agent\n",
@ -842,7 +833,7 @@
" reasoning: str\n", " reasoning: str\n",
" support_team: str\n", " support_team: str\n",
"\n", "\n",
"routing_agent_config = AgentConfig({\n", "routing_agent_config = {\n",
" **base_agent_config,\n", " **base_agent_config,\n",
" \"instructions\": f\"\"\"You are a routing agent. Analyze the user's input and select the most appropriate support team from these options: \n", " \"instructions\": f\"\"\"You are a routing agent. Analyze the user's input and select the most appropriate support team from these options: \n",
"\n", "\n",
@ -862,9 +853,9 @@
" \"type\": \"json_schema\",\n", " \"type\": \"json_schema\",\n",
" \"json_schema\": OutputSchema.model_json_schema()\n", " \"json_schema\": OutputSchema.model_json_schema()\n",
" }\n", " }\n",
"})\n", "}\n",
"\n", "\n",
"routing_agent = Agent(client, routing_agent_config)\n", "routing_agent = Agent(client, **routing_agent_config)\n",
"\n", "\n",
"# 3. Create a session for all agents\n", "# 3. Create a session for all agents\n",
"routing_agent_session_id = routing_agent.create_session(session_name=f\"routing_agent_{uuid.uuid4()}\")\n", "routing_agent_session_id = routing_agent.create_session(session_name=f\"routing_agent_{uuid.uuid4()}\")\n",
@ -1725,17 +1716,17 @@
"from concurrent.futures import ThreadPoolExecutor\n", "from concurrent.futures import ThreadPoolExecutor\n",
"from typing import List\n", "from typing import List\n",
"\n", "\n",
"worker_agent_config = AgentConfig({\n", "worker_agent_config = {\n",
" **base_agent_config,\n", " **base_agent_config,\n",
" \"instructions\": \"\"\"You are a helpful assistant that can analyze the impact of market changes on stakeholders.\n", " \"instructions\": \"\"\"You are a helpful assistant that can analyze the impact of market changes on stakeholders.\n",
" Analyze how market changes will impact this stakeholder group.\n", " Analyze how market changes will impact this stakeholder group.\n",
" Provide specific impacts and recommended actions.\n", " Provide specific impacts and recommended actions.\n",
" Format with clear sections and priorities.\n", " Format with clear sections and priorities.\n",
" \"\"\",\n", " \"\"\",\n",
"})\n", "}\n",
"\n", "\n",
"def create_worker_task(task: str):\n", "def create_worker_task(task: str):\n",
" worker_agent = Agent(client, worker_agent_config)\n", " worker_agent = Agent(client, **worker_agent_config)\n",
" worker_session_id = worker_agent.create_session(session_name=f\"worker_agent_{uuid.uuid4()}\")\n", " worker_session_id = worker_agent.create_session(session_name=f\"worker_agent_{uuid.uuid4()}\")\n",
" task_response = worker_agent.create_turn(\n", " task_response = worker_agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": task}],\n", " messages=[{\"role\": \"user\", \"content\": task}],\n",
@ -2248,7 +2239,7 @@
" thoughts: str\n", " thoughts: str\n",
" response: str\n", " response: str\n",
"\n", "\n",
"generator_agent_config = AgentConfig({\n", "generator_agent_config = {\n",
" **base_agent_config,\n", " **base_agent_config,\n",
" \"instructions\": \"\"\"Your goal is to complete the task based on <user input>. If there are feedback \n", " \"instructions\": \"\"\"Your goal is to complete the task based on <user input>. If there are feedback \n",
" from your previous generations, you should reflect on them to improve your solution\n", " from your previous generations, you should reflect on them to improve your solution\n",
@ -2263,13 +2254,13 @@
" \"type\": \"json_schema\",\n", " \"type\": \"json_schema\",\n",
" \"json_schema\": GeneratorOutputSchema.model_json_schema()\n", " \"json_schema\": GeneratorOutputSchema.model_json_schema()\n",
" }\n", " }\n",
"})\n", "}\n",
"\n", "\n",
"class EvaluatorOutputSchema(BaseModel):\n", "class EvaluatorOutputSchema(BaseModel):\n",
" evaluation: str\n", " evaluation: str\n",
" feedback: str\n", " feedback: str\n",
"\n", "\n",
"evaluator_agent_config = AgentConfig({\n", "evaluator_agent_config = {\n",
" **base_agent_config,\n", " **base_agent_config,\n",
" \"instructions\": \"\"\"Evaluate this following code implementation for:\n", " \"instructions\": \"\"\"Evaluate this following code implementation for:\n",
" 1. code correctness\n", " 1. code correctness\n",
@ -2293,10 +2284,10 @@
" \"type\": \"json_schema\",\n", " \"type\": \"json_schema\",\n",
" \"json_schema\": EvaluatorOutputSchema.model_json_schema()\n", " \"json_schema\": EvaluatorOutputSchema.model_json_schema()\n",
" }\n", " }\n",
"})\n", "}\n",
"\n", "\n",
"generator_agent = Agent(client, generator_agent_config)\n", "generator_agent = Agent(client, **generator_agent_config)\n",
"evaluator_agent = Agent(client, evaluator_agent_config)\n", "evaluator_agent = Agent(client, **evaluator_agent_config)\n",
"generator_session_id = generator_agent.create_session(session_name=f\"generator_agent_{uuid.uuid4()}\")\n", "generator_session_id = generator_agent.create_session(session_name=f\"generator_agent_{uuid.uuid4()}\")\n",
"evaluator_session_id = evaluator_agent.create_session(session_name=f\"evaluator_agent_{uuid.uuid4()}\")\n", "evaluator_session_id = evaluator_agent.create_session(session_name=f\"evaluator_agent_{uuid.uuid4()}\")\n",
"\n", "\n",
@ -2628,7 +2619,7 @@
" analysis: str\n", " analysis: str\n",
" tasks: List[Dict[str, str]]\n", " tasks: List[Dict[str, str]]\n",
"\n", "\n",
"orchestrator_agent_config = AgentConfig({\n", "orchestrator_agent_config = {\n",
" **base_agent_config,\n", " **base_agent_config,\n",
" \"instructions\": \"\"\"Your job is to analyize the task provided by the user andbreak it down into 2-3 distinct approaches:\n", " \"instructions\": \"\"\"Your job is to analyize the task provided by the user andbreak it down into 2-3 distinct approaches:\n",
"\n", "\n",
@ -2651,9 +2642,9 @@
" \"type\": \"json_schema\",\n", " \"type\": \"json_schema\",\n",
" \"json_schema\": OrchestratorOutputSchema.model_json_schema()\n", " \"json_schema\": OrchestratorOutputSchema.model_json_schema()\n",
" }\n", " }\n",
"})\n", "}\n",
"\n", "\n",
"worker_agent_config = AgentConfig({\n", "worker_agent_config = {\n",
" **base_agent_config,\n", " **base_agent_config,\n",
" \"instructions\": \"\"\"You will be given a task guideline. Generate content based on the provided\n", " \"instructions\": \"\"\"You will be given a task guideline. Generate content based on the provided\n",
" task, following the style and guideline descriptions. \n", " task, following the style and guideline descriptions. \n",
@ -2662,7 +2653,7 @@
"\n", "\n",
" Response: Your content here, maintaining the specified style and fully addressing requirements.\n", " Response: Your content here, maintaining the specified style and fully addressing requirements.\n",
" \"\"\",\n", " \"\"\",\n",
"})\n" "}\n"
] ]
}, },
{ {
@ -2673,7 +2664,7 @@
"source": [ "source": [
"def orchestrator_worker_workflow(task, context):\n", "def orchestrator_worker_workflow(task, context):\n",
" # single orchestrator agent\n", " # single orchestrator agent\n",
" orchestrator_agent = Agent(client, orchestrator_agent_config)\n", " orchestrator_agent = Agent(client, **orchestrator_agent_config)\n",
" orchestrator_session_id = orchestrator_agent.create_session(session_name=f\"orchestrator_agent_{uuid.uuid4()}\")\n", " orchestrator_session_id = orchestrator_agent.create_session(session_name=f\"orchestrator_agent_{uuid.uuid4()}\")\n",
"\n", "\n",
" orchestrator_response = orchestrator_agent.create_turn(\n", " orchestrator_response = orchestrator_agent.create_turn(\n",
@ -2689,7 +2680,7 @@
" workers = {}\n", " workers = {}\n",
" # spawn multiple worker agents\n", " # spawn multiple worker agents\n",
" for task in orchestrator_result[\"tasks\"]:\n", " for task in orchestrator_result[\"tasks\"]:\n",
" worker_agent = Agent(client, worker_agent_config)\n", " worker_agent = Agent(client, **worker_agent_config)\n",
" worker_session_id = worker_agent.create_session(session_name=f\"worker_agent_{uuid.uuid4()}\")\n", " worker_session_id = worker_agent.create_session(session_name=f\"worker_agent_{uuid.uuid4()}\")\n",
" workers[task[\"type\"]] = worker_agent\n", " workers[task[\"type\"]] = worker_agent\n",
" \n", " \n",

View file

@ -14,18 +14,16 @@ Agents are configured using the `AgentConfig` class, which includes:
- **Safety Shields**: Guardrails to ensure responsible AI behavior - **Safety Shields**: Guardrails to ensure responsible AI behavior
```python ```python
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
# Configure an agent
agent_config = AgentConfig(
model="meta-llama/Llama-3-70b-chat",
instructions="You are a helpful assistant that can use tools to answer questions.",
toolgroups=["builtin::code_interpreter", "builtin::rag/knowledge_search"],
)
# Create the agent # Create the agent
agent = Agent(llama_stack_client, agent_config) agent = Agent(
llama_stack_client,
model="meta-llama/Llama-3-70b-chat",
instructions="You are a helpful assistant that can use tools to answer questions.",
tools=["builtin::code_interpreter", "builtin::rag/knowledge_search"],
)
``` ```
### 2. Sessions ### 2. Sessions

View file

@ -70,18 +70,18 @@ Each step in this process can be monitored and controlled through configurations
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from rich.pretty import pprint from rich.pretty import pprint
# Replace host and port # Replace host and port
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}") client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
agent_config = AgentConfig( agent = Agent(
client,
# Check with `llama-stack-client models list` # Check with `llama-stack-client models list`
model="Llama3.2-3B-Instruct", model="Llama3.2-3B-Instruct",
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
# Enable both RAG and tool usage # Enable both RAG and tool usage
toolgroups=[ tools=[
{ {
"name": "builtin::rag/knowledge_search", "name": "builtin::rag/knowledge_search",
"args": {"vector_db_ids": ["my_docs"]}, "args": {"vector_db_ids": ["my_docs"]},
@ -98,8 +98,6 @@ agent_config = AgentConfig(
"max_tokens": 2048, "max_tokens": 2048,
}, },
) )
agent = Agent(client, agent_config)
session_id = agent.create_session("monitored_session") session_id = agent.create_session("monitored_session")
# Stream the agent's execution steps # Stream the agent's execution steps

View file

@ -25,17 +25,13 @@ In this example, we will show you how to:
```python ```python
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
agent_config = AgentConfig( agent = Agent(
client,
model="meta-llama/Llama-3.3-70B-Instruct", model="meta-llama/Llama-3.3-70B-Instruct",
instructions="You are a helpful assistant. Use search tool to answer the questions. ", instructions="You are a helpful assistant. Use search tool to answer the questions. ",
toolgroups=["builtin::websearch"], tools=["builtin::websearch"],
input_shields=[],
output_shields=[],
enable_session_persistence=False,
) )
agent = Agent(client, agent_config)
user_prompts = [ user_prompts = [
"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.", "Which teams played in the NBA western conference finals of 2024. Search the web for the answer.",
"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.", "In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.",

View file

@ -86,15 +86,14 @@ results = client.tool_runtime.rag_tool.query(
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example: One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
```python ```python
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
# Configure agent with memory # Create agent with memory
agent_config = AgentConfig( agent = Agent(
client,
model="meta-llama/Llama-3.3-70B-Instruct", model="meta-llama/Llama-3.3-70B-Instruct",
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
enable_session_persistence=False, tools=[
toolgroups=[
{ {
"name": "builtin::rag/knowledge_search", "name": "builtin::rag/knowledge_search",
"args": { "args": {
@ -103,8 +102,6 @@ agent_config = AgentConfig(
} }
], ],
) )
agent = Agent(client, agent_config)
session_id = agent.create_session("rag_session") session_id = agent.create_session("rag_session")

View file

@ -149,15 +149,7 @@ def my_tool(input: int) -> int:
Once defined, simply pass the tool to the agent config. `Agent` will take care of the rest (calling the model with the tool definition, executing the tool, and returning the result to the model for the next iteration). Once defined, simply pass the tool to the agent config. `Agent` will take care of the rest (calling the model with the tool definition, executing the tool, and returning the result to the model for the next iteration).
```python ```python
# Example agent config with client provided tools # Example agent config with client provided tools
client_tools = [ agent = Agent(client, ..., tools=[my_tool])
my_tool,
]
agent_config = AgentConfig(
...,
client_tools=[client_tool.get_tool_definition() for client_tool in client_tools],
)
agent = Agent(client, agent_config, client_tools)
``` ```
Refer to [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py) for an example of how to use client provided tools. Refer to [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py) for an example of how to use client provided tools.
@ -194,10 +186,10 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools")
```python ```python
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.types.agent_create_params import AgentConfig
# Configure the AI agent with necessary parameters # Instantiate the AI agent with the given configuration
agent_config = AgentConfig( agent = Agent(
client,
name="code-interpreter", name="code-interpreter",
description="A code interpreter agent for executing Python code snippets", description="A code interpreter agent for executing Python code snippets",
instructions=""" instructions="""
@ -205,14 +197,10 @@ agent_config = AgentConfig(
Always show the generated code, never generate your own code, and never anticipate results. Always show the generated code, never generate your own code, and never anticipate results.
""", """,
model="meta-llama/Llama-3.2-3B-Instruct", model="meta-llama/Llama-3.2-3B-Instruct",
toolgroups=["builtin::code_interpreter"], tools=["builtin::code_interpreter"],
max_infer_iters=5, max_infer_iters=5,
enable_session_persistence=False,
) )
# Instantiate the AI agent with the given configuration
agent = Agent(client, agent_config)
# Start a session # Start a session
session_id = agent.create_session("tool_session") session_id = agent.create_session("tool_session")

View file

@ -184,7 +184,6 @@ from termcolor import cprint
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types import Document from llama_stack_client.types import Document
@ -241,13 +240,14 @@ client.tool_runtime.rag_tool.insert(
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
) )
agent_config = AgentConfig( rag_agent = Agent(
client,
model=os.environ["INFERENCE_MODEL"], model=os.environ["INFERENCE_MODEL"],
# Define instructions for the agent ( aka system prompt) # Define instructions for the agent ( aka system prompt)
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
enable_session_persistence=False, enable_session_persistence=False,
# Define tools available to the agent # Define tools available to the agent
toolgroups=[ tools=[
{ {
"name": "builtin::rag/knowledge_search", "name": "builtin::rag/knowledge_search",
"args": { "args": {
@ -256,8 +256,6 @@ agent_config = AgentConfig(
} }
], ],
) )
rag_agent = Agent(client, agent_config)
session_id = rag_agent.create_session("test-session") session_id = rag_agent.create_session("test-session")
user_prompts = [ user_prompts = [

View file

@ -294,8 +294,9 @@
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n", " # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n", " webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
"\n", "\n",
" # Define the agent configuration, including the model and tool setup\n", " # Create an agent instance with the client and configuration\n",
" agent_config = AgentConfig(\n", " agent = Agent(\n",
" client, \n",
" model=MODEL_NAME,\n", " model=MODEL_NAME,\n",
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n", " instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
" sampling_params={\n", " sampling_params={\n",
@ -303,17 +304,12 @@
" \"type\": \"greedy\",\n", " \"type\": \"greedy\",\n",
" },\n", " },\n",
" },\n", " },\n",
" tools=[webSearchTool.get_tool_definition()],\n", " tools=[webSearchTool],\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"python_list\",\n",
" input_shields=input_shields,\n", " input_shields=input_shields,\n",
" output_shields=output_shields,\n", " output_shields=output_shields,\n",
" enable_session_persistence=False,\n", " enable_session_persistence=False,\n",
" )\n", " )\n",
"\n", "\n",
" # Create an agent instance with the client and configuration\n",
" agent = Agent(client, agent_config, [webSearchTool])\n",
"\n",
" # Create a session for interaction and print the session ID\n", " # Create a session for interaction and print the session ID\n",
" session_id = agent.create_session(\"test-session\")\n", " session_id = agent.create_session(\"test-session\")\n",
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n", " print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",

View file

@ -110,12 +110,12 @@
"from llama_stack_client import LlamaStackClient\n", "from llama_stack_client import LlamaStackClient\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"\n", "\n",
"\n", "\n",
"async def agent_example():\n", "async def agent_example():\n",
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n", " client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
" agent_config = AgentConfig(\n", " agent = Agent(\n",
" client, \n",
" model=MODEL_NAME,\n", " model=MODEL_NAME,\n",
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n", " instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
" sampling_params={\n", " sampling_params={\n",
@ -130,14 +130,7 @@
" \"api_key\": BRAVE_SEARCH_API_KEY,\n", " \"api_key\": BRAVE_SEARCH_API_KEY,\n",
" }\n", " }\n",
" ],\n", " ],\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"function_tag\",\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
" )\n", " )\n",
"\n",
" agent = Agent(client, agent_config)\n",
" session_id = agent.create_session(\"test-session\")\n", " session_id = agent.create_session(\"test-session\")\n",
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n", " print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",
"\n", "\n",

View file

@ -103,7 +103,6 @@
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import (\n", "from llama_stack_client.types.agent_create_params import (\n",
" AgentConfig,\n",
" AgentConfigToolSearchToolDefinition,\n", " AgentConfigToolSearchToolDefinition,\n",
")\n", ")\n",
"\n", "\n",
@ -117,7 +116,8 @@
") -> Agent:\n", ") -> Agent:\n",
" \"\"\"Create an agent with specified tools.\"\"\"\n", " \"\"\"Create an agent with specified tools.\"\"\"\n",
" print(\"Using the following model: \", model)\n", " print(\"Using the following model: \", model)\n",
" agent_config = AgentConfig(\n", " return Agent(\n",
" client, \n",
" model=model,\n", " model=model,\n",
" instructions=instructions,\n", " instructions=instructions,\n",
" sampling_params={\n", " sampling_params={\n",
@ -126,12 +126,7 @@
" },\n", " },\n",
" },\n", " },\n",
" tools=tools,\n", " tools=tools,\n",
" tool_choice=\"auto\",\n", " )\n"
" tool_prompt_format=\"json\",\n",
" enable_session_persistence=True,\n",
" )\n",
"\n",
" return Agent(client, agent_config)\n"
] ]
}, },
{ {
@ -360,9 +355,9 @@
" # Create the agent with the tool\n", " # Create the agent with the tool\n",
" weather_tool = WeatherTool()\n", " weather_tool = WeatherTool()\n",
"\n", "\n",
" agent_config = AgentConfig(\n", " agent = Agent(\n",
" client=client, \n",
" model=LLAMA31_8B_INSTRUCT,\n", " model=LLAMA31_8B_INSTRUCT,\n",
" # model=model_name,\n",
" instructions=\"\"\"\n", " instructions=\"\"\"\n",
" You are a weather assistant that can provide weather information.\n", " You are a weather assistant that can provide weather information.\n",
" Always specify the location clearly in your responses.\n", " Always specify the location clearly in your responses.\n",
@ -373,16 +368,9 @@
" \"type\": \"greedy\",\n", " \"type\": \"greedy\",\n",
" },\n", " },\n",
" },\n", " },\n",
" tools=[weather_tool.get_tool_definition()],\n", " tools=[weather_tool],\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"json\",\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=True,\n",
" )\n", " )\n",
"\n", "\n",
" agent = Agent(client=client, agent_config=agent_config, custom_tools=[weather_tool])\n",
"\n",
" return agent\n", " return agent\n",
"\n", "\n",
"\n", "\n",

View file

@ -7,7 +7,6 @@
import streamlit as st import streamlit as st
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.memory_insert_params import Document
from modules.api import llama_stack_api from modules.api import llama_stack_api
from modules.utils import data_url_from_file from modules.utils import data_url_from_file
@ -124,13 +123,14 @@ def rag_chat_page():
else: else:
strategy = {"type": "greedy"} strategy = {"type": "greedy"}
agent_config = AgentConfig( agent = Agent(
llama_stack_api.client,
model=selected_model, model=selected_model,
instructions=system_prompt, instructions=system_prompt,
sampling_params={ sampling_params={
"strategy": strategy, "strategy": strategy,
}, },
toolgroups=[ tools=[
dict( dict(
name="builtin::rag/knowledge_search", name="builtin::rag/knowledge_search",
args={ args={
@ -138,12 +138,7 @@ def rag_chat_page():
}, },
) )
], ],
tool_choice="auto",
tool_prompt_format="json",
enable_session_persistence=False,
) )
agent = Agent(llama_stack_api.client, agent_config)
session_id = agent.create_session("rag-session") session_id = agent.create_session("rag-session")
# Chat input # Chat input

View file

@ -64,7 +64,7 @@ def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> D
def agent_config(llama_stack_client_with_mocked_inference, text_model_id): def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()] available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()]
available_shields = available_shields[:1] available_shields = available_shields[:1]
agent_config = AgentConfig( agent_config = dict(
model=text_model_id, model=text_model_id,
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
sampling_params={ sampling_params={
@ -74,7 +74,7 @@ def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
"top_p": 0.9, "top_p": 0.9,
}, },
}, },
toolgroups=[], tools=[],
input_shields=available_shields, input_shields=available_shields,
output_shields=available_shields, output_shields=available_shields,
enable_session_persistence=False, enable_session_persistence=False,
@ -83,7 +83,7 @@ def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config): def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
agent = Agent(llama_stack_client_with_mocked_inference, agent_config) agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
simple_hello = agent.create_turn( simple_hello = agent.create_turn(
@ -137,7 +137,7 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
agent_config = AgentConfig( agent_config = AgentConfig(
**common_params, **common_params,
) )
Server__AgentConfig(**agent_config) Server__AgentConfig(**common_params)
agent_config = AgentConfig( agent_config = AgentConfig(
**common_params, **common_params,
@ -179,11 +179,11 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config): def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
agent_config = { agent_config = {
**agent_config, **agent_config,
"toolgroups": [ "tools": [
"builtin::websearch", "builtin::websearch",
], ],
} }
agent = Agent(llama_stack_client_with_mocked_inference, agent_config) agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn( response = agent.create_turn(
@ -209,11 +209,11 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config): def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
agent_config = { agent_config = {
**agent_config, **agent_config,
"toolgroups": [ "tools": [
"builtin::code_interpreter", "builtin::code_interpreter",
], ],
} }
agent = Agent(llama_stack_client_with_mocked_inference, agent_config) agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn( response = agent.create_turn(
@ -238,12 +238,12 @@ def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, a
def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config): def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config):
agent_config = { agent_config = {
**agent_config, **agent_config,
"toolgroups": [ "tools": [
"builtin::code_interpreter", "builtin::code_interpreter",
], ],
} }
codex_agent = Agent(llama_stack_client_with_mocked_inference, agent_config) codex_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = codex_agent.create_session(f"test-session-{uuid4()}") session_id = codex_agent.create_session(f"test-session-{uuid4()}")
inflation_doc = AgentDocument( inflation_doc = AgentDocument(
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
@ -275,11 +275,11 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
client_tool = get_boiling_point client_tool = get_boiling_point
agent_config = { agent_config = {
**agent_config, **agent_config,
"toolgroups": ["builtin::websearch"], "tools": ["builtin::websearch", client_tool],
"client_tools": [client_tool.get_tool_definition()], "client_tools": [client_tool.get_tool_definition()],
} }
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,)) agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn( response = agent.create_turn(
@ -303,11 +303,11 @@ def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, age
agent_config = { agent_config = {
**agent_config, **agent_config,
"instructions": "You are a helpful assistant Always respond with tool calls no matter what. ", "instructions": "You are a helpful assistant Always respond with tool calls no matter what. ",
"client_tools": [client_tool.get_tool_definition()], "tools": [client_tool],
"max_infer_iters": 5, "max_infer_iters": 5,
} }
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,)) agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn( response = agent.create_turn(
@ -332,10 +332,10 @@ def test_tool_choice(llama_stack_client_with_mocked_inference, agent_config):
test_agent_config = { test_agent_config = {
**agent_config, **agent_config,
"tool_config": {"tool_choice": tool_choice}, "tool_config": {"tool_choice": tool_choice},
"client_tools": [client_tool.get_tool_definition()], "tools": [client_tool],
} }
agent = Agent(llama_stack_client_with_mocked_inference, test_agent_config, client_tools=(client_tool,)) agent = Agent(llama_stack_client_with_mocked_inference, **test_agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn( response = agent.create_turn(
@ -387,7 +387,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
) )
agent_config = { agent_config = {
**agent_config, **agent_config,
"toolgroups": [ "tools": [
dict( dict(
name=rag_tool_name, name=rag_tool_name,
args={ args={
@ -396,7 +396,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
) )
], ],
} }
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config) rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}") session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [ user_prompts = [
( (
@ -422,7 +422,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
@pytest.mark.parametrize( @pytest.mark.parametrize(
"toolgroup", "tool",
[ [
dict( dict(
name="builtin::rag/knowledge_search", name="builtin::rag/knowledge_search",
@ -433,7 +433,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
"builtin::rag/knowledge_search", "builtin::rag/knowledge_search",
], ],
) )
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, toolgroup): def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, tool):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [ documents = [
Document( Document(
@ -446,9 +446,9 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
] ]
agent_config = { agent_config = {
**agent_config, **agent_config,
"toolgroups": [toolgroup], "tools": [tool],
} }
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config) rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}") session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [ user_prompts = [
( (
@ -521,7 +521,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
) )
agent_config = { agent_config = {
**agent_config, **agent_config,
"toolgroups": [ "tools": [
dict( dict(
name="builtin::rag/knowledge_search", name="builtin::rag/knowledge_search",
args={"vector_db_ids": [vector_db_id]}, args={"vector_db_ids": [vector_db_id]},
@ -529,7 +529,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
"builtin::code_interpreter", "builtin::code_interpreter",
], ],
} }
agent = Agent(llama_stack_client_with_mocked_inference, agent_config) agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
inflation_doc = Document( inflation_doc = Document(
document_id="test_csv", document_id="test_csv",
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
@ -578,10 +578,10 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co
**agent_config, **agent_config,
"input_shields": [], "input_shields": [],
"output_shields": [], "output_shields": [],
"client_tools": [client_tool.get_tool_definition()], "tools": [client_tool],
} }
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,)) agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn( response = agent.create_turn(