docs: update test_agents to use new Agent SDK API (#1402)

# Summary:
new Agent SDK API is added in
https://github.com/meta-llama/llama-stack-client-python/pull/178

Update docs and test to reflect this.

Closes https://github.com/meta-llama/llama-stack/issues/1365

# Test Plan:
```bash
py.test -v -s --nbval-lax ./docs/getting_started.ipynb

LLAMA_STACK_CONFIG=fireworks \
   pytest -s -v tests/integration/agents/test_agents.py \
  --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct
```
This commit is contained in:
ehhuang 2025-03-06 15:21:12 -08:00 committed by GitHub
parent 3d71e5a036
commit ca2910d27a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 121 additions and 206 deletions

View file

@ -1635,18 +1635,14 @@
"source": [
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from termcolor import cprint\n",
"\n",
"agent_config = AgentConfig(\n",
"agent = Agent(\n",
" client, \n",
" model=model_id,\n",
" instructions=\"You are a helpful assistant\",\n",
" toolgroups=[\"builtin::websearch\"],\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
" tools=[\"builtin::websearch\"],\n",
")\n",
"agent = Agent(client, agent_config)\n",
"user_prompts = [\n",
" \"Hello\",\n",
" \"Which teams played in the NBA western conference finals of 2024\",\n",
@ -1815,7 +1811,6 @@
"import uuid\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from termcolor import cprint\n",
"from llama_stack_client.types import Document\n",
"\n",
@ -1841,11 +1836,11 @@
" vector_db_id=vector_db_id,\n",
" chunk_size_in_tokens=512,\n",
")\n",
"agent_config = AgentConfig(\n",
"rag_agent = Agent(\n",
" client, \n",
" model=model_id,\n",
" instructions=\"You are a helpful assistant\",\n",
" enable_session_persistence=False,\n",
" toolgroups = [\n",
" tools = [\n",
" {\n",
" \"name\": \"builtin::rag/knowledge_search\",\n",
" \"args\" : {\n",
@ -1854,7 +1849,6 @@
" }\n",
" ],\n",
")\n",
"rag_agent = Agent(client, agent_config)\n",
"session_id = rag_agent.create_session(\"test-session\")\n",
"user_prompts = [\n",
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\",\n",
@ -1978,23 +1972,19 @@
"source": [
"from llama_stack_client.types.agents.turn_create_params import Document\n",
"\n",
"agent_config = AgentConfig(\n",
"codex_agent = Agent(\n",
" client, \n",
" model=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
" instructions=\"You are a helpful assistant\",\n",
" tools=[\n",
" \"builtin::code_interpreter\",\n",
" \"builtin::websearch\"\n",
" ],\n",
" sampling_params = {\n",
" \"max_tokens\" : 4096,\n",
" \"temperature\": 0.0\n",
" },\n",
" model=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
" instructions=\"You are a helpful assistant\",\n",
" toolgroups=[\n",
" \"builtin::code_interpreter\",\n",
" \"builtin::websearch\"\n",
" ],\n",
" tool_choice=\"auto\",\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
")\n",
"codex_agent = Agent(client, agent_config)\n",
"session_id = codex_agent.create_session(\"test-session\")\n",
"\n",
"\n",
@ -2904,18 +2894,14 @@
"# NBVAL_SKIP\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from termcolor import cprint\n",
"\n",
"agent_config = AgentConfig(\n",
"agent = Agent(\n",
" client, \n",
" model=model_id,\n",
" instructions=\"You are a helpful assistant\",\n",
" toolgroups=[\"mcp::filesystem\"],\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
" tools=[\"mcp::filesystem\"],\n",
")\n",
"agent = Agent(client, agent_config)\n",
"user_prompts = [\n",
" \"Hello\",\n",
" \"list all the files /content\",\n",
@ -3010,17 +2996,13 @@
"source": [
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"\n",
"agent_config = AgentConfig(\n",
"agent = Agent(\n",
" client, \n",
" model=\"meta-llama/Llama-3.3-70B-Instruct\",\n",
" instructions=\"You are a helpful assistant. Use search tool to answer the questions. \",\n",
" toolgroups=[\"builtin::websearch\"],\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
" tools=[\"builtin::websearch\"],\n",
")\n",
"agent = Agent(client, agent_config)\n",
"user_prompts = [\n",
" \"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.\",\n",
" \"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.\",\n",
@ -4346,16 +4328,11 @@
}
],
"source": [
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"\n",
"agent_config = AgentConfig(\n",
"agent = Agent(\n",
" client, \n",
" model=vision_model_id,\n",
" instructions=\"You are a helpful assistant\",\n",
" enable_session_persistence=False,\n",
" toolgroups=[],\n",
")\n",
"\n",
"agent = Agent(client, agent_config)\n",
"session_id = agent.create_session(\"test-session\")\n",
"\n",
"response = agent.create_turn(\n",

View file

@ -49,7 +49,6 @@
"source": [
"from llama_stack_client import LlamaStackClient\n",
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from rich.pretty import pprint\n",
"import json\n",
@ -71,20 +70,12 @@
"\n",
"MODEL_ID = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
"\n",
"base_agent_config = AgentConfig(\n",
"base_agent_config = dict(\n",
" model=MODEL_ID,\n",
" instructions=\"You are a helpful assistant.\",\n",
" sampling_params={\n",
" \"strategy\": {\"type\": \"top_p\", \"temperature\": 1.0, \"top_p\": 0.9},\n",
" },\n",
" toolgroups=[],\n",
" tool_config={\n",
" \"tool_choice\": \"auto\",\n",
" \"tool_prompt_format\": \"python_list\",\n",
" },\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
")"
]
},
@ -172,7 +163,7 @@
}
],
"source": [
"vanilla_agent_config = AgentConfig({\n",
"vanilla_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"\n",
" You are a helpful assistant capable of structuring data extraction and formatting. \n",
@ -189,9 +180,9 @@
" Employee satisfaction is at 87 points.\n",
" Operating margin improved to 34%.\n",
" \"\"\",\n",
"})\n",
"}\n",
"\n",
"vanilla_agent = Agent(client, vanilla_agent_config)\n",
"vanilla_agent = Agent(client, **vanilla_agent_config)\n",
"prompt_chaining_session_id = vanilla_agent.create_session(session_name=f\"vanilla_agent_{uuid.uuid4()}\")\n",
"\n",
"prompts = [\n",
@ -778,7 +769,7 @@
],
"source": [
"# 1. Define a couple of specialized agents\n",
"billing_agent_config = AgentConfig({\n",
"billing_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"You are a billing support specialist. Follow these guidelines:\n",
" 1. Always start with \"Billing Support Response:\"\n",
@ -789,9 +780,9 @@
" \n",
" Keep responses professional but friendly.\n",
" \"\"\",\n",
"})\n",
"}\n",
"\n",
"technical_agent_config = AgentConfig({\n",
"technical_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"You are a technical support engineer. Follow these guidelines:\n",
" 1. Always start with \"Technical Support Response:\"\n",
@ -802,9 +793,9 @@
" \n",
" Use clear, numbered steps and technical details.\n",
" \"\"\",\n",
"})\n",
"}\n",
"\n",
"account_agent_config = AgentConfig({\n",
"account_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"You are an account security specialist. Follow these guidelines:\n",
" 1. Always start with \"Account Support Response:\"\n",
@ -815,9 +806,9 @@
" \n",
" Maintain a serious, security-focused tone.\n",
" \"\"\",\n",
"})\n",
"}\n",
"\n",
"product_agent_config = AgentConfig({\n",
"product_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"You are a product specialist. Follow these guidelines:\n",
" 1. Always start with \"Product Support Response:\"\n",
@ -828,13 +819,13 @@
" \n",
" Be educational and encouraging in tone.\n",
" \"\"\",\n",
"})\n",
"}\n",
"\n",
"specialized_agents = {\n",
" \"billing\": Agent(client, billing_agent_config),\n",
" \"technical\": Agent(client, technical_agent_config),\n",
" \"account\": Agent(client, account_agent_config),\n",
" \"product\": Agent(client, product_agent_config),\n",
" \"billing\": Agent(client, **billing_agent_config),\n",
" \"technical\": Agent(client, **technical_agent_config),\n",
" \"account\": Agent(client, **account_agent_config),\n",
" \"product\": Agent(client, **product_agent_config),\n",
"}\n",
"\n",
"# 2. Define a routing agent\n",
@ -842,7 +833,7 @@
" reasoning: str\n",
" support_team: str\n",
"\n",
"routing_agent_config = AgentConfig({\n",
"routing_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": f\"\"\"You are a routing agent. Analyze the user's input and select the most appropriate support team from these options: \n",
"\n",
@ -862,9 +853,9 @@
" \"type\": \"json_schema\",\n",
" \"json_schema\": OutputSchema.model_json_schema()\n",
" }\n",
"})\n",
"}\n",
"\n",
"routing_agent = Agent(client, routing_agent_config)\n",
"routing_agent = Agent(client, **routing_agent_config)\n",
"\n",
"# 3. Create a session for all agents\n",
"routing_agent_session_id = routing_agent.create_session(session_name=f\"routing_agent_{uuid.uuid4()}\")\n",
@ -1725,17 +1716,17 @@
"from concurrent.futures import ThreadPoolExecutor\n",
"from typing import List\n",
"\n",
"worker_agent_config = AgentConfig({\n",
"worker_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"You are a helpful assistant that can analyze the impact of market changes on stakeholders.\n",
" Analyze how market changes will impact this stakeholder group.\n",
" Provide specific impacts and recommended actions.\n",
" Format with clear sections and priorities.\n",
" \"\"\",\n",
"})\n",
"}\n",
"\n",
"def create_worker_task(task: str):\n",
" worker_agent = Agent(client, worker_agent_config)\n",
" worker_agent = Agent(client, **worker_agent_config)\n",
" worker_session_id = worker_agent.create_session(session_name=f\"worker_agent_{uuid.uuid4()}\")\n",
" task_response = worker_agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": task}],\n",
@ -2248,7 +2239,7 @@
" thoughts: str\n",
" response: str\n",
"\n",
"generator_agent_config = AgentConfig({\n",
"generator_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"Your goal is to complete the task based on <user input>. If there are feedback \n",
" from your previous generations, you should reflect on them to improve your solution\n",
@ -2263,13 +2254,13 @@
" \"type\": \"json_schema\",\n",
" \"json_schema\": GeneratorOutputSchema.model_json_schema()\n",
" }\n",
"})\n",
"}\n",
"\n",
"class EvaluatorOutputSchema(BaseModel):\n",
" evaluation: str\n",
" feedback: str\n",
"\n",
"evaluator_agent_config = AgentConfig({\n",
"evaluator_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"Evaluate this following code implementation for:\n",
" 1. code correctness\n",
@ -2293,10 +2284,10 @@
" \"type\": \"json_schema\",\n",
" \"json_schema\": EvaluatorOutputSchema.model_json_schema()\n",
" }\n",
"})\n",
"}\n",
"\n",
"generator_agent = Agent(client, generator_agent_config)\n",
"evaluator_agent = Agent(client, evaluator_agent_config)\n",
"generator_agent = Agent(client, **generator_agent_config)\n",
"evaluator_agent = Agent(client, **evaluator_agent_config)\n",
"generator_session_id = generator_agent.create_session(session_name=f\"generator_agent_{uuid.uuid4()}\")\n",
"evaluator_session_id = evaluator_agent.create_session(session_name=f\"evaluator_agent_{uuid.uuid4()}\")\n",
"\n",
@ -2628,7 +2619,7 @@
" analysis: str\n",
" tasks: List[Dict[str, str]]\n",
"\n",
"orchestrator_agent_config = AgentConfig({\n",
"orchestrator_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"Your job is to analyize the task provided by the user andbreak it down into 2-3 distinct approaches:\n",
"\n",
@ -2651,9 +2642,9 @@
" \"type\": \"json_schema\",\n",
" \"json_schema\": OrchestratorOutputSchema.model_json_schema()\n",
" }\n",
"})\n",
"}\n",
"\n",
"worker_agent_config = AgentConfig({\n",
"worker_agent_config = {\n",
" **base_agent_config,\n",
" \"instructions\": \"\"\"You will be given a task guideline. Generate content based on the provided\n",
" task, following the style and guideline descriptions. \n",
@ -2662,7 +2653,7 @@
"\n",
" Response: Your content here, maintaining the specified style and fully addressing requirements.\n",
" \"\"\",\n",
"})\n"
"}\n"
]
},
{
@ -2673,7 +2664,7 @@
"source": [
"def orchestrator_worker_workflow(task, context):\n",
" # single orchestrator agent\n",
" orchestrator_agent = Agent(client, orchestrator_agent_config)\n",
" orchestrator_agent = Agent(client, **orchestrator_agent_config)\n",
" orchestrator_session_id = orchestrator_agent.create_session(session_name=f\"orchestrator_agent_{uuid.uuid4()}\")\n",
"\n",
" orchestrator_response = orchestrator_agent.create_turn(\n",
@ -2689,7 +2680,7 @@
" workers = {}\n",
" # spawn multiple worker agents\n",
" for task in orchestrator_result[\"tasks\"]:\n",
" worker_agent = Agent(client, worker_agent_config)\n",
" worker_agent = Agent(client, **worker_agent_config)\n",
" worker_session_id = worker_agent.create_session(session_name=f\"worker_agent_{uuid.uuid4()}\")\n",
" workers[task[\"type\"]] = worker_agent\n",
" \n",

View file

@ -14,18 +14,16 @@ Agents are configured using the `AgentConfig` class, which includes:
- **Safety Shields**: Guardrails to ensure responsible AI behavior
```python
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.lib.agents.agent import Agent
# Configure an agent
agent_config = AgentConfig(
model="meta-llama/Llama-3-70b-chat",
instructions="You are a helpful assistant that can use tools to answer questions.",
toolgroups=["builtin::code_interpreter", "builtin::rag/knowledge_search"],
)
# Create the agent
agent = Agent(llama_stack_client, agent_config)
agent = Agent(
llama_stack_client,
model="meta-llama/Llama-3-70b-chat",
instructions="You are a helpful assistant that can use tools to answer questions.",
tools=["builtin::code_interpreter", "builtin::rag/knowledge_search"],
)
```
### 2. Sessions

View file

@ -70,18 +70,18 @@ Each step in this process can be monitored and controlled through configurations
from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from rich.pretty import pprint
# Replace host and port
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
agent_config = AgentConfig(
agent = Agent(
client,
# Check with `llama-stack-client models list`
model="Llama3.2-3B-Instruct",
instructions="You are a helpful assistant",
# Enable both RAG and tool usage
toolgroups=[
tools=[
{
"name": "builtin::rag/knowledge_search",
"args": {"vector_db_ids": ["my_docs"]},
@ -98,8 +98,6 @@ agent_config = AgentConfig(
"max_tokens": 2048,
},
)
agent = Agent(client, agent_config)
session_id = agent.create_session("monitored_session")
# Stream the agent's execution steps

View file

@ -25,17 +25,13 @@ In this example, we will show you how to:
```python
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
agent_config = AgentConfig(
agent = Agent(
client,
model="meta-llama/Llama-3.3-70B-Instruct",
instructions="You are a helpful assistant. Use search tool to answer the questions. ",
toolgroups=["builtin::websearch"],
input_shields=[],
output_shields=[],
enable_session_persistence=False,
tools=["builtin::websearch"],
)
agent = Agent(client, agent_config)
user_prompts = [
"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.",
"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.",

View file

@ -86,15 +86,14 @@ results = client.tool_runtime.rag_tool.query(
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
```python
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.lib.agents.agent import Agent
# Configure agent with memory
agent_config = AgentConfig(
# Create agent with memory
agent = Agent(
client,
model="meta-llama/Llama-3.3-70B-Instruct",
instructions="You are a helpful assistant",
enable_session_persistence=False,
toolgroups=[
tools=[
{
"name": "builtin::rag/knowledge_search",
"args": {
@ -103,8 +102,6 @@ agent_config = AgentConfig(
}
],
)
agent = Agent(client, agent_config)
session_id = agent.create_session("rag_session")

View file

@ -149,15 +149,7 @@ def my_tool(input: int) -> int:
Once defined, simply pass the tool to the agent config. `Agent` will take care of the rest (calling the model with the tool definition, executing the tool, and returning the result to the model for the next iteration).
```python
# Example agent config with client provided tools
client_tools = [
my_tool,
]
agent_config = AgentConfig(
...,
client_tools=[client_tool.get_tool_definition() for client_tool in client_tools],
)
agent = Agent(client, agent_config, client_tools)
agent = Agent(client, ..., tools=[my_tool])
```
Refer to [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py) for an example of how to use client provided tools.
@ -194,10 +186,10 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools")
```python
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.types.agent_create_params import AgentConfig
# Configure the AI agent with necessary parameters
agent_config = AgentConfig(
# Instantiate the AI agent with the given configuration
agent = Agent(
client,
name="code-interpreter",
description="A code interpreter agent for executing Python code snippets",
instructions="""
@ -205,14 +197,10 @@ agent_config = AgentConfig(
Always show the generated code, never generate your own code, and never anticipate results.
""",
model="meta-llama/Llama-3.2-3B-Instruct",
toolgroups=["builtin::code_interpreter"],
tools=["builtin::code_interpreter"],
max_infer_iters=5,
enable_session_persistence=False,
)
# Instantiate the AI agent with the given configuration
agent = Agent(client, agent_config)
# Start a session
session_id = agent.create_session("tool_session")

View file

@ -184,7 +184,6 @@ from termcolor import cprint
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types import Document
@ -241,13 +240,14 @@ client.tool_runtime.rag_tool.insert(
chunk_size_in_tokens=512,
)
agent_config = AgentConfig(
rag_agent = Agent(
client,
model=os.environ["INFERENCE_MODEL"],
# Define instructions for the agent ( aka system prompt)
instructions="You are a helpful assistant",
enable_session_persistence=False,
# Define tools available to the agent
toolgroups=[
tools=[
{
"name": "builtin::rag/knowledge_search",
"args": {
@ -256,8 +256,6 @@ agent_config = AgentConfig(
}
],
)
rag_agent = Agent(client, agent_config)
session_id = rag_agent.create_session("test-session")
user_prompts = [

View file

@ -294,8 +294,9 @@
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
"\n",
" # Define the agent configuration, including the model and tool setup\n",
" agent_config = AgentConfig(\n",
" # Create an agent instance with the client and configuration\n",
" agent = Agent(\n",
" client, \n",
" model=MODEL_NAME,\n",
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
" sampling_params={\n",
@ -303,17 +304,12 @@
" \"type\": \"greedy\",\n",
" },\n",
" },\n",
" tools=[webSearchTool.get_tool_definition()],\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"python_list\",\n",
" tools=[webSearchTool],\n",
" input_shields=input_shields,\n",
" output_shields=output_shields,\n",
" enable_session_persistence=False,\n",
" )\n",
"\n",
" # Create an agent instance with the client and configuration\n",
" agent = Agent(client, agent_config, [webSearchTool])\n",
"\n",
" # Create a session for interaction and print the session ID\n",
" session_id = agent.create_session(\"test-session\")\n",
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",

View file

@ -110,12 +110,12 @@
"from llama_stack_client import LlamaStackClient\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"\n",
"\n",
"async def agent_example():\n",
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
" agent_config = AgentConfig(\n",
" agent = Agent(\n",
" client, \n",
" model=MODEL_NAME,\n",
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
" sampling_params={\n",
@ -130,14 +130,7 @@
" \"api_key\": BRAVE_SEARCH_API_KEY,\n",
" }\n",
" ],\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"function_tag\",\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
" )\n",
"\n",
" agent = Agent(client, agent_config)\n",
" session_id = agent.create_session(\"test-session\")\n",
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",
"\n",

View file

@ -103,7 +103,6 @@
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import (\n",
" AgentConfig,\n",
" AgentConfigToolSearchToolDefinition,\n",
")\n",
"\n",
@ -117,7 +116,8 @@
") -> Agent:\n",
" \"\"\"Create an agent with specified tools.\"\"\"\n",
" print(\"Using the following model: \", model)\n",
" agent_config = AgentConfig(\n",
" return Agent(\n",
" client, \n",
" model=model,\n",
" instructions=instructions,\n",
" sampling_params={\n",
@ -126,12 +126,7 @@
" },\n",
" },\n",
" tools=tools,\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"json\",\n",
" enable_session_persistence=True,\n",
" )\n",
"\n",
" return Agent(client, agent_config)\n"
" )\n"
]
},
{
@ -360,9 +355,9 @@
" # Create the agent with the tool\n",
" weather_tool = WeatherTool()\n",
"\n",
" agent_config = AgentConfig(\n",
" agent = Agent(\n",
" client=client, \n",
" model=LLAMA31_8B_INSTRUCT,\n",
" # model=model_name,\n",
" instructions=\"\"\"\n",
" You are a weather assistant that can provide weather information.\n",
" Always specify the location clearly in your responses.\n",
@ -373,16 +368,9 @@
" \"type\": \"greedy\",\n",
" },\n",
" },\n",
" tools=[weather_tool.get_tool_definition()],\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"json\",\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=True,\n",
" tools=[weather_tool],\n",
" )\n",
"\n",
" agent = Agent(client=client, agent_config=agent_config, custom_tools=[weather_tool])\n",
"\n",
" return agent\n",
"\n",
"\n",

View file

@ -7,7 +7,6 @@
import streamlit as st
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.memory_insert_params import Document
from modules.api import llama_stack_api
from modules.utils import data_url_from_file
@ -124,13 +123,14 @@ def rag_chat_page():
else:
strategy = {"type": "greedy"}
agent_config = AgentConfig(
agent = Agent(
llama_stack_api.client,
model=selected_model,
instructions=system_prompt,
sampling_params={
"strategy": strategy,
},
toolgroups=[
tools=[
dict(
name="builtin::rag/knowledge_search",
args={
@ -138,12 +138,7 @@ def rag_chat_page():
},
)
],
tool_choice="auto",
tool_prompt_format="json",
enable_session_persistence=False,
)
agent = Agent(llama_stack_api.client, agent_config)
session_id = agent.create_session("rag-session")
# Chat input

View file

@ -64,7 +64,7 @@ def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> D
def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()]
available_shields = available_shields[:1]
agent_config = AgentConfig(
agent_config = dict(
model=text_model_id,
instructions="You are a helpful assistant",
sampling_params={
@ -74,7 +74,7 @@ def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
"top_p": 0.9,
},
},
toolgroups=[],
tools=[],
input_shields=available_shields,
output_shields=available_shields,
enable_session_persistence=False,
@ -83,7 +83,7 @@ def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
simple_hello = agent.create_turn(
@ -137,7 +137,7 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
agent_config = AgentConfig(
**common_params,
)
Server__AgentConfig(**agent_config)
Server__AgentConfig(**common_params)
agent_config = AgentConfig(
**common_params,
@ -179,11 +179,11 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
agent_config = {
**agent_config,
"toolgroups": [
"tools": [
"builtin::websearch",
],
}
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -209,11 +209,11 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
agent_config = {
**agent_config,
"toolgroups": [
"tools": [
"builtin::code_interpreter",
],
}
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -238,12 +238,12 @@ def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, a
def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config):
agent_config = {
**agent_config,
"toolgroups": [
"tools": [
"builtin::code_interpreter",
],
}
codex_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
codex_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = codex_agent.create_session(f"test-session-{uuid4()}")
inflation_doc = AgentDocument(
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
@ -275,11 +275,11 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
client_tool = get_boiling_point
agent_config = {
**agent_config,
"toolgroups": ["builtin::websearch"],
"tools": ["builtin::websearch", client_tool],
"client_tools": [client_tool.get_tool_definition()],
}
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -303,11 +303,11 @@ def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, age
agent_config = {
**agent_config,
"instructions": "You are a helpful assistant Always respond with tool calls no matter what. ",
"client_tools": [client_tool.get_tool_definition()],
"tools": [client_tool],
"max_infer_iters": 5,
}
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -332,10 +332,10 @@ def test_tool_choice(llama_stack_client_with_mocked_inference, agent_config):
test_agent_config = {
**agent_config,
"tool_config": {"tool_choice": tool_choice},
"client_tools": [client_tool.get_tool_definition()],
"tools": [client_tool],
}
agent = Agent(llama_stack_client_with_mocked_inference, test_agent_config, client_tools=(client_tool,))
agent = Agent(llama_stack_client_with_mocked_inference, **test_agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -387,7 +387,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
)
agent_config = {
**agent_config,
"toolgroups": [
"tools": [
dict(
name=rag_tool_name,
args={
@ -396,7 +396,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
)
],
}
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
@ -422,7 +422,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
@pytest.mark.parametrize(
"toolgroup",
"tool",
[
dict(
name="builtin::rag/knowledge_search",
@ -433,7 +433,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
"builtin::rag/knowledge_search",
],
)
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, toolgroup):
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, tool):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [
Document(
@ -446,9 +446,9 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
]
agent_config = {
**agent_config,
"toolgroups": [toolgroup],
"tools": [tool],
}
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
@ -521,7 +521,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
)
agent_config = {
**agent_config,
"toolgroups": [
"tools": [
dict(
name="builtin::rag/knowledge_search",
args={"vector_db_ids": [vector_db_id]},
@ -529,7 +529,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
"builtin::code_interpreter",
],
}
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
inflation_doc = Document(
document_id="test_csv",
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
@ -578,10 +578,10 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co
**agent_config,
"input_shields": [],
"output_shields": [],
"client_tools": [client_tool.get_tool_definition()],
"tools": [client_tool],
}
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(