From 40793cd8ad5db6fdaaf722e00753b9060ba3d745 Mon Sep 17 00:00:00 2001 From: Sanyam Bhutani Date: Tue, 5 Nov 2024 14:43:46 -0800 Subject: [PATCH] Create Tool_Calling101.ipynb --- docs/zero_to_hero_guide/Tool_Calling101.ipynb | 558 ++++++++++++++++++ 1 file changed, 558 insertions(+) create mode 100644 docs/zero_to_hero_guide/Tool_Calling101.ipynb diff --git a/docs/zero_to_hero_guide/Tool_Calling101.ipynb b/docs/zero_to_hero_guide/Tool_Calling101.ipynb new file mode 100644 index 000000000..a4c57ddff --- /dev/null +++ b/docs/zero_to_hero_guide/Tool_Calling101.ipynb @@ -0,0 +1,558 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Getting Started with LlamaStack: Tool Calling Tutorial\n", + "\n", + "Welcome! This notebook will guide you through creating and using custom tools with LlamaStack.\n", + "We'll start with the basics and work our way up to more complex examples.\n", + "\n", + "Table of Contents:\n", + "1. Setup and Installation\n", + "2. Understanding Tool Basics\n", + "3. Creating Your First Tool\n", + "4. Building a Mock Weather Tool\n", + "5. Setting Up the LlamaStack Agent\n", + "6. Running Examples\n", + "7. Next Steps\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Setup\n", + "#### Before we begin, let's import all the required packages:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import asyncio\n", + "import json\n", + "from typing import Dict\n", + "from datetime import datetime" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# LlamaStack specific imports\n", + "from llama_stack_client import LlamaStackClient\n", + "from llama_stack_client.lib.agents.agent import Agent\n", + "from llama_stack_client.lib.agents.event_logger import EventLogger\n", + "from llama_stack_client.types.agent_create_params import AgentConfig\n", + "from llama_stack_client.types.tool_param_definition_param import ToolParamDefinitionParam" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Understanding Tool Basics\n", + "\n", + "In LlamaStack, a tool is like a special function that our AI assistant can use. Think of it as giving the AI a new \n", + "capability, like using a calculator or checking the weather.\n", + "\n", + "Every tool needs:\n", + "- A name: What we call the tool\n", + "- A description: What the tool does\n", + "- Parameters: What information the tool needs to work\n", + "- Implementation: The actual code that does the work\n", + "\n", + "Let's create a base class that all our tools will inherit from:" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "class SingleMessageCustomTool:\n", + " \"\"\"Base class for all our custom tools\"\"\"\n", + " \n", + " async def run(self, messages=None):\n", + " \"\"\"\n", + " Main entry point for running the tool\n", + " Args:\n", + " messages: List of messages (can be None for backward compatibility)\n", + " \"\"\"\n", + " if messages and len(messages) > 0:\n", + " # Extract parameters from the message if it contains function parameters\n", + " message = messages[0]\n", + " if hasattr(message, 'function_parameters'):\n", + " return await self.run_impl(**message.function_parameters)\n", + " else:\n", + " return await self.run_impl()\n", + " return await self.run_impl()\n", + " \n", + " async def run_impl(self, **kwargs):\n", + " \"\"\"Each tool will implement this method with their specific logic\"\"\"\n", + " raise NotImplementedError()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Creating Your First Tool: Calculator\n", + " \n", + "Let's create a simple calculator tool. This will help us understand the basic structure of a tool.\n", + "Our calculator can:\n", + "- Add\n", + "- Subtract\n", + "- Multiply\n", + "- Divide\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "# Calculator Tool implementation\n", + "class CalculatorTool(SingleMessageCustomTool):\n", + " \"\"\"A simple calculator tool that can perform basic math operations\"\"\"\n", + " \n", + " def get_name(self) -> str:\n", + " return \"calculator\"\n", + " \n", + " def get_description(self) -> str:\n", + " return \"Perform basic arithmetic operations (add, subtract, multiply, divide)\"\n", + " \n", + " def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:\n", + " return {\n", + " \"operation\": ToolParamDefinitionParam(\n", + " param_type=\"str\",\n", + " description=\"Operation to perform (add, subtract, multiply, divide)\",\n", + " required=True\n", + " ),\n", + " \"x\": ToolParamDefinitionParam(\n", + " param_type=\"float\",\n", + " description=\"First number\",\n", + " required=True\n", + " ),\n", + " \"y\": ToolParamDefinitionParam(\n", + " param_type=\"float\",\n", + " description=\"Second number\",\n", + " required=True\n", + " )\n", + " }\n", + " \n", + " async def run_impl(self, operation: str = None, x: float = None, y: float = None):\n", + " \"\"\"The actual implementation of our calculator\"\"\"\n", + " if not all([operation, x, y]):\n", + " return json.dumps({\"error\": \"Missing required parameters\"})\n", + " \n", + " # Dictionary of math operations\n", + " operations = {\n", + " \"add\": lambda a, b: a + b,\n", + " \"subtract\": lambda a, b: a - b,\n", + " \"multiply\": lambda a, b: a * b,\n", + " \"divide\": lambda a, b: a / b if b != 0 else \"Error: Division by zero\"\n", + " }\n", + " \n", + " # Check if the operation is valid\n", + " if operation not in operations:\n", + " return json.dumps({\"error\": f\"Unknown operation '{operation}'\"})\n", + " \n", + " try:\n", + " # Convert string inputs to float if needed\n", + " x = float(x) if isinstance(x, str) else x\n", + " y = float(y) if isinstance(y, str) else y\n", + " \n", + " # Perform the calculation\n", + " result = operations[operation](x, y)\n", + " return json.dumps({\"result\": result})\n", + " except ValueError:\n", + " return json.dumps({\"error\": \"Invalid number format\"})\n", + " except Exception as e:\n", + " return json.dumps({\"error\": str(e)})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Building a Mock Weather Tool\n", + " \n", + "Now let's create something a bit more complex: a weather tool! \n", + "While this is just a mock version (it doesn't actually fetch real weather data),\n", + "it shows how you might structure a tool that interfaces with an external API." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "class WeatherTool(SingleMessageCustomTool):\n", + " \"async def run_single_query(agent, session_id, query: str):\n", + " \"\"\"Run a single query through our agent with complete interaction cycle\"\"\"\n", + " print(\"\\n\" + \"=\"*50)\n", + " print(f\"🤔 User asks: {query}\")\n", + " print(\"=\"*50)\n", + " \n", + " # Get the initial response and tool call\n", + " response = agent.create_turn(\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": query,\n", + " }\n", + " ],\n", + " session_id=session_id,\n", + " )\n", + " \n", + " # Process all events including tool calls and final response\n", + " async for event in EventLogger().log(response):\n", + " event.print()\n", + " \n", + " # If this was a tool call, we need to create another turn with the result\n", + " if hasattr(event, 'tool_calls') and event.tool_calls:\n", + " tool_call = event.tool_calls[0] # Get the first tool call\n", + " \n", + " # Execute the custom tool\n", + " if tool_call.tool_name in [t.get_name() for t in agent.custom_tools]:\n", + " tool = [t for t in agent.custom_tools if t.get_name() == tool_call.tool_name][0]\n", + " result = await tool.run_impl(**tool_call.arguments)\n", + " \n", + " # Create a follow-up turn with the tool result\n", + " follow_up = agent.create_turn(\n", + " messages=[\n", + " {\n", + " \"role\": \"tool\",\n", + " \"content\": result,\n", + " \"tool_call_id\": tool_call.call_id,\n", + " \"name\": tool_call.tool_name\n", + " }\n", + " ],\n", + " session_id=session_id,\n", + " )\n", + " \n", + " # Process the follow-up response\n", + " async for follow_up_event in EventLogger().log(follow_up):\n", + " follow_up_event.print()\"\"A mock weather tool that simulates getting weather data\"\"\"\n", + " \n", + " def get_name(self) -> str:\n", + " return \"get_weather\"\n", + " \n", + " def get_description(self) -> str:\n", + " return \"Get current weather information for major cities\"\n", + " \n", + " def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:\n", + " return {\n", + " \"city\": ToolParamDefinitionParam(\n", + " param_type=\"str\",\n", + " description=\"Name of the city (e.g., New York, London, Tokyo)\",\n", + " required=True\n", + " ),\n", + " \"date\": ToolParamDefinitionParam(\n", + " param_type=\"str\",\n", + " description=\"Date in YYYY-MM-DD format (optional)\",\n", + " required=False\n", + " )\n", + " }\n", + " \n", + " async def run_impl(self, city: str = None, date: str = None):\n", + " if not city:\n", + " return json.dumps({\"error\": \"City parameter is required\"})\n", + " \n", + " # Mock database of weather information\n", + " weather_data = {\n", + " \"New York\": {\"temp\": 20, \"condition\": \"sunny\"},\n", + " \"London\": {\"temp\": 15, \"condition\": \"rainy\"},\n", + " \"Tokyo\": {\"temp\": 25, \"condition\": \"cloudy\"}\n", + " }\n", + " \n", + " try:\n", + " # Check if we have data for the requested city\n", + " if city not in weather_data:\n", + " return json.dumps({\n", + " \"error\": f\"Sorry! No data available for {city}\",\n", + " \"available_cities\": list(weather_data.keys())\n", + " })\n", + " \n", + " # Return the weather information\n", + " return json.dumps({\n", + " \"city\": city,\n", + " \"date\": date or datetime.now().strftime(\"%Y-%m-%d\"),\n", + " \"data\": weather_data[city]\n", + " })\n", + " except Exception as e:\n", + " return json.dumps({\"error\": str(e)})" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "# ## 5. Setting Up the LlamaStack Agent\n", + "# \n", + "# Now that we have our tools, we need to create an agent that can use them.\n", + "# The agent is like a smart assistant that knows how to use our tools when needed." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "async def setup_agent(host: str = \"localhost\", port: int = 5001):\n", + " \"\"\"Creates and configures our LlamaStack agent\"\"\"\n", + " \n", + " # Create a client to connect to the LlamaStack server\n", + " client = LlamaStackClient(\n", + " base_url=f\"http://{host}:{port}\",\n", + " )\n", + " \n", + " # Configure how we want our agent to behave\n", + " agent_config = AgentConfig(\n", + " model=\"Llama3.1-8B-Instruct\",\n", + " instructions=\"\"\"You are a helpful assistant that can:\n", + " 1. Perform mathematical calculations\n", + " 2. Check weather information\n", + " Always explain your thinking before using a tool.\"\"\",\n", + " \n", + " sampling_params={\n", + " \"strategy\": \"greedy\",\n", + " \"temperature\": 1.0,\n", + " \"top_p\": 0.9,\n", + " },\n", + " \n", + " # List of tools available to the agent\n", + " tools=[\n", + " {\n", + " \"function_name\": \"calculator\",\n", + " \"description\": \"Perform basic arithmetic operations\",\n", + " \"parameters\": {\n", + " \"operation\": {\n", + " \"param_type\": \"str\",\n", + " \"description\": \"Operation to perform (add, subtract, multiply, divide)\",\n", + " \"required\": True,\n", + " },\n", + " \"x\": {\n", + " \"param_type\": \"float\",\n", + " \"description\": \"First number\",\n", + " \"required\": True,\n", + " },\n", + " \"y\": {\n", + " \"param_type\": \"float\",\n", + " \"description\": \"Second number\",\n", + " \"required\": True,\n", + " },\n", + " },\n", + " \"type\": \"function_call\",\n", + " },\n", + " {\n", + " \"function_name\": \"get_weather\",\n", + " \"description\": \"Get weather information for a given city\",\n", + " \"parameters\": {\n", + " \"city\": {\n", + " \"param_type\": \"str\",\n", + " \"description\": \"Name of the city\",\n", + " \"required\": True,\n", + " },\n", + " \"date\": {\n", + " \"param_type\": \"str\",\n", + " \"description\": \"Date in YYYY-MM-DD format\",\n", + " \"required\": False,\n", + " },\n", + " },\n", + " \"type\": \"function_call\",\n", + " },\n", + " ],\n", + " tool_choice=\"auto\",\n", + " # Using standard JSON format for tools\n", + " tool_prompt_format=\"json\", \n", + " input_shields=[],\n", + " output_shields=[],\n", + " enable_session_persistence=False,\n", + " )\n", + " \n", + " # Create our tools\n", + " custom_tools = [CalculatorTool(), WeatherTool()]\n", + " \n", + " # Create the agent\n", + " agent = Agent(client, agent_config, custom_tools)\n", + " session_id = agent.create_session(\"tutorial-session\")\n", + " print(f\"🎉 Created session_id={session_id} for Agent({agent.agent_id})\")\n", + " \n", + " return agent, session_id" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "# ## 6. Running Examples\n", + "# \n", + "# Let's try out our agent with some example questions!\n", + "\n", + "# %%" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "import nest_asyncio\n", + "nest_asyncio.apply() # This allows async operations to work in Jupyter\n", + "\n", + "# %%\n", + "# Initialize the agent\n", + "async def init_agent():\n", + " \"\"\"Initialize our agent - run this first!\"\"\"\n", + " agent, session_id = await setup_agent()\n", + " print(f\"✨ Agent initialized with session {session_id}\")\n", + " return agent, session_id\n", + "\n", + "# %%\n", + "# Function to run a single query\n", + "async def run_single_query(agent, session_id, query: str):\n", + " \"\"\"Run a single query through our agent\"\"\"\n", + " print(\"\\n\" + \"=\"*50)\n", + " print(f\"🤔 User asks: {query}\")\n", + " print(\"=\"*50)\n", + " \n", + " response = agent.create_turn(\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": query,\n", + " }\n", + " ],\n", + " session_id=session_id,\n", + " )\n", + " \n", + " async for log in EventLogger().log(response):\n", + " log.print()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's run everything and see it in action!\n", + "\n", + "Create and run our agent" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🎉 Created session_id=fbe83bb6-bdfd-497c-b920-d7307482d8ba for Agent(3997eeda-4ffd-4b05-9026-28b4da206a11)\n", + "✨ Agent initialized with session fbe83bb6-bdfd-497c-b920-d7307482d8ba\n" + ] + } + ], + "source": [ + "agent, session_id = await init_agent()" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "==================================================\n", + "🤔 User asks: What's 25 plus 17?\n", + "==================================================\n", + "\u001b[30m\u001b[0m\u001b[33minference> \u001b[0m\u001b[36m\u001b[0m\u001b[36m{\"\u001b[0m\u001b[36mtype\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mfunction\u001b[0m\u001b[36m\",\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mname\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mcalculator\u001b[0m\u001b[36m\",\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mparameters\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m {\"\u001b[0m\u001b[36moperation\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36madd\u001b[0m\u001b[36m\",\u001b[0m\u001b[36m \"\u001b[0m\u001b[36my\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36m17\u001b[0m\u001b[36m\",\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mx\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36m25\u001b[0m\u001b[36m\"}}\u001b[0m\u001b[97m\u001b[0m\n" + ] + } + ], + "source": [ + "await run_single_query(agent, session_id, \"What's 25 plus 17?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "==================================================\n", + "🤔 User asks: What's the weather like in Tokyo?\n", + "==================================================\n", + "\u001b[30m\u001b[0m\u001b[33minference> \u001b[0m\u001b[36m\u001b[0m\u001b[36m{\"\u001b[0m\u001b[36mtype\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mfunction\u001b[0m\u001b[36m\",\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mname\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mget\u001b[0m\u001b[36m_weather\u001b[0m\u001b[36m\",\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mparameters\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m {\"\u001b[0m\u001b[36mcity\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mTok\u001b[0m\u001b[36myo\u001b[0m\u001b[36m\"}}\u001b[0m\u001b[97m\u001b[0m\n" + ] + } + ], + "source": [ + "await run_single_query(agent, session_id, \"What's the weather like in Tokyo?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "#fin" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}