llama-stack-mirror/docs/tool_Calling101.ipynb
2024-11-04 19:35:31 -08:00

318 lines
13 KiB
Text

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool Calling"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this section, we'll explore how to enhance your applications with tool calling capabilities. We'll cover:\n",
"1. Setting up and using the Brave Search API\n",
"2. Creating custom tools\n",
"3. Configuring tool prompts and safety settings"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: llama-stack-client in ./.conda/envs/quick/lib/python3.13/site-packages (0.0.48)\n",
"Requirement already satisfied: anyio<5,>=3.5.0 in ./.conda/envs/quick/lib/python3.13/site-packages (from llama-stack-client) (4.6.2.post1)\n",
"Requirement already satisfied: distro<2,>=1.7.0 in ./.conda/envs/quick/lib/python3.13/site-packages (from llama-stack-client) (1.9.0)\n",
"Requirement already satisfied: httpx<1,>=0.23.0 in ./.conda/envs/quick/lib/python3.13/site-packages (from llama-stack-client) (0.27.2)\n",
"Requirement already satisfied: pydantic<3,>=1.9.0 in ./.conda/envs/quick/lib/python3.13/site-packages (from llama-stack-client) (2.9.2)\n",
"Requirement already satisfied: sniffio in ./.conda/envs/quick/lib/python3.13/site-packages (from llama-stack-client) (1.3.1)\n",
"Requirement already satisfied: tabulate>=0.9.0 in ./.conda/envs/quick/lib/python3.13/site-packages (from llama-stack-client) (0.9.0)\n",
"Requirement already satisfied: typing-extensions<5,>=4.7 in ./.conda/envs/quick/lib/python3.13/site-packages (from llama-stack-client) (4.12.2)\n",
"Requirement already satisfied: idna>=2.8 in ./.conda/envs/quick/lib/python3.13/site-packages (from anyio<5,>=3.5.0->llama-stack-client) (3.10)\n",
"Requirement already satisfied: certifi in ./.conda/envs/quick/lib/python3.13/site-packages (from httpx<1,>=0.23.0->llama-stack-client) (2024.8.30)\n",
"Requirement already satisfied: httpcore==1.* in ./.conda/envs/quick/lib/python3.13/site-packages (from httpx<1,>=0.23.0->llama-stack-client) (1.0.6)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in ./.conda/envs/quick/lib/python3.13/site-packages (from httpcore==1.*->httpx<1,>=0.23.0->llama-stack-client) (0.14.0)\n",
"Requirement already satisfied: annotated-types>=0.6.0 in ./.conda/envs/quick/lib/python3.13/site-packages (from pydantic<3,>=1.9.0->llama-stack-client) (0.7.0)\n",
"Requirement already satisfied: pydantic-core==2.23.4 in ./.conda/envs/quick/lib/python3.13/site-packages (from pydantic<3,>=1.9.0->llama-stack-client) (2.23.4)\n"
]
}
],
"source": [
"!pip install llama-stack-client --upgrade"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'Agent' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[4], line 23\u001b[0m\n\u001b[1;32m 15\u001b[0m load_dotenv()\n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# Helper function to create an agent with tools\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_tool_agent\u001b[39m(\n\u001b[1;32m 19\u001b[0m client: LlamaStackClient,\n\u001b[1;32m 20\u001b[0m tools: List[Dict],\n\u001b[1;32m 21\u001b[0m instructions: \u001b[38;5;28mstr\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou are a helpful assistant\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 22\u001b[0m model: \u001b[38;5;28mstr\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLlama3.1-8B-Instruct\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m---> 23\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[43mAgent\u001b[49m:\n\u001b[1;32m 24\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Create an agent with specified tools.\"\"\"\u001b[39;00m\n\u001b[1;32m 25\u001b[0m agent_config \u001b[38;5;241m=\u001b[39m AgentConfig(\n\u001b[1;32m 26\u001b[0m model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[1;32m 27\u001b[0m instructions\u001b[38;5;241m=\u001b[39minstructions,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 38\u001b[0m enable_session_persistence\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 39\u001b[0m )\n",
"\u001b[0;31mNameError\u001b[0m: name 'Agent' is not defined"
]
}
],
"source": [
"import asyncio\n",
"import os\n",
"from typing import Dict, List, Optional\n",
"from dotenv import load_dotenv\n",
"\n",
"from llama_stack_client import LlamaStackClient\n",
"#from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import (\n",
" AgentConfig,\n",
" AgentConfigToolSearchToolDefinition,\n",
")\n",
"\n",
"# Load environment variables\n",
"load_dotenv()\n",
"\n",
"# Helper function to create an agent with tools\n",
"async def create_tool_agent(\n",
" client: LlamaStackClient,\n",
" tools: List[Dict],\n",
" instructions: str = \"You are a helpful assistant\",\n",
" model: str = \"Llama3.1-8B-Instruct\",\n",
") -> Agent:\n",
" \"\"\"Create an agent with specified tools.\"\"\"\n",
" agent_config = AgentConfig(\n",
" model=model,\n",
" instructions=instructions,\n",
" sampling_params={\n",
" \"strategy\": \"greedy\",\n",
" \"temperature\": 1.0,\n",
" \"top_p\": 0.9,\n",
" },\n",
" tools=tools,\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"json\",\n",
" input_shields=[\"llama_guard\"],\n",
" output_shields=[\"llama_guard\"],\n",
" enable_session_persistence=True,\n",
" )\n",
"\n",
" return Agent(client, agent_config)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, create a `.env` file in your notebook directory with your Brave Search API key:\n",
"\n",
"```\n",
"BRAVE_SEARCH_API_KEY=your_key_here\n",
"```\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"async def create_search_agent(client: LlamaStackClient) -> Agent:\n",
" \"\"\"Create an agent with Brave Search capability.\"\"\"\n",
" search_tool = AgentConfigToolSearchToolDefinition(\n",
" type=\"brave_search\",\n",
" engine=\"brave\",\n",
" api_key=os.getenv(\"BRAVE_SEARCH_API_KEY\"),\n",
" )\n",
"\n",
" return await create_tool_agent(\n",
" client=client,\n",
" tools=[search_tool],\n",
" instructions=\"\"\"\n",
" You are a research assistant that can search the web.\n",
" Always cite your sources with URLs when providing information.\n",
" Format your responses as:\n",
"\n",
" FINDINGS:\n",
" [Your summary here]\n",
"\n",
" SOURCES:\n",
" - [Source title](URL)\n",
" \"\"\"\n",
" )\n",
"\n",
"# Example usage\n",
"async def search_example():\n",
" client = LlamaStackClient(base_url=\"http://localhost:8000\")\n",
" agent = await create_search_agent(client)\n",
"\n",
" # Create a session\n",
" session_id = agent.create_session(\"search-session\")\n",
"\n",
" # Example queries\n",
" queries = [\n",
" \"What are the latest developments in quantum computing?\",\n",
" \"Who won the most recent Super Bowl?\",\n",
" ]\n",
"\n",
" for query in queries:\n",
" print(f\"\\nQuery: {query}\")\n",
" print(\"-\" * 50)\n",
"\n",
" response = agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": query}],\n",
" session_id=session_id,\n",
" )\n",
"\n",
" async for log in EventLogger().log(response):\n",
" log.print()\n",
"\n",
"# Run the example (in Jupyter, use asyncio.run())\n",
"await search_example()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Custom Tool Creation\n",
"\n",
"Let's create a custom weather tool:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import TypedDict, Optional\n",
"from datetime import datetime\n",
"\n",
"# Define tool types\n",
"class WeatherInput(TypedDict):\n",
" location: str\n",
" date: Optional[str]\n",
"\n",
"class WeatherOutput(TypedDict):\n",
" temperature: float\n",
" conditions: str\n",
" humidity: float\n",
"\n",
"class WeatherTool:\n",
" \"\"\"Example custom tool for weather information.\"\"\"\n",
"\n",
" def __init__(self, api_key: Optional[str] = None):\n",
" self.api_key = api_key\n",
"\n",
" async def get_weather(self, location: str, date: Optional[str] = None) -> WeatherOutput:\n",
" \"\"\"Simulate getting weather data (replace with actual API call).\"\"\"\n",
" # Mock implementation\n",
" return {\n",
" \"temperature\": 72.5,\n",
" \"conditions\": \"partly cloudy\",\n",
" \"humidity\": 65.0\n",
" }\n",
"\n",
" async def __call__(self, input_data: WeatherInput) -> WeatherOutput:\n",
" \"\"\"Make the tool callable with structured input.\"\"\"\n",
" return await self.get_weather(\n",
" location=input_data[\"location\"],\n",
" date=input_data.get(\"date\")\n",
" )\n",
"\n",
"async def create_weather_agent(client: LlamaStackClient) -> Agent:\n",
" \"\"\"Create an agent with weather tool capability.\"\"\"\n",
" weather_tool = {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_weather\",\n",
" \"description\": \"Get weather information for a location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"City or location name\"\n",
" },\n",
" \"date\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Optional date (YYYY-MM-DD)\",\n",
" \"format\": \"date\"\n",
" }\n",
" },\n",
" \"required\": [\"location\"]\n",
" }\n",
" },\n",
" \"implementation\": WeatherTool()\n",
" }\n",
"\n",
" return await create_tool_agent(\n",
" client=client,\n",
" tools=[weather_tool],\n",
" instructions=\"\"\"\n",
" You are a weather assistant that can provide weather information.\n",
" Always specify the location clearly in your responses.\n",
" Include both temperature and conditions in your summaries.\n",
" \"\"\"\n",
" )\n",
"\n",
"# Example usage\n",
"async def weather_example():\n",
" client = LlamaStackClient(base_url=\"http://localhost:8000\")\n",
" agent = await create_weather_agent(client)\n",
"\n",
" session_id = agent.create_session(\"weather-session\")\n",
"\n",
" queries = [\n",
" \"What's the weather like in San Francisco?\",\n",
" \"Tell me the weather in Tokyo tomorrow\",\n",
" ]\n",
"\n",
" for query in queries:\n",
" print(f\"\\nQuery: {query}\")\n",
" print(\"-\" * 50)\n",
"\n",
" response = agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": query}],\n",
" session_id=session_id,\n",
" )\n",
"\n",
" async for log in EventLogger().log(response):\n",
" log.print()\n",
"\n",
"# Run the example\n",
"await weather_example()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}