mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
docs: update test_agents to use new Agent SDK API (#1402)
# Summary: new Agent SDK API is added in https://github.com/meta-llama/llama-stack-client-python/pull/178 Update docs and test to reflect this. Closes https://github.com/meta-llama/llama-stack/issues/1365 # Test Plan: ```bash py.test -v -s --nbval-lax ./docs/getting_started.ipynb LLAMA_STACK_CONFIG=fireworks \ pytest -s -v tests/integration/agents/test_agents.py \ --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct ```
This commit is contained in:
parent
3d71e5a036
commit
ca2910d27a
13 changed files with 121 additions and 206 deletions
|
@ -1635,18 +1635,14 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
|
||||||
"from termcolor import cprint\n",
|
"from termcolor import cprint\n",
|
||||||
"\n",
|
"\n",
|
||||||
"agent_config = AgentConfig(\n",
|
"agent = Agent(\n",
|
||||||
|
" client, \n",
|
||||||
" model=model_id,\n",
|
" model=model_id,\n",
|
||||||
" instructions=\"You are a helpful assistant\",\n",
|
" instructions=\"You are a helpful assistant\",\n",
|
||||||
" toolgroups=[\"builtin::websearch\"],\n",
|
" tools=[\"builtin::websearch\"],\n",
|
||||||
" input_shields=[],\n",
|
|
||||||
" output_shields=[],\n",
|
|
||||||
" enable_session_persistence=False,\n",
|
|
||||||
")\n",
|
")\n",
|
||||||
"agent = Agent(client, agent_config)\n",
|
|
||||||
"user_prompts = [\n",
|
"user_prompts = [\n",
|
||||||
" \"Hello\",\n",
|
" \"Hello\",\n",
|
||||||
" \"Which teams played in the NBA western conference finals of 2024\",\n",
|
" \"Which teams played in the NBA western conference finals of 2024\",\n",
|
||||||
|
@ -1815,7 +1811,6 @@
|
||||||
"import uuid\n",
|
"import uuid\n",
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
|
||||||
"from termcolor import cprint\n",
|
"from termcolor import cprint\n",
|
||||||
"from llama_stack_client.types import Document\n",
|
"from llama_stack_client.types import Document\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -1841,11 +1836,11 @@
|
||||||
" vector_db_id=vector_db_id,\n",
|
" vector_db_id=vector_db_id,\n",
|
||||||
" chunk_size_in_tokens=512,\n",
|
" chunk_size_in_tokens=512,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"agent_config = AgentConfig(\n",
|
"rag_agent = Agent(\n",
|
||||||
|
" client, \n",
|
||||||
" model=model_id,\n",
|
" model=model_id,\n",
|
||||||
" instructions=\"You are a helpful assistant\",\n",
|
" instructions=\"You are a helpful assistant\",\n",
|
||||||
" enable_session_persistence=False,\n",
|
" tools = [\n",
|
||||||
" toolgroups = [\n",
|
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"name\": \"builtin::rag/knowledge_search\",\n",
|
" \"name\": \"builtin::rag/knowledge_search\",\n",
|
||||||
" \"args\" : {\n",
|
" \"args\" : {\n",
|
||||||
|
@ -1854,7 +1849,6 @@
|
||||||
" }\n",
|
" }\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")\n",
|
")\n",
|
||||||
"rag_agent = Agent(client, agent_config)\n",
|
|
||||||
"session_id = rag_agent.create_session(\"test-session\")\n",
|
"session_id = rag_agent.create_session(\"test-session\")\n",
|
||||||
"user_prompts = [\n",
|
"user_prompts = [\n",
|
||||||
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\",\n",
|
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\",\n",
|
||||||
|
@ -1978,23 +1972,19 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from llama_stack_client.types.agents.turn_create_params import Document\n",
|
"from llama_stack_client.types.agents.turn_create_params import Document\n",
|
||||||
"\n",
|
"\n",
|
||||||
"agent_config = AgentConfig(\n",
|
"codex_agent = Agent(\n",
|
||||||
|
" client, \n",
|
||||||
|
" model=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
|
||||||
|
" instructions=\"You are a helpful assistant\",\n",
|
||||||
|
" tools=[\n",
|
||||||
|
" \"builtin::code_interpreter\",\n",
|
||||||
|
" \"builtin::websearch\"\n",
|
||||||
|
" ],\n",
|
||||||
" sampling_params = {\n",
|
" sampling_params = {\n",
|
||||||
" \"max_tokens\" : 4096,\n",
|
" \"max_tokens\" : 4096,\n",
|
||||||
" \"temperature\": 0.0\n",
|
" \"temperature\": 0.0\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" model=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
|
|
||||||
" instructions=\"You are a helpful assistant\",\n",
|
|
||||||
" toolgroups=[\n",
|
|
||||||
" \"builtin::code_interpreter\",\n",
|
|
||||||
" \"builtin::websearch\"\n",
|
|
||||||
" ],\n",
|
|
||||||
" tool_choice=\"auto\",\n",
|
|
||||||
" input_shields=[],\n",
|
|
||||||
" output_shields=[],\n",
|
|
||||||
" enable_session_persistence=False,\n",
|
|
||||||
")\n",
|
")\n",
|
||||||
"codex_agent = Agent(client, agent_config)\n",
|
|
||||||
"session_id = codex_agent.create_session(\"test-session\")\n",
|
"session_id = codex_agent.create_session(\"test-session\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -2904,18 +2894,14 @@
|
||||||
"# NBVAL_SKIP\n",
|
"# NBVAL_SKIP\n",
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
|
||||||
"from termcolor import cprint\n",
|
"from termcolor import cprint\n",
|
||||||
"\n",
|
"\n",
|
||||||
"agent_config = AgentConfig(\n",
|
"agent = Agent(\n",
|
||||||
|
" client, \n",
|
||||||
" model=model_id,\n",
|
" model=model_id,\n",
|
||||||
" instructions=\"You are a helpful assistant\",\n",
|
" instructions=\"You are a helpful assistant\",\n",
|
||||||
" toolgroups=[\"mcp::filesystem\"],\n",
|
" tools=[\"mcp::filesystem\"],\n",
|
||||||
" input_shields=[],\n",
|
|
||||||
" output_shields=[],\n",
|
|
||||||
" enable_session_persistence=False,\n",
|
|
||||||
")\n",
|
")\n",
|
||||||
"agent = Agent(client, agent_config)\n",
|
|
||||||
"user_prompts = [\n",
|
"user_prompts = [\n",
|
||||||
" \"Hello\",\n",
|
" \"Hello\",\n",
|
||||||
" \"list all the files /content\",\n",
|
" \"list all the files /content\",\n",
|
||||||
|
@ -3010,17 +2996,13 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"agent_config = AgentConfig(\n",
|
"agent = Agent(\n",
|
||||||
|
" client, \n",
|
||||||
" model=\"meta-llama/Llama-3.3-70B-Instruct\",\n",
|
" model=\"meta-llama/Llama-3.3-70B-Instruct\",\n",
|
||||||
" instructions=\"You are a helpful assistant. Use search tool to answer the questions. \",\n",
|
" instructions=\"You are a helpful assistant. Use search tool to answer the questions. \",\n",
|
||||||
" toolgroups=[\"builtin::websearch\"],\n",
|
" tools=[\"builtin::websearch\"],\n",
|
||||||
" input_shields=[],\n",
|
|
||||||
" output_shields=[],\n",
|
|
||||||
" enable_session_persistence=False,\n",
|
|
||||||
")\n",
|
")\n",
|
||||||
"agent = Agent(client, agent_config)\n",
|
|
||||||
"user_prompts = [\n",
|
"user_prompts = [\n",
|
||||||
" \"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.\",\n",
|
" \"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.\",\n",
|
||||||
" \"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.\",\n",
|
" \"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.\",\n",
|
||||||
|
@ -4346,16 +4328,11 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
"agent = Agent(\n",
|
||||||
"\n",
|
" client, \n",
|
||||||
"agent_config = AgentConfig(\n",
|
|
||||||
" model=vision_model_id,\n",
|
" model=vision_model_id,\n",
|
||||||
" instructions=\"You are a helpful assistant\",\n",
|
" instructions=\"You are a helpful assistant\",\n",
|
||||||
" enable_session_persistence=False,\n",
|
|
||||||
" toolgroups=[],\n",
|
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
|
||||||
"agent = Agent(client, agent_config)\n",
|
|
||||||
"session_id = agent.create_session(\"test-session\")\n",
|
"session_id = agent.create_session(\"test-session\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"response = agent.create_turn(\n",
|
"response = agent.create_turn(\n",
|
||||||
|
|
|
@ -49,7 +49,6 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from llama_stack_client import LlamaStackClient\n",
|
"from llama_stack_client import LlamaStackClient\n",
|
||||||
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
||||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||||
"from rich.pretty import pprint\n",
|
"from rich.pretty import pprint\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
|
@ -71,20 +70,12 @@
|
||||||
"\n",
|
"\n",
|
||||||
"MODEL_ID = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
|
"MODEL_ID = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"base_agent_config = AgentConfig(\n",
|
"base_agent_config = dict(\n",
|
||||||
" model=MODEL_ID,\n",
|
" model=MODEL_ID,\n",
|
||||||
" instructions=\"You are a helpful assistant.\",\n",
|
" instructions=\"You are a helpful assistant.\",\n",
|
||||||
" sampling_params={\n",
|
" sampling_params={\n",
|
||||||
" \"strategy\": {\"type\": \"top_p\", \"temperature\": 1.0, \"top_p\": 0.9},\n",
|
" \"strategy\": {\"type\": \"top_p\", \"temperature\": 1.0, \"top_p\": 0.9},\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" toolgroups=[],\n",
|
|
||||||
" tool_config={\n",
|
|
||||||
" \"tool_choice\": \"auto\",\n",
|
|
||||||
" \"tool_prompt_format\": \"python_list\",\n",
|
|
||||||
" },\n",
|
|
||||||
" input_shields=[],\n",
|
|
||||||
" output_shields=[],\n",
|
|
||||||
" enable_session_persistence=False,\n",
|
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -172,7 +163,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"vanilla_agent_config = AgentConfig({\n",
|
"vanilla_agent_config = {\n",
|
||||||
" **base_agent_config,\n",
|
" **base_agent_config,\n",
|
||||||
" \"instructions\": \"\"\"\n",
|
" \"instructions\": \"\"\"\n",
|
||||||
" You are a helpful assistant capable of structuring data extraction and formatting. \n",
|
" You are a helpful assistant capable of structuring data extraction and formatting. \n",
|
||||||
|
@ -189,9 +180,9 @@
|
||||||
" Employee satisfaction is at 87 points.\n",
|
" Employee satisfaction is at 87 points.\n",
|
||||||
" Operating margin improved to 34%.\n",
|
" Operating margin improved to 34%.\n",
|
||||||
" \"\"\",\n",
|
" \"\"\",\n",
|
||||||
"})\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"vanilla_agent = Agent(client, vanilla_agent_config)\n",
|
"vanilla_agent = Agent(client, **vanilla_agent_config)\n",
|
||||||
"prompt_chaining_session_id = vanilla_agent.create_session(session_name=f\"vanilla_agent_{uuid.uuid4()}\")\n",
|
"prompt_chaining_session_id = vanilla_agent.create_session(session_name=f\"vanilla_agent_{uuid.uuid4()}\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"prompts = [\n",
|
"prompts = [\n",
|
||||||
|
@ -778,7 +769,7 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# 1. Define a couple of specialized agents\n",
|
"# 1. Define a couple of specialized agents\n",
|
||||||
"billing_agent_config = AgentConfig({\n",
|
"billing_agent_config = {\n",
|
||||||
" **base_agent_config,\n",
|
" **base_agent_config,\n",
|
||||||
" \"instructions\": \"\"\"You are a billing support specialist. Follow these guidelines:\n",
|
" \"instructions\": \"\"\"You are a billing support specialist. Follow these guidelines:\n",
|
||||||
" 1. Always start with \"Billing Support Response:\"\n",
|
" 1. Always start with \"Billing Support Response:\"\n",
|
||||||
|
@ -789,9 +780,9 @@
|
||||||
" \n",
|
" \n",
|
||||||
" Keep responses professional but friendly.\n",
|
" Keep responses professional but friendly.\n",
|
||||||
" \"\"\",\n",
|
" \"\"\",\n",
|
||||||
"})\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"technical_agent_config = AgentConfig({\n",
|
"technical_agent_config = {\n",
|
||||||
" **base_agent_config,\n",
|
" **base_agent_config,\n",
|
||||||
" \"instructions\": \"\"\"You are a technical support engineer. Follow these guidelines:\n",
|
" \"instructions\": \"\"\"You are a technical support engineer. Follow these guidelines:\n",
|
||||||
" 1. Always start with \"Technical Support Response:\"\n",
|
" 1. Always start with \"Technical Support Response:\"\n",
|
||||||
|
@ -802,9 +793,9 @@
|
||||||
" \n",
|
" \n",
|
||||||
" Use clear, numbered steps and technical details.\n",
|
" Use clear, numbered steps and technical details.\n",
|
||||||
" \"\"\",\n",
|
" \"\"\",\n",
|
||||||
"})\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"account_agent_config = AgentConfig({\n",
|
"account_agent_config = {\n",
|
||||||
" **base_agent_config,\n",
|
" **base_agent_config,\n",
|
||||||
" \"instructions\": \"\"\"You are an account security specialist. Follow these guidelines:\n",
|
" \"instructions\": \"\"\"You are an account security specialist. Follow these guidelines:\n",
|
||||||
" 1. Always start with \"Account Support Response:\"\n",
|
" 1. Always start with \"Account Support Response:\"\n",
|
||||||
|
@ -815,9 +806,9 @@
|
||||||
" \n",
|
" \n",
|
||||||
" Maintain a serious, security-focused tone.\n",
|
" Maintain a serious, security-focused tone.\n",
|
||||||
" \"\"\",\n",
|
" \"\"\",\n",
|
||||||
"})\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"product_agent_config = AgentConfig({\n",
|
"product_agent_config = {\n",
|
||||||
" **base_agent_config,\n",
|
" **base_agent_config,\n",
|
||||||
" \"instructions\": \"\"\"You are a product specialist. Follow these guidelines:\n",
|
" \"instructions\": \"\"\"You are a product specialist. Follow these guidelines:\n",
|
||||||
" 1. Always start with \"Product Support Response:\"\n",
|
" 1. Always start with \"Product Support Response:\"\n",
|
||||||
|
@ -828,13 +819,13 @@
|
||||||
" \n",
|
" \n",
|
||||||
" Be educational and encouraging in tone.\n",
|
" Be educational and encouraging in tone.\n",
|
||||||
" \"\"\",\n",
|
" \"\"\",\n",
|
||||||
"})\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"specialized_agents = {\n",
|
"specialized_agents = {\n",
|
||||||
" \"billing\": Agent(client, billing_agent_config),\n",
|
" \"billing\": Agent(client, **billing_agent_config),\n",
|
||||||
" \"technical\": Agent(client, technical_agent_config),\n",
|
" \"technical\": Agent(client, **technical_agent_config),\n",
|
||||||
" \"account\": Agent(client, account_agent_config),\n",
|
" \"account\": Agent(client, **account_agent_config),\n",
|
||||||
" \"product\": Agent(client, product_agent_config),\n",
|
" \"product\": Agent(client, **product_agent_config),\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 2. Define a routing agent\n",
|
"# 2. Define a routing agent\n",
|
||||||
|
@ -842,7 +833,7 @@
|
||||||
" reasoning: str\n",
|
" reasoning: str\n",
|
||||||
" support_team: str\n",
|
" support_team: str\n",
|
||||||
"\n",
|
"\n",
|
||||||
"routing_agent_config = AgentConfig({\n",
|
"routing_agent_config = {\n",
|
||||||
" **base_agent_config,\n",
|
" **base_agent_config,\n",
|
||||||
" \"instructions\": f\"\"\"You are a routing agent. Analyze the user's input and select the most appropriate support team from these options: \n",
|
" \"instructions\": f\"\"\"You are a routing agent. Analyze the user's input and select the most appropriate support team from these options: \n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -862,9 +853,9 @@
|
||||||
" \"type\": \"json_schema\",\n",
|
" \"type\": \"json_schema\",\n",
|
||||||
" \"json_schema\": OutputSchema.model_json_schema()\n",
|
" \"json_schema\": OutputSchema.model_json_schema()\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
"})\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"routing_agent = Agent(client, routing_agent_config)\n",
|
"routing_agent = Agent(client, **routing_agent_config)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 3. Create a session for all agents\n",
|
"# 3. Create a session for all agents\n",
|
||||||
"routing_agent_session_id = routing_agent.create_session(session_name=f\"routing_agent_{uuid.uuid4()}\")\n",
|
"routing_agent_session_id = routing_agent.create_session(session_name=f\"routing_agent_{uuid.uuid4()}\")\n",
|
||||||
|
@ -1725,17 +1716,17 @@
|
||||||
"from concurrent.futures import ThreadPoolExecutor\n",
|
"from concurrent.futures import ThreadPoolExecutor\n",
|
||||||
"from typing import List\n",
|
"from typing import List\n",
|
||||||
"\n",
|
"\n",
|
||||||
"worker_agent_config = AgentConfig({\n",
|
"worker_agent_config = {\n",
|
||||||
" **base_agent_config,\n",
|
" **base_agent_config,\n",
|
||||||
" \"instructions\": \"\"\"You are a helpful assistant that can analyze the impact of market changes on stakeholders.\n",
|
" \"instructions\": \"\"\"You are a helpful assistant that can analyze the impact of market changes on stakeholders.\n",
|
||||||
" Analyze how market changes will impact this stakeholder group.\n",
|
" Analyze how market changes will impact this stakeholder group.\n",
|
||||||
" Provide specific impacts and recommended actions.\n",
|
" Provide specific impacts and recommended actions.\n",
|
||||||
" Format with clear sections and priorities.\n",
|
" Format with clear sections and priorities.\n",
|
||||||
" \"\"\",\n",
|
" \"\"\",\n",
|
||||||
"})\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def create_worker_task(task: str):\n",
|
"def create_worker_task(task: str):\n",
|
||||||
" worker_agent = Agent(client, worker_agent_config)\n",
|
" worker_agent = Agent(client, **worker_agent_config)\n",
|
||||||
" worker_session_id = worker_agent.create_session(session_name=f\"worker_agent_{uuid.uuid4()}\")\n",
|
" worker_session_id = worker_agent.create_session(session_name=f\"worker_agent_{uuid.uuid4()}\")\n",
|
||||||
" task_response = worker_agent.create_turn(\n",
|
" task_response = worker_agent.create_turn(\n",
|
||||||
" messages=[{\"role\": \"user\", \"content\": task}],\n",
|
" messages=[{\"role\": \"user\", \"content\": task}],\n",
|
||||||
|
@ -2248,7 +2239,7 @@
|
||||||
" thoughts: str\n",
|
" thoughts: str\n",
|
||||||
" response: str\n",
|
" response: str\n",
|
||||||
"\n",
|
"\n",
|
||||||
"generator_agent_config = AgentConfig({\n",
|
"generator_agent_config = {\n",
|
||||||
" **base_agent_config,\n",
|
" **base_agent_config,\n",
|
||||||
" \"instructions\": \"\"\"Your goal is to complete the task based on <user input>. If there are feedback \n",
|
" \"instructions\": \"\"\"Your goal is to complete the task based on <user input>. If there are feedback \n",
|
||||||
" from your previous generations, you should reflect on them to improve your solution\n",
|
" from your previous generations, you should reflect on them to improve your solution\n",
|
||||||
|
@ -2263,13 +2254,13 @@
|
||||||
" \"type\": \"json_schema\",\n",
|
" \"type\": \"json_schema\",\n",
|
||||||
" \"json_schema\": GeneratorOutputSchema.model_json_schema()\n",
|
" \"json_schema\": GeneratorOutputSchema.model_json_schema()\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
"})\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"class EvaluatorOutputSchema(BaseModel):\n",
|
"class EvaluatorOutputSchema(BaseModel):\n",
|
||||||
" evaluation: str\n",
|
" evaluation: str\n",
|
||||||
" feedback: str\n",
|
" feedback: str\n",
|
||||||
"\n",
|
"\n",
|
||||||
"evaluator_agent_config = AgentConfig({\n",
|
"evaluator_agent_config = {\n",
|
||||||
" **base_agent_config,\n",
|
" **base_agent_config,\n",
|
||||||
" \"instructions\": \"\"\"Evaluate this following code implementation for:\n",
|
" \"instructions\": \"\"\"Evaluate this following code implementation for:\n",
|
||||||
" 1. code correctness\n",
|
" 1. code correctness\n",
|
||||||
|
@ -2293,10 +2284,10 @@
|
||||||
" \"type\": \"json_schema\",\n",
|
" \"type\": \"json_schema\",\n",
|
||||||
" \"json_schema\": EvaluatorOutputSchema.model_json_schema()\n",
|
" \"json_schema\": EvaluatorOutputSchema.model_json_schema()\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
"})\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"generator_agent = Agent(client, generator_agent_config)\n",
|
"generator_agent = Agent(client, **generator_agent_config)\n",
|
||||||
"evaluator_agent = Agent(client, evaluator_agent_config)\n",
|
"evaluator_agent = Agent(client, **evaluator_agent_config)\n",
|
||||||
"generator_session_id = generator_agent.create_session(session_name=f\"generator_agent_{uuid.uuid4()}\")\n",
|
"generator_session_id = generator_agent.create_session(session_name=f\"generator_agent_{uuid.uuid4()}\")\n",
|
||||||
"evaluator_session_id = evaluator_agent.create_session(session_name=f\"evaluator_agent_{uuid.uuid4()}\")\n",
|
"evaluator_session_id = evaluator_agent.create_session(session_name=f\"evaluator_agent_{uuid.uuid4()}\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -2628,7 +2619,7 @@
|
||||||
" analysis: str\n",
|
" analysis: str\n",
|
||||||
" tasks: List[Dict[str, str]]\n",
|
" tasks: List[Dict[str, str]]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"orchestrator_agent_config = AgentConfig({\n",
|
"orchestrator_agent_config = {\n",
|
||||||
" **base_agent_config,\n",
|
" **base_agent_config,\n",
|
||||||
" \"instructions\": \"\"\"Your job is to analyize the task provided by the user andbreak it down into 2-3 distinct approaches:\n",
|
" \"instructions\": \"\"\"Your job is to analyize the task provided by the user andbreak it down into 2-3 distinct approaches:\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -2651,9 +2642,9 @@
|
||||||
" \"type\": \"json_schema\",\n",
|
" \"type\": \"json_schema\",\n",
|
||||||
" \"json_schema\": OrchestratorOutputSchema.model_json_schema()\n",
|
" \"json_schema\": OrchestratorOutputSchema.model_json_schema()\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
"})\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"worker_agent_config = AgentConfig({\n",
|
"worker_agent_config = {\n",
|
||||||
" **base_agent_config,\n",
|
" **base_agent_config,\n",
|
||||||
" \"instructions\": \"\"\"You will be given a task guideline. Generate content based on the provided\n",
|
" \"instructions\": \"\"\"You will be given a task guideline. Generate content based on the provided\n",
|
||||||
" task, following the style and guideline descriptions. \n",
|
" task, following the style and guideline descriptions. \n",
|
||||||
|
@ -2662,7 +2653,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
" Response: Your content here, maintaining the specified style and fully addressing requirements.\n",
|
" Response: Your content here, maintaining the specified style and fully addressing requirements.\n",
|
||||||
" \"\"\",\n",
|
" \"\"\",\n",
|
||||||
"})\n"
|
"}\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -2673,7 +2664,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"def orchestrator_worker_workflow(task, context):\n",
|
"def orchestrator_worker_workflow(task, context):\n",
|
||||||
" # single orchestrator agent\n",
|
" # single orchestrator agent\n",
|
||||||
" orchestrator_agent = Agent(client, orchestrator_agent_config)\n",
|
" orchestrator_agent = Agent(client, **orchestrator_agent_config)\n",
|
||||||
" orchestrator_session_id = orchestrator_agent.create_session(session_name=f\"orchestrator_agent_{uuid.uuid4()}\")\n",
|
" orchestrator_session_id = orchestrator_agent.create_session(session_name=f\"orchestrator_agent_{uuid.uuid4()}\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
" orchestrator_response = orchestrator_agent.create_turn(\n",
|
" orchestrator_response = orchestrator_agent.create_turn(\n",
|
||||||
|
@ -2689,7 +2680,7 @@
|
||||||
" workers = {}\n",
|
" workers = {}\n",
|
||||||
" # spawn multiple worker agents\n",
|
" # spawn multiple worker agents\n",
|
||||||
" for task in orchestrator_result[\"tasks\"]:\n",
|
" for task in orchestrator_result[\"tasks\"]:\n",
|
||||||
" worker_agent = Agent(client, worker_agent_config)\n",
|
" worker_agent = Agent(client, **worker_agent_config)\n",
|
||||||
" worker_session_id = worker_agent.create_session(session_name=f\"worker_agent_{uuid.uuid4()}\")\n",
|
" worker_session_id = worker_agent.create_session(session_name=f\"worker_agent_{uuid.uuid4()}\")\n",
|
||||||
" workers[task[\"type\"]] = worker_agent\n",
|
" workers[task[\"type\"]] = worker_agent\n",
|
||||||
" \n",
|
" \n",
|
||||||
|
|
|
@ -14,18 +14,16 @@ Agents are configured using the `AgentConfig` class, which includes:
|
||||||
- **Safety Shields**: Guardrails to ensure responsible AI behavior
|
- **Safety Shields**: Guardrails to ensure responsible AI behavior
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
|
|
||||||
# Configure an agent
|
|
||||||
agent_config = AgentConfig(
|
|
||||||
model="meta-llama/Llama-3-70b-chat",
|
|
||||||
instructions="You are a helpful assistant that can use tools to answer questions.",
|
|
||||||
toolgroups=["builtin::code_interpreter", "builtin::rag/knowledge_search"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the agent
|
# Create the agent
|
||||||
agent = Agent(llama_stack_client, agent_config)
|
agent = Agent(
|
||||||
|
llama_stack_client,
|
||||||
|
model="meta-llama/Llama-3-70b-chat",
|
||||||
|
instructions="You are a helpful assistant that can use tools to answer questions.",
|
||||||
|
tools=["builtin::code_interpreter", "builtin::rag/knowledge_search"],
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Sessions
|
### 2. Sessions
|
||||||
|
|
|
@ -70,18 +70,18 @@ Each step in this process can be monitored and controlled through configurations
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
|
||||||
from rich.pretty import pprint
|
from rich.pretty import pprint
|
||||||
|
|
||||||
# Replace host and port
|
# Replace host and port
|
||||||
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
|
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
agent = Agent(
|
||||||
|
client,
|
||||||
# Check with `llama-stack-client models list`
|
# Check with `llama-stack-client models list`
|
||||||
model="Llama3.2-3B-Instruct",
|
model="Llama3.2-3B-Instruct",
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
# Enable both RAG and tool usage
|
# Enable both RAG and tool usage
|
||||||
toolgroups=[
|
tools=[
|
||||||
{
|
{
|
||||||
"name": "builtin::rag/knowledge_search",
|
"name": "builtin::rag/knowledge_search",
|
||||||
"args": {"vector_db_ids": ["my_docs"]},
|
"args": {"vector_db_ids": ["my_docs"]},
|
||||||
|
@ -98,8 +98,6 @@ agent_config = AgentConfig(
|
||||||
"max_tokens": 2048,
|
"max_tokens": 2048,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = Agent(client, agent_config)
|
|
||||||
session_id = agent.create_session("monitored_session")
|
session_id = agent.create_session("monitored_session")
|
||||||
|
|
||||||
# Stream the agent's execution steps
|
# Stream the agent's execution steps
|
||||||
|
|
|
@ -25,17 +25,13 @@ In this example, we will show you how to:
|
||||||
```python
|
```python
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
agent = Agent(
|
||||||
|
client,
|
||||||
model="meta-llama/Llama-3.3-70B-Instruct",
|
model="meta-llama/Llama-3.3-70B-Instruct",
|
||||||
instructions="You are a helpful assistant. Use search tool to answer the questions. ",
|
instructions="You are a helpful assistant. Use search tool to answer the questions. ",
|
||||||
toolgroups=["builtin::websearch"],
|
tools=["builtin::websearch"],
|
||||||
input_shields=[],
|
|
||||||
output_shields=[],
|
|
||||||
enable_session_persistence=False,
|
|
||||||
)
|
)
|
||||||
agent = Agent(client, agent_config)
|
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.",
|
"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.",
|
||||||
"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.",
|
"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.",
|
||||||
|
|
|
@ -86,15 +86,14 @@ results = client.tool_runtime.rag_tool.query(
|
||||||
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
|
|
||||||
# Configure agent with memory
|
# Create agent with memory
|
||||||
agent_config = AgentConfig(
|
agent = Agent(
|
||||||
|
client,
|
||||||
model="meta-llama/Llama-3.3-70B-Instruct",
|
model="meta-llama/Llama-3.3-70B-Instruct",
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
enable_session_persistence=False,
|
tools=[
|
||||||
toolgroups=[
|
|
||||||
{
|
{
|
||||||
"name": "builtin::rag/knowledge_search",
|
"name": "builtin::rag/knowledge_search",
|
||||||
"args": {
|
"args": {
|
||||||
|
@ -103,8 +102,6 @@ agent_config = AgentConfig(
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = Agent(client, agent_config)
|
|
||||||
session_id = agent.create_session("rag_session")
|
session_id = agent.create_session("rag_session")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -149,15 +149,7 @@ def my_tool(input: int) -> int:
|
||||||
Once defined, simply pass the tool to the agent config. `Agent` will take care of the rest (calling the model with the tool definition, executing the tool, and returning the result to the model for the next iteration).
|
Once defined, simply pass the tool to the agent config. `Agent` will take care of the rest (calling the model with the tool definition, executing the tool, and returning the result to the model for the next iteration).
|
||||||
```python
|
```python
|
||||||
# Example agent config with client provided tools
|
# Example agent config with client provided tools
|
||||||
client_tools = [
|
agent = Agent(client, ..., tools=[my_tool])
|
||||||
my_tool,
|
|
||||||
]
|
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
|
||||||
...,
|
|
||||||
client_tools=[client_tool.get_tool_definition() for client_tool in client_tools],
|
|
||||||
)
|
|
||||||
agent = Agent(client, agent_config, client_tools)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Refer to [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py) for an example of how to use client provided tools.
|
Refer to [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py) for an example of how to use client provided tools.
|
||||||
|
@ -194,10 +186,10 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools")
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
|
||||||
|
|
||||||
# Configure the AI agent with necessary parameters
|
# Instantiate the AI agent with the given configuration
|
||||||
agent_config = AgentConfig(
|
agent = Agent(
|
||||||
|
client,
|
||||||
name="code-interpreter",
|
name="code-interpreter",
|
||||||
description="A code interpreter agent for executing Python code snippets",
|
description="A code interpreter agent for executing Python code snippets",
|
||||||
instructions="""
|
instructions="""
|
||||||
|
@ -205,14 +197,10 @@ agent_config = AgentConfig(
|
||||||
Always show the generated code, never generate your own code, and never anticipate results.
|
Always show the generated code, never generate your own code, and never anticipate results.
|
||||||
""",
|
""",
|
||||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||||
toolgroups=["builtin::code_interpreter"],
|
tools=["builtin::code_interpreter"],
|
||||||
max_infer_iters=5,
|
max_infer_iters=5,
|
||||||
enable_session_persistence=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Instantiate the AI agent with the given configuration
|
|
||||||
agent = Agent(client, agent_config)
|
|
||||||
|
|
||||||
# Start a session
|
# Start a session
|
||||||
session_id = agent.create_session("tool_session")
|
session_id = agent.create_session("tool_session")
|
||||||
|
|
||||||
|
|
|
@ -184,7 +184,6 @@ from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
|
||||||
from llama_stack_client.types import Document
|
from llama_stack_client.types import Document
|
||||||
|
|
||||||
|
|
||||||
|
@ -241,13 +240,14 @@ client.tool_runtime.rag_tool.insert(
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
rag_agent = Agent(
|
||||||
|
client,
|
||||||
model=os.environ["INFERENCE_MODEL"],
|
model=os.environ["INFERENCE_MODEL"],
|
||||||
# Define instructions for the agent ( aka system prompt)
|
# Define instructions for the agent ( aka system prompt)
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
enable_session_persistence=False,
|
enable_session_persistence=False,
|
||||||
# Define tools available to the agent
|
# Define tools available to the agent
|
||||||
toolgroups=[
|
tools=[
|
||||||
{
|
{
|
||||||
"name": "builtin::rag/knowledge_search",
|
"name": "builtin::rag/knowledge_search",
|
||||||
"args": {
|
"args": {
|
||||||
|
@ -256,8 +256,6 @@ agent_config = AgentConfig(
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
rag_agent = Agent(client, agent_config)
|
|
||||||
session_id = rag_agent.create_session("test-session")
|
session_id = rag_agent.create_session("test-session")
|
||||||
|
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
|
|
|
@ -294,8 +294,9 @@
|
||||||
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
|
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
|
||||||
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Define the agent configuration, including the model and tool setup\n",
|
" # Create an agent instance with the client and configuration\n",
|
||||||
" agent_config = AgentConfig(\n",
|
" agent = Agent(\n",
|
||||||
|
" client, \n",
|
||||||
" model=MODEL_NAME,\n",
|
" model=MODEL_NAME,\n",
|
||||||
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
|
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
|
||||||
" sampling_params={\n",
|
" sampling_params={\n",
|
||||||
|
@ -303,17 +304,12 @@
|
||||||
" \"type\": \"greedy\",\n",
|
" \"type\": \"greedy\",\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" tools=[webSearchTool.get_tool_definition()],\n",
|
" tools=[webSearchTool],\n",
|
||||||
" tool_choice=\"auto\",\n",
|
|
||||||
" tool_prompt_format=\"python_list\",\n",
|
|
||||||
" input_shields=input_shields,\n",
|
" input_shields=input_shields,\n",
|
||||||
" output_shields=output_shields,\n",
|
" output_shields=output_shields,\n",
|
||||||
" enable_session_persistence=False,\n",
|
" enable_session_persistence=False,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Create an agent instance with the client and configuration\n",
|
|
||||||
" agent = Agent(client, agent_config, [webSearchTool])\n",
|
|
||||||
"\n",
|
|
||||||
" # Create a session for interaction and print the session ID\n",
|
" # Create a session for interaction and print the session ID\n",
|
||||||
" session_id = agent.create_session(\"test-session\")\n",
|
" session_id = agent.create_session(\"test-session\")\n",
|
||||||
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",
|
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",
|
||||||
|
|
|
@ -110,12 +110,12 @@
|
||||||
"from llama_stack_client import LlamaStackClient\n",
|
"from llama_stack_client import LlamaStackClient\n",
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"async def agent_example():\n",
|
"async def agent_example():\n",
|
||||||
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
|
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
|
||||||
" agent_config = AgentConfig(\n",
|
" agent = Agent(\n",
|
||||||
|
" client, \n",
|
||||||
" model=MODEL_NAME,\n",
|
" model=MODEL_NAME,\n",
|
||||||
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
|
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
|
||||||
" sampling_params={\n",
|
" sampling_params={\n",
|
||||||
|
@ -130,14 +130,7 @@
|
||||||
" \"api_key\": BRAVE_SEARCH_API_KEY,\n",
|
" \"api_key\": BRAVE_SEARCH_API_KEY,\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
" tool_choice=\"auto\",\n",
|
|
||||||
" tool_prompt_format=\"function_tag\",\n",
|
|
||||||
" input_shields=[],\n",
|
|
||||||
" output_shields=[],\n",
|
|
||||||
" enable_session_persistence=False,\n",
|
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
|
||||||
" agent = Agent(client, agent_config)\n",
|
|
||||||
" session_id = agent.create_session(\"test-session\")\n",
|
" session_id = agent.create_session(\"test-session\")\n",
|
||||||
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",
|
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
|
@ -103,7 +103,6 @@
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||||
"from llama_stack_client.types.agent_create_params import (\n",
|
"from llama_stack_client.types.agent_create_params import (\n",
|
||||||
" AgentConfig,\n",
|
|
||||||
" AgentConfigToolSearchToolDefinition,\n",
|
" AgentConfigToolSearchToolDefinition,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -117,7 +116,8 @@
|
||||||
") -> Agent:\n",
|
") -> Agent:\n",
|
||||||
" \"\"\"Create an agent with specified tools.\"\"\"\n",
|
" \"\"\"Create an agent with specified tools.\"\"\"\n",
|
||||||
" print(\"Using the following model: \", model)\n",
|
" print(\"Using the following model: \", model)\n",
|
||||||
" agent_config = AgentConfig(\n",
|
" return Agent(\n",
|
||||||
|
" client, \n",
|
||||||
" model=model,\n",
|
" model=model,\n",
|
||||||
" instructions=instructions,\n",
|
" instructions=instructions,\n",
|
||||||
" sampling_params={\n",
|
" sampling_params={\n",
|
||||||
|
@ -126,12 +126,7 @@
|
||||||
" },\n",
|
" },\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" tools=tools,\n",
|
" tools=tools,\n",
|
||||||
" tool_choice=\"auto\",\n",
|
" )\n"
|
||||||
" tool_prompt_format=\"json\",\n",
|
|
||||||
" enable_session_persistence=True,\n",
|
|
||||||
" )\n",
|
|
||||||
"\n",
|
|
||||||
" return Agent(client, agent_config)\n"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -360,9 +355,9 @@
|
||||||
" # Create the agent with the tool\n",
|
" # Create the agent with the tool\n",
|
||||||
" weather_tool = WeatherTool()\n",
|
" weather_tool = WeatherTool()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" agent_config = AgentConfig(\n",
|
" agent = Agent(\n",
|
||||||
|
" client=client, \n",
|
||||||
" model=LLAMA31_8B_INSTRUCT,\n",
|
" model=LLAMA31_8B_INSTRUCT,\n",
|
||||||
" # model=model_name,\n",
|
|
||||||
" instructions=\"\"\"\n",
|
" instructions=\"\"\"\n",
|
||||||
" You are a weather assistant that can provide weather information.\n",
|
" You are a weather assistant that can provide weather information.\n",
|
||||||
" Always specify the location clearly in your responses.\n",
|
" Always specify the location clearly in your responses.\n",
|
||||||
|
@ -373,16 +368,9 @@
|
||||||
" \"type\": \"greedy\",\n",
|
" \"type\": \"greedy\",\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" tools=[weather_tool.get_tool_definition()],\n",
|
" tools=[weather_tool],\n",
|
||||||
" tool_choice=\"auto\",\n",
|
|
||||||
" tool_prompt_format=\"json\",\n",
|
|
||||||
" input_shields=[],\n",
|
|
||||||
" output_shields=[],\n",
|
|
||||||
" enable_session_persistence=True,\n",
|
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" agent = Agent(client=client, agent_config=agent_config, custom_tools=[weather_tool])\n",
|
|
||||||
"\n",
|
|
||||||
" return agent\n",
|
" return agent\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
|
||||||
from llama_stack_client.types.memory_insert_params import Document
|
from llama_stack_client.types.memory_insert_params import Document
|
||||||
from modules.api import llama_stack_api
|
from modules.api import llama_stack_api
|
||||||
from modules.utils import data_url_from_file
|
from modules.utils import data_url_from_file
|
||||||
|
@ -124,13 +123,14 @@ def rag_chat_page():
|
||||||
else:
|
else:
|
||||||
strategy = {"type": "greedy"}
|
strategy = {"type": "greedy"}
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
agent = Agent(
|
||||||
|
llama_stack_api.client,
|
||||||
model=selected_model,
|
model=selected_model,
|
||||||
instructions=system_prompt,
|
instructions=system_prompt,
|
||||||
sampling_params={
|
sampling_params={
|
||||||
"strategy": strategy,
|
"strategy": strategy,
|
||||||
},
|
},
|
||||||
toolgroups=[
|
tools=[
|
||||||
dict(
|
dict(
|
||||||
name="builtin::rag/knowledge_search",
|
name="builtin::rag/knowledge_search",
|
||||||
args={
|
args={
|
||||||
|
@ -138,12 +138,7 @@ def rag_chat_page():
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
tool_choice="auto",
|
|
||||||
tool_prompt_format="json",
|
|
||||||
enable_session_persistence=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = Agent(llama_stack_api.client, agent_config)
|
|
||||||
session_id = agent.create_session("rag-session")
|
session_id = agent.create_session("rag-session")
|
||||||
|
|
||||||
# Chat input
|
# Chat input
|
||||||
|
|
|
@ -64,7 +64,7 @@ def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> D
|
||||||
def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
|
def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
|
||||||
available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()]
|
available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()]
|
||||||
available_shields = available_shields[:1]
|
available_shields = available_shields[:1]
|
||||||
agent_config = AgentConfig(
|
agent_config = dict(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
sampling_params={
|
sampling_params={
|
||||||
|
@ -74,7 +74,7 @@ def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
|
||||||
"top_p": 0.9,
|
"top_p": 0.9,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
toolgroups=[],
|
tools=[],
|
||||||
input_shields=available_shields,
|
input_shields=available_shields,
|
||||||
output_shields=available_shields,
|
output_shields=available_shields,
|
||||||
enable_session_persistence=False,
|
enable_session_persistence=False,
|
||||||
|
@ -83,7 +83,7 @@ def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
|
||||||
|
|
||||||
|
|
||||||
def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
|
def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
simple_hello = agent.create_turn(
|
simple_hello = agent.create_turn(
|
||||||
|
@ -137,7 +137,7 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
**common_params,
|
**common_params,
|
||||||
)
|
)
|
||||||
Server__AgentConfig(**agent_config)
|
Server__AgentConfig(**common_params)
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
**common_params,
|
**common_params,
|
||||||
|
@ -179,11 +179,11 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
|
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"toolgroups": [
|
"tools": [
|
||||||
"builtin::websearch",
|
"builtin::websearch",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
|
@ -209,11 +209,11 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
|
||||||
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
|
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"toolgroups": [
|
"tools": [
|
||||||
"builtin::code_interpreter",
|
"builtin::code_interpreter",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
|
@ -238,12 +238,12 @@ def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, a
|
||||||
def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config):
|
def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"toolgroups": [
|
"tools": [
|
||||||
"builtin::code_interpreter",
|
"builtin::code_interpreter",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
codex_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
codex_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||||
session_id = codex_agent.create_session(f"test-session-{uuid4()}")
|
session_id = codex_agent.create_session(f"test-session-{uuid4()}")
|
||||||
inflation_doc = AgentDocument(
|
inflation_doc = AgentDocument(
|
||||||
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
||||||
|
@ -275,11 +275,11 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
client_tool = get_boiling_point
|
client_tool = get_boiling_point
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"toolgroups": ["builtin::websearch"],
|
"tools": ["builtin::websearch", client_tool],
|
||||||
"client_tools": [client_tool.get_tool_definition()],
|
"client_tools": [client_tool.get_tool_definition()],
|
||||||
}
|
}
|
||||||
|
|
||||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
|
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
|
@ -303,11 +303,11 @@ def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, age
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"instructions": "You are a helpful assistant Always respond with tool calls no matter what. ",
|
"instructions": "You are a helpful assistant Always respond with tool calls no matter what. ",
|
||||||
"client_tools": [client_tool.get_tool_definition()],
|
"tools": [client_tool],
|
||||||
"max_infer_iters": 5,
|
"max_infer_iters": 5,
|
||||||
}
|
}
|
||||||
|
|
||||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
|
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
|
@ -332,10 +332,10 @@ def test_tool_choice(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
test_agent_config = {
|
test_agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"tool_config": {"tool_choice": tool_choice},
|
"tool_config": {"tool_choice": tool_choice},
|
||||||
"client_tools": [client_tool.get_tool_definition()],
|
"tools": [client_tool],
|
||||||
}
|
}
|
||||||
|
|
||||||
agent = Agent(llama_stack_client_with_mocked_inference, test_agent_config, client_tools=(client_tool,))
|
agent = Agent(llama_stack_client_with_mocked_inference, **test_agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
|
@ -387,7 +387,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
||||||
)
|
)
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"toolgroups": [
|
"tools": [
|
||||||
dict(
|
dict(
|
||||||
name=rag_tool_name,
|
name=rag_tool_name,
|
||||||
args={
|
args={
|
||||||
|
@ -396,7 +396,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
(
|
(
|
||||||
|
@ -422,7 +422,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"toolgroup",
|
"tool",
|
||||||
[
|
[
|
||||||
dict(
|
dict(
|
||||||
name="builtin::rag/knowledge_search",
|
name="builtin::rag/knowledge_search",
|
||||||
|
@ -433,7 +433,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
||||||
"builtin::rag/knowledge_search",
|
"builtin::rag/knowledge_search",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, toolgroup):
|
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, tool):
|
||||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||||
documents = [
|
documents = [
|
||||||
Document(
|
Document(
|
||||||
|
@ -446,9 +446,9 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
|
||||||
]
|
]
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"toolgroups": [toolgroup],
|
"tools": [tool],
|
||||||
}
|
}
|
||||||
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
(
|
(
|
||||||
|
@ -521,7 +521,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
|
||||||
)
|
)
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"toolgroups": [
|
"tools": [
|
||||||
dict(
|
dict(
|
||||||
name="builtin::rag/knowledge_search",
|
name="builtin::rag/knowledge_search",
|
||||||
args={"vector_db_ids": [vector_db_id]},
|
args={"vector_db_ids": [vector_db_id]},
|
||||||
|
@ -529,7 +529,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
|
||||||
"builtin::code_interpreter",
|
"builtin::code_interpreter",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||||
inflation_doc = Document(
|
inflation_doc = Document(
|
||||||
document_id="test_csv",
|
document_id="test_csv",
|
||||||
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
||||||
|
@ -578,10 +578,10 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"input_shields": [],
|
"input_shields": [],
|
||||||
"output_shields": [],
|
"output_shields": [],
|
||||||
"client_tools": [client_tool.get_tool_definition()],
|
"tools": [client_tool],
|
||||||
}
|
}
|
||||||
|
|
||||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
|
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue