mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
docs: update test_agents to use new Agent SDK API (#1402)
# Summary: new Agent SDK API is added in https://github.com/meta-llama/llama-stack-client-python/pull/178 Update docs and test to reflect this. Closes https://github.com/meta-llama/llama-stack/issues/1365 # Test Plan: ```bash py.test -v -s --nbval-lax ./docs/getting_started.ipynb LLAMA_STACK_CONFIG=fireworks \ pytest -s -v tests/integration/agents/test_agents.py \ --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct ```
This commit is contained in:
parent
3d71e5a036
commit
ca2910d27a
13 changed files with 121 additions and 206 deletions
|
@ -1635,18 +1635,14 @@
|
|||
"source": [
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from termcolor import cprint\n",
|
||||
"\n",
|
||||
"agent_config = AgentConfig(\n",
|
||||
"agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=model_id,\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
" toolgroups=[\"builtin::websearch\"],\n",
|
||||
" input_shields=[],\n",
|
||||
" output_shields=[],\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
" tools=[\"builtin::websearch\"],\n",
|
||||
")\n",
|
||||
"agent = Agent(client, agent_config)\n",
|
||||
"user_prompts = [\n",
|
||||
" \"Hello\",\n",
|
||||
" \"Which teams played in the NBA western conference finals of 2024\",\n",
|
||||
|
@ -1815,7 +1811,6 @@
|
|||
"import uuid\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from termcolor import cprint\n",
|
||||
"from llama_stack_client.types import Document\n",
|
||||
"\n",
|
||||
|
@ -1841,11 +1836,11 @@
|
|||
" vector_db_id=vector_db_id,\n",
|
||||
" chunk_size_in_tokens=512,\n",
|
||||
")\n",
|
||||
"agent_config = AgentConfig(\n",
|
||||
"rag_agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=model_id,\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
" toolgroups = [\n",
|
||||
" tools = [\n",
|
||||
" {\n",
|
||||
" \"name\": \"builtin::rag/knowledge_search\",\n",
|
||||
" \"args\" : {\n",
|
||||
|
@ -1854,7 +1849,6 @@
|
|||
" }\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
"rag_agent = Agent(client, agent_config)\n",
|
||||
"session_id = rag_agent.create_session(\"test-session\")\n",
|
||||
"user_prompts = [\n",
|
||||
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\",\n",
|
||||
|
@ -1978,23 +1972,19 @@
|
|||
"source": [
|
||||
"from llama_stack_client.types.agents.turn_create_params import Document\n",
|
||||
"\n",
|
||||
"agent_config = AgentConfig(\n",
|
||||
"codex_agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
" tools=[\n",
|
||||
" \"builtin::code_interpreter\",\n",
|
||||
" \"builtin::websearch\"\n",
|
||||
" ],\n",
|
||||
" sampling_params = {\n",
|
||||
" \"max_tokens\" : 4096,\n",
|
||||
" \"temperature\": 0.0\n",
|
||||
" },\n",
|
||||
" model=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
" toolgroups=[\n",
|
||||
" \"builtin::code_interpreter\",\n",
|
||||
" \"builtin::websearch\"\n",
|
||||
" ],\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" input_shields=[],\n",
|
||||
" output_shields=[],\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
")\n",
|
||||
"codex_agent = Agent(client, agent_config)\n",
|
||||
"session_id = codex_agent.create_session(\"test-session\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
@ -2904,18 +2894,14 @@
|
|||
"# NBVAL_SKIP\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from termcolor import cprint\n",
|
||||
"\n",
|
||||
"agent_config = AgentConfig(\n",
|
||||
"agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=model_id,\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
" toolgroups=[\"mcp::filesystem\"],\n",
|
||||
" input_shields=[],\n",
|
||||
" output_shields=[],\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
" tools=[\"mcp::filesystem\"],\n",
|
||||
")\n",
|
||||
"agent = Agent(client, agent_config)\n",
|
||||
"user_prompts = [\n",
|
||||
" \"Hello\",\n",
|
||||
" \"list all the files /content\",\n",
|
||||
|
@ -3010,17 +2996,13 @@
|
|||
"source": [
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"\n",
|
||||
"agent_config = AgentConfig(\n",
|
||||
"agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=\"meta-llama/Llama-3.3-70B-Instruct\",\n",
|
||||
" instructions=\"You are a helpful assistant. Use search tool to answer the questions. \",\n",
|
||||
" toolgroups=[\"builtin::websearch\"],\n",
|
||||
" input_shields=[],\n",
|
||||
" output_shields=[],\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
" tools=[\"builtin::websearch\"],\n",
|
||||
")\n",
|
||||
"agent = Agent(client, agent_config)\n",
|
||||
"user_prompts = [\n",
|
||||
" \"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.\",\n",
|
||||
" \"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.\",\n",
|
||||
|
@ -4346,16 +4328,11 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"\n",
|
||||
"agent_config = AgentConfig(\n",
|
||||
"agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=vision_model_id,\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
" toolgroups=[],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"agent = Agent(client, agent_config)\n",
|
||||
"session_id = agent.create_session(\"test-session\")\n",
|
||||
"\n",
|
||||
"response = agent.create_turn(\n",
|
||||
|
|
|
@ -49,7 +49,6 @@
|
|||
"source": [
|
||||
"from llama_stack_client import LlamaStackClient\n",
|
||||
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from rich.pretty import pprint\n",
|
||||
"import json\n",
|
||||
|
@ -71,20 +70,12 @@
|
|||
"\n",
|
||||
"MODEL_ID = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
|
||||
"\n",
|
||||
"base_agent_config = AgentConfig(\n",
|
||||
"base_agent_config = dict(\n",
|
||||
" model=MODEL_ID,\n",
|
||||
" instructions=\"You are a helpful assistant.\",\n",
|
||||
" sampling_params={\n",
|
||||
" \"strategy\": {\"type\": \"top_p\", \"temperature\": 1.0, \"top_p\": 0.9},\n",
|
||||
" },\n",
|
||||
" toolgroups=[],\n",
|
||||
" tool_config={\n",
|
||||
" \"tool_choice\": \"auto\",\n",
|
||||
" \"tool_prompt_format\": \"python_list\",\n",
|
||||
" },\n",
|
||||
" input_shields=[],\n",
|
||||
" output_shields=[],\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
|
@ -172,7 +163,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"vanilla_agent_config = AgentConfig({\n",
|
||||
"vanilla_agent_config = {\n",
|
||||
" **base_agent_config,\n",
|
||||
" \"instructions\": \"\"\"\n",
|
||||
" You are a helpful assistant capable of structuring data extraction and formatting. \n",
|
||||
|
@ -189,9 +180,9 @@
|
|||
" Employee satisfaction is at 87 points.\n",
|
||||
" Operating margin improved to 34%.\n",
|
||||
" \"\"\",\n",
|
||||
"})\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"vanilla_agent = Agent(client, vanilla_agent_config)\n",
|
||||
"vanilla_agent = Agent(client, **vanilla_agent_config)\n",
|
||||
"prompt_chaining_session_id = vanilla_agent.create_session(session_name=f\"vanilla_agent_{uuid.uuid4()}\")\n",
|
||||
"\n",
|
||||
"prompts = [\n",
|
||||
|
@ -778,7 +769,7 @@
|
|||
],
|
||||
"source": [
|
||||
"# 1. Define a couple of specialized agents\n",
|
||||
"billing_agent_config = AgentConfig({\n",
|
||||
"billing_agent_config = {\n",
|
||||
" **base_agent_config,\n",
|
||||
" \"instructions\": \"\"\"You are a billing support specialist. Follow these guidelines:\n",
|
||||
" 1. Always start with \"Billing Support Response:\"\n",
|
||||
|
@ -789,9 +780,9 @@
|
|||
" \n",
|
||||
" Keep responses professional but friendly.\n",
|
||||
" \"\"\",\n",
|
||||
"})\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"technical_agent_config = AgentConfig({\n",
|
||||
"technical_agent_config = {\n",
|
||||
" **base_agent_config,\n",
|
||||
" \"instructions\": \"\"\"You are a technical support engineer. Follow these guidelines:\n",
|
||||
" 1. Always start with \"Technical Support Response:\"\n",
|
||||
|
@ -802,9 +793,9 @@
|
|||
" \n",
|
||||
" Use clear, numbered steps and technical details.\n",
|
||||
" \"\"\",\n",
|
||||
"})\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"account_agent_config = AgentConfig({\n",
|
||||
"account_agent_config = {\n",
|
||||
" **base_agent_config,\n",
|
||||
" \"instructions\": \"\"\"You are an account security specialist. Follow these guidelines:\n",
|
||||
" 1. Always start with \"Account Support Response:\"\n",
|
||||
|
@ -815,9 +806,9 @@
|
|||
" \n",
|
||||
" Maintain a serious, security-focused tone.\n",
|
||||
" \"\"\",\n",
|
||||
"})\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"product_agent_config = AgentConfig({\n",
|
||||
"product_agent_config = {\n",
|
||||
" **base_agent_config,\n",
|
||||
" \"instructions\": \"\"\"You are a product specialist. Follow these guidelines:\n",
|
||||
" 1. Always start with \"Product Support Response:\"\n",
|
||||
|
@ -828,13 +819,13 @@
|
|||
" \n",
|
||||
" Be educational and encouraging in tone.\n",
|
||||
" \"\"\",\n",
|
||||
"})\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"specialized_agents = {\n",
|
||||
" \"billing\": Agent(client, billing_agent_config),\n",
|
||||
" \"technical\": Agent(client, technical_agent_config),\n",
|
||||
" \"account\": Agent(client, account_agent_config),\n",
|
||||
" \"product\": Agent(client, product_agent_config),\n",
|
||||
" \"billing\": Agent(client, **billing_agent_config),\n",
|
||||
" \"technical\": Agent(client, **technical_agent_config),\n",
|
||||
" \"account\": Agent(client, **account_agent_config),\n",
|
||||
" \"product\": Agent(client, **product_agent_config),\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# 2. Define a routing agent\n",
|
||||
|
@ -842,7 +833,7 @@
|
|||
" reasoning: str\n",
|
||||
" support_team: str\n",
|
||||
"\n",
|
||||
"routing_agent_config = AgentConfig({\n",
|
||||
"routing_agent_config = {\n",
|
||||
" **base_agent_config,\n",
|
||||
" \"instructions\": f\"\"\"You are a routing agent. Analyze the user's input and select the most appropriate support team from these options: \n",
|
||||
"\n",
|
||||
|
@ -862,9 +853,9 @@
|
|||
" \"type\": \"json_schema\",\n",
|
||||
" \"json_schema\": OutputSchema.model_json_schema()\n",
|
||||
" }\n",
|
||||
"})\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"routing_agent = Agent(client, routing_agent_config)\n",
|
||||
"routing_agent = Agent(client, **routing_agent_config)\n",
|
||||
"\n",
|
||||
"# 3. Create a session for all agents\n",
|
||||
"routing_agent_session_id = routing_agent.create_session(session_name=f\"routing_agent_{uuid.uuid4()}\")\n",
|
||||
|
@ -1725,17 +1716,17 @@
|
|||
"from concurrent.futures import ThreadPoolExecutor\n",
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"worker_agent_config = AgentConfig({\n",
|
||||
"worker_agent_config = {\n",
|
||||
" **base_agent_config,\n",
|
||||
" \"instructions\": \"\"\"You are a helpful assistant that can analyze the impact of market changes on stakeholders.\n",
|
||||
" Analyze how market changes will impact this stakeholder group.\n",
|
||||
" Provide specific impacts and recommended actions.\n",
|
||||
" Format with clear sections and priorities.\n",
|
||||
" \"\"\",\n",
|
||||
"})\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"def create_worker_task(task: str):\n",
|
||||
" worker_agent = Agent(client, worker_agent_config)\n",
|
||||
" worker_agent = Agent(client, **worker_agent_config)\n",
|
||||
" worker_session_id = worker_agent.create_session(session_name=f\"worker_agent_{uuid.uuid4()}\")\n",
|
||||
" task_response = worker_agent.create_turn(\n",
|
||||
" messages=[{\"role\": \"user\", \"content\": task}],\n",
|
||||
|
@ -2248,7 +2239,7 @@
|
|||
" thoughts: str\n",
|
||||
" response: str\n",
|
||||
"\n",
|
||||
"generator_agent_config = AgentConfig({\n",
|
||||
"generator_agent_config = {\n",
|
||||
" **base_agent_config,\n",
|
||||
" \"instructions\": \"\"\"Your goal is to complete the task based on <user input>. If there are feedback \n",
|
||||
" from your previous generations, you should reflect on them to improve your solution\n",
|
||||
|
@ -2263,13 +2254,13 @@
|
|||
" \"type\": \"json_schema\",\n",
|
||||
" \"json_schema\": GeneratorOutputSchema.model_json_schema()\n",
|
||||
" }\n",
|
||||
"})\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"class EvaluatorOutputSchema(BaseModel):\n",
|
||||
" evaluation: str\n",
|
||||
" feedback: str\n",
|
||||
"\n",
|
||||
"evaluator_agent_config = AgentConfig({\n",
|
||||
"evaluator_agent_config = {\n",
|
||||
" **base_agent_config,\n",
|
||||
" \"instructions\": \"\"\"Evaluate this following code implementation for:\n",
|
||||
" 1. code correctness\n",
|
||||
|
@ -2293,10 +2284,10 @@
|
|||
" \"type\": \"json_schema\",\n",
|
||||
" \"json_schema\": EvaluatorOutputSchema.model_json_schema()\n",
|
||||
" }\n",
|
||||
"})\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"generator_agent = Agent(client, generator_agent_config)\n",
|
||||
"evaluator_agent = Agent(client, evaluator_agent_config)\n",
|
||||
"generator_agent = Agent(client, **generator_agent_config)\n",
|
||||
"evaluator_agent = Agent(client, **evaluator_agent_config)\n",
|
||||
"generator_session_id = generator_agent.create_session(session_name=f\"generator_agent_{uuid.uuid4()}\")\n",
|
||||
"evaluator_session_id = evaluator_agent.create_session(session_name=f\"evaluator_agent_{uuid.uuid4()}\")\n",
|
||||
"\n",
|
||||
|
@ -2628,7 +2619,7 @@
|
|||
" analysis: str\n",
|
||||
" tasks: List[Dict[str, str]]\n",
|
||||
"\n",
|
||||
"orchestrator_agent_config = AgentConfig({\n",
|
||||
"orchestrator_agent_config = {\n",
|
||||
" **base_agent_config,\n",
|
||||
" \"instructions\": \"\"\"Your job is to analyize the task provided by the user andbreak it down into 2-3 distinct approaches:\n",
|
||||
"\n",
|
||||
|
@ -2651,9 +2642,9 @@
|
|||
" \"type\": \"json_schema\",\n",
|
||||
" \"json_schema\": OrchestratorOutputSchema.model_json_schema()\n",
|
||||
" }\n",
|
||||
"})\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"worker_agent_config = AgentConfig({\n",
|
||||
"worker_agent_config = {\n",
|
||||
" **base_agent_config,\n",
|
||||
" \"instructions\": \"\"\"You will be given a task guideline. Generate content based on the provided\n",
|
||||
" task, following the style and guideline descriptions. \n",
|
||||
|
@ -2662,7 +2653,7 @@
|
|||
"\n",
|
||||
" Response: Your content here, maintaining the specified style and fully addressing requirements.\n",
|
||||
" \"\"\",\n",
|
||||
"})\n"
|
||||
"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -2673,7 +2664,7 @@
|
|||
"source": [
|
||||
"def orchestrator_worker_workflow(task, context):\n",
|
||||
" # single orchestrator agent\n",
|
||||
" orchestrator_agent = Agent(client, orchestrator_agent_config)\n",
|
||||
" orchestrator_agent = Agent(client, **orchestrator_agent_config)\n",
|
||||
" orchestrator_session_id = orchestrator_agent.create_session(session_name=f\"orchestrator_agent_{uuid.uuid4()}\")\n",
|
||||
"\n",
|
||||
" orchestrator_response = orchestrator_agent.create_turn(\n",
|
||||
|
@ -2689,7 +2680,7 @@
|
|||
" workers = {}\n",
|
||||
" # spawn multiple worker agents\n",
|
||||
" for task in orchestrator_result[\"tasks\"]:\n",
|
||||
" worker_agent = Agent(client, worker_agent_config)\n",
|
||||
" worker_agent = Agent(client, **worker_agent_config)\n",
|
||||
" worker_session_id = worker_agent.create_session(session_name=f\"worker_agent_{uuid.uuid4()}\")\n",
|
||||
" workers[task[\"type\"]] = worker_agent\n",
|
||||
" \n",
|
||||
|
|
|
@ -14,18 +14,16 @@ Agents are configured using the `AgentConfig` class, which includes:
|
|||
- **Safety Shields**: Guardrails to ensure responsible AI behavior
|
||||
|
||||
```python
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
|
||||
# Configure an agent
|
||||
agent_config = AgentConfig(
|
||||
model="meta-llama/Llama-3-70b-chat",
|
||||
instructions="You are a helpful assistant that can use tools to answer questions.",
|
||||
toolgroups=["builtin::code_interpreter", "builtin::rag/knowledge_search"],
|
||||
)
|
||||
|
||||
# Create the agent
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
agent = Agent(
|
||||
llama_stack_client,
|
||||
model="meta-llama/Llama-3-70b-chat",
|
||||
instructions="You are a helpful assistant that can use tools to answer questions.",
|
||||
tools=["builtin::code_interpreter", "builtin::rag/knowledge_search"],
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Sessions
|
||||
|
|
|
@ -70,18 +70,18 @@ Each step in this process can be monitored and controlled through configurations
|
|||
from llama_stack_client import LlamaStackClient
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from rich.pretty import pprint
|
||||
|
||||
# Replace host and port
|
||||
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
|
||||
|
||||
agent_config = AgentConfig(
|
||||
agent = Agent(
|
||||
client,
|
||||
# Check with `llama-stack-client models list`
|
||||
model="Llama3.2-3B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
# Enable both RAG and tool usage
|
||||
toolgroups=[
|
||||
tools=[
|
||||
{
|
||||
"name": "builtin::rag/knowledge_search",
|
||||
"args": {"vector_db_ids": ["my_docs"]},
|
||||
|
@ -98,8 +98,6 @@ agent_config = AgentConfig(
|
|||
"max_tokens": 2048,
|
||||
},
|
||||
)
|
||||
|
||||
agent = Agent(client, agent_config)
|
||||
session_id = agent.create_session("monitored_session")
|
||||
|
||||
# Stream the agent's execution steps
|
||||
|
|
|
@ -25,17 +25,13 @@ In this example, we will show you how to:
|
|||
```python
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
|
||||
agent_config = AgentConfig(
|
||||
agent = Agent(
|
||||
client,
|
||||
model="meta-llama/Llama-3.3-70B-Instruct",
|
||||
instructions="You are a helpful assistant. Use search tool to answer the questions. ",
|
||||
toolgroups=["builtin::websearch"],
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
enable_session_persistence=False,
|
||||
tools=["builtin::websearch"],
|
||||
)
|
||||
agent = Agent(client, agent_config)
|
||||
user_prompts = [
|
||||
"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.",
|
||||
"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.",
|
||||
|
|
|
@ -86,15 +86,14 @@ results = client.tool_runtime.rag_tool.query(
|
|||
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
||||
|
||||
```python
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
|
||||
# Configure agent with memory
|
||||
agent_config = AgentConfig(
|
||||
# Create agent with memory
|
||||
agent = Agent(
|
||||
client,
|
||||
model="meta-llama/Llama-3.3-70B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
enable_session_persistence=False,
|
||||
toolgroups=[
|
||||
tools=[
|
||||
{
|
||||
"name": "builtin::rag/knowledge_search",
|
||||
"args": {
|
||||
|
@ -103,8 +102,6 @@ agent_config = AgentConfig(
|
|||
}
|
||||
],
|
||||
)
|
||||
|
||||
agent = Agent(client, agent_config)
|
||||
session_id = agent.create_session("rag_session")
|
||||
|
||||
|
||||
|
|
|
@ -149,15 +149,7 @@ def my_tool(input: int) -> int:
|
|||
Once defined, simply pass the tool to the agent config. `Agent` will take care of the rest (calling the model with the tool definition, executing the tool, and returning the result to the model for the next iteration).
|
||||
```python
|
||||
# Example agent config with client provided tools
|
||||
client_tools = [
|
||||
my_tool,
|
||||
]
|
||||
|
||||
agent_config = AgentConfig(
|
||||
...,
|
||||
client_tools=[client_tool.get_tool_definition() for client_tool in client_tools],
|
||||
)
|
||||
agent = Agent(client, agent_config, client_tools)
|
||||
agent = Agent(client, ..., tools=[my_tool])
|
||||
```
|
||||
|
||||
Refer to [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py) for an example of how to use client provided tools.
|
||||
|
@ -194,10 +186,10 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools")
|
|||
|
||||
```python
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
|
||||
# Configure the AI agent with necessary parameters
|
||||
agent_config = AgentConfig(
|
||||
# Instantiate the AI agent with the given configuration
|
||||
agent = Agent(
|
||||
client,
|
||||
name="code-interpreter",
|
||||
description="A code interpreter agent for executing Python code snippets",
|
||||
instructions="""
|
||||
|
@ -205,14 +197,10 @@ agent_config = AgentConfig(
|
|||
Always show the generated code, never generate your own code, and never anticipate results.
|
||||
""",
|
||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||
toolgroups=["builtin::code_interpreter"],
|
||||
tools=["builtin::code_interpreter"],
|
||||
max_infer_iters=5,
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
||||
# Instantiate the AI agent with the given configuration
|
||||
agent = Agent(client, agent_config)
|
||||
|
||||
# Start a session
|
||||
session_id = agent.create_session("tool_session")
|
||||
|
||||
|
|
|
@ -184,7 +184,6 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.types import Document
|
||||
|
||||
|
||||
|
@ -241,13 +240,14 @@ client.tool_runtime.rag_tool.insert(
|
|||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
agent_config = AgentConfig(
|
||||
rag_agent = Agent(
|
||||
client,
|
||||
model=os.environ["INFERENCE_MODEL"],
|
||||
# Define instructions for the agent ( aka system prompt)
|
||||
instructions="You are a helpful assistant",
|
||||
enable_session_persistence=False,
|
||||
# Define tools available to the agent
|
||||
toolgroups=[
|
||||
tools=[
|
||||
{
|
||||
"name": "builtin::rag/knowledge_search",
|
||||
"args": {
|
||||
|
@ -256,8 +256,6 @@ agent_config = AgentConfig(
|
|||
}
|
||||
],
|
||||
)
|
||||
|
||||
rag_agent = Agent(client, agent_config)
|
||||
session_id = rag_agent.create_session("test-session")
|
||||
|
||||
user_prompts = [
|
||||
|
|
|
@ -294,8 +294,9 @@
|
|||
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
|
||||
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
||||
"\n",
|
||||
" # Define the agent configuration, including the model and tool setup\n",
|
||||
" agent_config = AgentConfig(\n",
|
||||
" # Create an agent instance with the client and configuration\n",
|
||||
" agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=MODEL_NAME,\n",
|
||||
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
|
||||
" sampling_params={\n",
|
||||
|
@ -303,17 +304,12 @@
|
|||
" \"type\": \"greedy\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" tools=[webSearchTool.get_tool_definition()],\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" tool_prompt_format=\"python_list\",\n",
|
||||
" tools=[webSearchTool],\n",
|
||||
" input_shields=input_shields,\n",
|
||||
" output_shields=output_shields,\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Create an agent instance with the client and configuration\n",
|
||||
" agent = Agent(client, agent_config, [webSearchTool])\n",
|
||||
"\n",
|
||||
" # Create a session for interaction and print the session ID\n",
|
||||
" session_id = agent.create_session(\"test-session\")\n",
|
||||
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",
|
||||
|
|
|
@ -110,12 +110,12 @@
|
|||
"from llama_stack_client import LlamaStackClient\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def agent_example():\n",
|
||||
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
|
||||
" agent_config = AgentConfig(\n",
|
||||
" agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=MODEL_NAME,\n",
|
||||
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
|
||||
" sampling_params={\n",
|
||||
|
@ -130,14 +130,7 @@
|
|||
" \"api_key\": BRAVE_SEARCH_API_KEY,\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" tool_prompt_format=\"function_tag\",\n",
|
||||
" input_shields=[],\n",
|
||||
" output_shields=[],\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" agent = Agent(client, agent_config)\n",
|
||||
" session_id = agent.create_session(\"test-session\")\n",
|
||||
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",
|
||||
"\n",
|
||||
|
|
|
@ -103,7 +103,6 @@
|
|||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import (\n",
|
||||
" AgentConfig,\n",
|
||||
" AgentConfigToolSearchToolDefinition,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
|
@ -117,7 +116,8 @@
|
|||
") -> Agent:\n",
|
||||
" \"\"\"Create an agent with specified tools.\"\"\"\n",
|
||||
" print(\"Using the following model: \", model)\n",
|
||||
" agent_config = AgentConfig(\n",
|
||||
" return Agent(\n",
|
||||
" client, \n",
|
||||
" model=model,\n",
|
||||
" instructions=instructions,\n",
|
||||
" sampling_params={\n",
|
||||
|
@ -126,12 +126,7 @@
|
|||
" },\n",
|
||||
" },\n",
|
||||
" tools=tools,\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" tool_prompt_format=\"json\",\n",
|
||||
" enable_session_persistence=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return Agent(client, agent_config)\n"
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -360,9 +355,9 @@
|
|||
" # Create the agent with the tool\n",
|
||||
" weather_tool = WeatherTool()\n",
|
||||
"\n",
|
||||
" agent_config = AgentConfig(\n",
|
||||
" agent = Agent(\n",
|
||||
" client=client, \n",
|
||||
" model=LLAMA31_8B_INSTRUCT,\n",
|
||||
" # model=model_name,\n",
|
||||
" instructions=\"\"\"\n",
|
||||
" You are a weather assistant that can provide weather information.\n",
|
||||
" Always specify the location clearly in your responses.\n",
|
||||
|
@ -373,16 +368,9 @@
|
|||
" \"type\": \"greedy\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" tools=[weather_tool.get_tool_definition()],\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" tool_prompt_format=\"json\",\n",
|
||||
" input_shields=[],\n",
|
||||
" output_shields=[],\n",
|
||||
" enable_session_persistence=True,\n",
|
||||
" tools=[weather_tool],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" agent = Agent(client=client, agent_config=agent_config, custom_tools=[weather_tool])\n",
|
||||
"\n",
|
||||
" return agent\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import streamlit as st
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
from modules.api import llama_stack_api
|
||||
from modules.utils import data_url_from_file
|
||||
|
@ -124,13 +123,14 @@ def rag_chat_page():
|
|||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
agent_config = AgentConfig(
|
||||
agent = Agent(
|
||||
llama_stack_api.client,
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
toolgroups=[
|
||||
tools=[
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
|
@ -138,12 +138,7 @@ def rag_chat_page():
|
|||
},
|
||||
)
|
||||
],
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="json",
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
||||
agent = Agent(llama_stack_api.client, agent_config)
|
||||
session_id = agent.create_session("rag-session")
|
||||
|
||||
# Chat input
|
||||
|
|
|
@ -64,7 +64,7 @@ def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> D
|
|||
def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
|
||||
available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()]
|
||||
available_shields = available_shields[:1]
|
||||
agent_config = AgentConfig(
|
||||
agent_config = dict(
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params={
|
||||
|
@ -74,7 +74,7 @@ def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
|
|||
"top_p": 0.9,
|
||||
},
|
||||
},
|
||||
toolgroups=[],
|
||||
tools=[],
|
||||
input_shields=available_shields,
|
||||
output_shields=available_shields,
|
||||
enable_session_persistence=False,
|
||||
|
@ -83,7 +83,7 @@ def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
|
|||
|
||||
|
||||
def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
simple_hello = agent.create_turn(
|
||||
|
@ -137,7 +137,7 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
|
|||
agent_config = AgentConfig(
|
||||
**common_params,
|
||||
)
|
||||
Server__AgentConfig(**agent_config)
|
||||
Server__AgentConfig(**common_params)
|
||||
|
||||
agent_config = AgentConfig(
|
||||
**common_params,
|
||||
|
@ -179,11 +179,11 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
|
|||
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [
|
||||
"tools": [
|
||||
"builtin::websearch",
|
||||
],
|
||||
}
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -209,11 +209,11 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
|
|||
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [
|
||||
"tools": [
|
||||
"builtin::code_interpreter",
|
||||
],
|
||||
}
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -238,12 +238,12 @@ def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, a
|
|||
def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [
|
||||
"tools": [
|
||||
"builtin::code_interpreter",
|
||||
],
|
||||
}
|
||||
|
||||
codex_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
codex_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = codex_agent.create_session(f"test-session-{uuid4()}")
|
||||
inflation_doc = AgentDocument(
|
||||
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
||||
|
@ -275,11 +275,11 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
|
|||
client_tool = get_boiling_point
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": ["builtin::websearch"],
|
||||
"tools": ["builtin::websearch", client_tool],
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -303,11 +303,11 @@ def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, age
|
|||
agent_config = {
|
||||
**agent_config,
|
||||
"instructions": "You are a helpful assistant Always respond with tool calls no matter what. ",
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
"tools": [client_tool],
|
||||
"max_infer_iters": 5,
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -332,10 +332,10 @@ def test_tool_choice(llama_stack_client_with_mocked_inference, agent_config):
|
|||
test_agent_config = {
|
||||
**agent_config,
|
||||
"tool_config": {"tool_choice": tool_choice},
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
"tools": [client_tool],
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, test_agent_config, client_tools=(client_tool,))
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **test_agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -387,7 +387,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
|||
)
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [
|
||||
"tools": [
|
||||
dict(
|
||||
name=rag_tool_name,
|
||||
args={
|
||||
|
@ -396,7 +396,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
|||
)
|
||||
],
|
||||
}
|
||||
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||
user_prompts = [
|
||||
(
|
||||
|
@ -422,7 +422,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"toolgroup",
|
||||
"tool",
|
||||
[
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
|
@ -433,7 +433,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
|||
"builtin::rag/knowledge_search",
|
||||
],
|
||||
)
|
||||
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, toolgroup):
|
||||
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, tool):
|
||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||
documents = [
|
||||
Document(
|
||||
|
@ -446,9 +446,9 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
|
|||
]
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [toolgroup],
|
||||
"tools": [tool],
|
||||
}
|
||||
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||
user_prompts = [
|
||||
(
|
||||
|
@ -521,7 +521,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
|
|||
)
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [
|
||||
"tools": [
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={"vector_db_ids": [vector_db_id]},
|
||||
|
@ -529,7 +529,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
|
|||
"builtin::code_interpreter",
|
||||
],
|
||||
}
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
inflation_doc = Document(
|
||||
document_id="test_csv",
|
||||
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
||||
|
@ -578,10 +578,10 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co
|
|||
**agent_config,
|
||||
"input_shields": [],
|
||||
"output_shields": [],
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
"tools": [client_tool],
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue