diff --git a/docs/zero_to_hero_guide/07_Agents101.ipynb b/docs/zero_to_hero_guide/07_Agents101.ipynb index b6df2a4c8..905799946 100644 --- a/docs/zero_to_hero_guide/07_Agents101.ipynb +++ b/docs/zero_to_hero_guide/07_Agents101.ipynb @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -65,7 +65,7 @@ "from dotenv import load_dotenv\n", "\n", "load_dotenv()\n", - "BRAVE_SEARCH_API_KEY = os.environ[\"BRAVE_SEARCH_API_KEY\"]\n" + "TAVILY_SEARCH_API_KEY = os.environ[\"TAVILY_SEARCH_API_KEY\"]\n" ] }, { @@ -110,10 +110,17 @@ "from llama_stack_client import LlamaStackClient\n", "from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n", + "from llama_stack_client.types import UserMessage\n", + "from typing import cast, Iterator\n", "\n", "\n", "async def agent_example():\n", - " client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n", + " client = LlamaStackClient(\n", + " base_url=f\"http://{HOST}:{PORT}\",\n", + " provider_data={\n", + " \"tavily_search_api_key\": TAVILY_SEARCH_API_KEY,\n", + " }\n", + " )\n", " agent = Agent(\n", " client,\n", " model=MODEL_NAME,\n", @@ -123,13 +130,7 @@ " \"type\": \"greedy\",\n", " },\n", " },\n", - " tools=[\n", - " {\n", - " \"type\": \"brave_search\",\n", - " \"engine\": \"brave\",\n", - " \"api_key\": BRAVE_SEARCH_API_KEY,\n", - " }\n", - " ],\n", + " tools=[\"builtin::websearch\"],\n", " )\n", " session_id = agent.create_session(\"test-session\")\n", " print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n", @@ -142,15 +143,13 @@ " for prompt in user_prompts:\n", " response = agent.create_turn(\n", " messages=[\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": prompt,\n", - " }\n", + " UserMessage(role=\"user\", content=prompt)\n", " ],\n", " session_id=session_id,\n", + " stream=True,\n", " )\n", "\n", - " async for log in EventLogger().log(response):\n", + " for log in EventLogger().log(cast(Iterator, response)):\n", " log.print()\n", "\n", "\n",