Create Tool_Calling101.ipynb

This commit is contained in:
Sanyam Bhutani 2024-11-05 14:43:46 -08:00
parent 863f58ce2f
commit 40793cd8ad

View file

@ -0,0 +1,558 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Getting Started with LlamaStack: Tool Calling Tutorial\n",
"\n",
"Welcome! This notebook will guide you through creating and using custom tools with LlamaStack.\n",
"We'll start with the basics and work our way up to more complex examples.\n",
"\n",
"Table of Contents:\n",
"1. Setup and Installation\n",
"2. Understanding Tool Basics\n",
"3. Creating Your First Tool\n",
"4. Building a Mock Weather Tool\n",
"5. Setting Up the LlamaStack Agent\n",
"6. Running Examples\n",
"7. Next Steps\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Setup\n",
"#### Before we begin, let's import all the required packages:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import asyncio\n",
"import json\n",
"from typing import Dict\n",
"from datetime import datetime"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# LlamaStack specific imports\n",
"from llama_stack_client import LlamaStackClient\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from llama_stack_client.types.tool_param_definition_param import ToolParamDefinitionParam"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Understanding Tool Basics\n",
"\n",
"In LlamaStack, a tool is like a special function that our AI assistant can use. Think of it as giving the AI a new \n",
"capability, like using a calculator or checking the weather.\n",
"\n",
"Every tool needs:\n",
"- A name: What we call the tool\n",
"- A description: What the tool does\n",
"- Parameters: What information the tool needs to work\n",
"- Implementation: The actual code that does the work\n",
"\n",
"Let's create a base class that all our tools will inherit from:"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"class SingleMessageCustomTool:\n",
" \"\"\"Base class for all our custom tools\"\"\"\n",
" \n",
" async def run(self, messages=None):\n",
" \"\"\"\n",
" Main entry point for running the tool\n",
" Args:\n",
" messages: List of messages (can be None for backward compatibility)\n",
" \"\"\"\n",
" if messages and len(messages) > 0:\n",
" # Extract parameters from the message if it contains function parameters\n",
" message = messages[0]\n",
" if hasattr(message, 'function_parameters'):\n",
" return await self.run_impl(**message.function_parameters)\n",
" else:\n",
" return await self.run_impl()\n",
" return await self.run_impl()\n",
" \n",
" async def run_impl(self, **kwargs):\n",
" \"\"\"Each tool will implement this method with their specific logic\"\"\"\n",
" raise NotImplementedError()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Creating Your First Tool: Calculator\n",
" \n",
"Let's create a simple calculator tool. This will help us understand the basic structure of a tool.\n",
"Our calculator can:\n",
"- Add\n",
"- Subtract\n",
"- Multiply\n",
"- Divide\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"# Calculator Tool implementation\n",
"class CalculatorTool(SingleMessageCustomTool):\n",
" \"\"\"A simple calculator tool that can perform basic math operations\"\"\"\n",
" \n",
" def get_name(self) -> str:\n",
" return \"calculator\"\n",
" \n",
" def get_description(self) -> str:\n",
" return \"Perform basic arithmetic operations (add, subtract, multiply, divide)\"\n",
" \n",
" def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:\n",
" return {\n",
" \"operation\": ToolParamDefinitionParam(\n",
" param_type=\"str\",\n",
" description=\"Operation to perform (add, subtract, multiply, divide)\",\n",
" required=True\n",
" ),\n",
" \"x\": ToolParamDefinitionParam(\n",
" param_type=\"float\",\n",
" description=\"First number\",\n",
" required=True\n",
" ),\n",
" \"y\": ToolParamDefinitionParam(\n",
" param_type=\"float\",\n",
" description=\"Second number\",\n",
" required=True\n",
" )\n",
" }\n",
" \n",
" async def run_impl(self, operation: str = None, x: float = None, y: float = None):\n",
" \"\"\"The actual implementation of our calculator\"\"\"\n",
" if not all([operation, x, y]):\n",
" return json.dumps({\"error\": \"Missing required parameters\"})\n",
" \n",
" # Dictionary of math operations\n",
" operations = {\n",
" \"add\": lambda a, b: a + b,\n",
" \"subtract\": lambda a, b: a - b,\n",
" \"multiply\": lambda a, b: a * b,\n",
" \"divide\": lambda a, b: a / b if b != 0 else \"Error: Division by zero\"\n",
" }\n",
" \n",
" # Check if the operation is valid\n",
" if operation not in operations:\n",
" return json.dumps({\"error\": f\"Unknown operation '{operation}'\"})\n",
" \n",
" try:\n",
" # Convert string inputs to float if needed\n",
" x = float(x) if isinstance(x, str) else x\n",
" y = float(y) if isinstance(y, str) else y\n",
" \n",
" # Perform the calculation\n",
" result = operations[operation](x, y)\n",
" return json.dumps({\"result\": result})\n",
" except ValueError:\n",
" return json.dumps({\"error\": \"Invalid number format\"})\n",
" except Exception as e:\n",
" return json.dumps({\"error\": str(e)})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Building a Mock Weather Tool\n",
" \n",
"Now let's create something a bit more complex: a weather tool! \n",
"While this is just a mock version (it doesn't actually fetch real weather data),\n",
"it shows how you might structure a tool that interfaces with an external API."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"class WeatherTool(SingleMessageCustomTool):\n",
" \"async def run_single_query(agent, session_id, query: str):\n",
" \"\"\"Run a single query through our agent with complete interaction cycle\"\"\"\n",
" print(\"\\n\" + \"=\"*50)\n",
" print(f\"🤔 User asks: {query}\")\n",
" print(\"=\"*50)\n",
" \n",
" # Get the initial response and tool call\n",
" response = agent.create_turn(\n",
" messages=[\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": query,\n",
" }\n",
" ],\n",
" session_id=session_id,\n",
" )\n",
" \n",
" # Process all events including tool calls and final response\n",
" async for event in EventLogger().log(response):\n",
" event.print()\n",
" \n",
" # If this was a tool call, we need to create another turn with the result\n",
" if hasattr(event, 'tool_calls') and event.tool_calls:\n",
" tool_call = event.tool_calls[0] # Get the first tool call\n",
" \n",
" # Execute the custom tool\n",
" if tool_call.tool_name in [t.get_name() for t in agent.custom_tools]:\n",
" tool = [t for t in agent.custom_tools if t.get_name() == tool_call.tool_name][0]\n",
" result = await tool.run_impl(**tool_call.arguments)\n",
" \n",
" # Create a follow-up turn with the tool result\n",
" follow_up = agent.create_turn(\n",
" messages=[\n",
" {\n",
" \"role\": \"tool\",\n",
" \"content\": result,\n",
" \"tool_call_id\": tool_call.call_id,\n",
" \"name\": tool_call.tool_name\n",
" }\n",
" ],\n",
" session_id=session_id,\n",
" )\n",
" \n",
" # Process the follow-up response\n",
" async for follow_up_event in EventLogger().log(follow_up):\n",
" follow_up_event.print()\"\"A mock weather tool that simulates getting weather data\"\"\"\n",
" \n",
" def get_name(self) -> str:\n",
" return \"get_weather\"\n",
" \n",
" def get_description(self) -> str:\n",
" return \"Get current weather information for major cities\"\n",
" \n",
" def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:\n",
" return {\n",
" \"city\": ToolParamDefinitionParam(\n",
" param_type=\"str\",\n",
" description=\"Name of the city (e.g., New York, London, Tokyo)\",\n",
" required=True\n",
" ),\n",
" \"date\": ToolParamDefinitionParam(\n",
" param_type=\"str\",\n",
" description=\"Date in YYYY-MM-DD format (optional)\",\n",
" required=False\n",
" )\n",
" }\n",
" \n",
" async def run_impl(self, city: str = None, date: str = None):\n",
" if not city:\n",
" return json.dumps({\"error\": \"City parameter is required\"})\n",
" \n",
" # Mock database of weather information\n",
" weather_data = {\n",
" \"New York\": {\"temp\": 20, \"condition\": \"sunny\"},\n",
" \"London\": {\"temp\": 15, \"condition\": \"rainy\"},\n",
" \"Tokyo\": {\"temp\": 25, \"condition\": \"cloudy\"}\n",
" }\n",
" \n",
" try:\n",
" # Check if we have data for the requested city\n",
" if city not in weather_data:\n",
" return json.dumps({\n",
" \"error\": f\"Sorry! No data available for {city}\",\n",
" \"available_cities\": list(weather_data.keys())\n",
" })\n",
" \n",
" # Return the weather information\n",
" return json.dumps({\n",
" \"city\": city,\n",
" \"date\": date or datetime.now().strftime(\"%Y-%m-%d\"),\n",
" \"data\": weather_data[city]\n",
" })\n",
" except Exception as e:\n",
" return json.dumps({\"error\": str(e)})"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"# ## 5. Setting Up the LlamaStack Agent\n",
"# \n",
"# Now that we have our tools, we need to create an agent that can use them.\n",
"# The agent is like a smart assistant that knows how to use our tools when needed."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"async def setup_agent(host: str = \"localhost\", port: int = 5001):\n",
" \"\"\"Creates and configures our LlamaStack agent\"\"\"\n",
" \n",
" # Create a client to connect to the LlamaStack server\n",
" client = LlamaStackClient(\n",
" base_url=f\"http://{host}:{port}\",\n",
" )\n",
" \n",
" # Configure how we want our agent to behave\n",
" agent_config = AgentConfig(\n",
" model=\"Llama3.1-8B-Instruct\",\n",
" instructions=\"\"\"You are a helpful assistant that can:\n",
" 1. Perform mathematical calculations\n",
" 2. Check weather information\n",
" Always explain your thinking before using a tool.\"\"\",\n",
" \n",
" sampling_params={\n",
" \"strategy\": \"greedy\",\n",
" \"temperature\": 1.0,\n",
" \"top_p\": 0.9,\n",
" },\n",
" \n",
" # List of tools available to the agent\n",
" tools=[\n",
" {\n",
" \"function_name\": \"calculator\",\n",
" \"description\": \"Perform basic arithmetic operations\",\n",
" \"parameters\": {\n",
" \"operation\": {\n",
" \"param_type\": \"str\",\n",
" \"description\": \"Operation to perform (add, subtract, multiply, divide)\",\n",
" \"required\": True,\n",
" },\n",
" \"x\": {\n",
" \"param_type\": \"float\",\n",
" \"description\": \"First number\",\n",
" \"required\": True,\n",
" },\n",
" \"y\": {\n",
" \"param_type\": \"float\",\n",
" \"description\": \"Second number\",\n",
" \"required\": True,\n",
" },\n",
" },\n",
" \"type\": \"function_call\",\n",
" },\n",
" {\n",
" \"function_name\": \"get_weather\",\n",
" \"description\": \"Get weather information for a given city\",\n",
" \"parameters\": {\n",
" \"city\": {\n",
" \"param_type\": \"str\",\n",
" \"description\": \"Name of the city\",\n",
" \"required\": True,\n",
" },\n",
" \"date\": {\n",
" \"param_type\": \"str\",\n",
" \"description\": \"Date in YYYY-MM-DD format\",\n",
" \"required\": False,\n",
" },\n",
" },\n",
" \"type\": \"function_call\",\n",
" },\n",
" ],\n",
" tool_choice=\"auto\",\n",
" # Using standard JSON format for tools\n",
" tool_prompt_format=\"json\", \n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
" )\n",
" \n",
" # Create our tools\n",
" custom_tools = [CalculatorTool(), WeatherTool()]\n",
" \n",
" # Create the agent\n",
" agent = Agent(client, agent_config, custom_tools)\n",
" session_id = agent.create_session(\"tutorial-session\")\n",
" print(f\"🎉 Created session_id={session_id} for Agent({agent.agent_id})\")\n",
" \n",
" return agent, session_id"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"# ## 6. Running Examples\n",
"# \n",
"# Let's try out our agent with some example questions!\n",
"\n",
"# %%"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"import nest_asyncio\n",
"nest_asyncio.apply() # This allows async operations to work in Jupyter\n",
"\n",
"# %%\n",
"# Initialize the agent\n",
"async def init_agent():\n",
" \"\"\"Initialize our agent - run this first!\"\"\"\n",
" agent, session_id = await setup_agent()\n",
" print(f\"✨ Agent initialized with session {session_id}\")\n",
" return agent, session_id\n",
"\n",
"# %%\n",
"# Function to run a single query\n",
"async def run_single_query(agent, session_id, query: str):\n",
" \"\"\"Run a single query through our agent\"\"\"\n",
" print(\"\\n\" + \"=\"*50)\n",
" print(f\"🤔 User asks: {query}\")\n",
" print(\"=\"*50)\n",
" \n",
" response = agent.create_turn(\n",
" messages=[\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": query,\n",
" }\n",
" ],\n",
" session_id=session_id,\n",
" )\n",
" \n",
" async for log in EventLogger().log(response):\n",
" log.print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's run everything and see it in action!\n",
"\n",
"Create and run our agent"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🎉 Created session_id=fbe83bb6-bdfd-497c-b920-d7307482d8ba for Agent(3997eeda-4ffd-4b05-9026-28b4da206a11)\n",
"✨ Agent initialized with session fbe83bb6-bdfd-497c-b920-d7307482d8ba\n"
]
}
],
"source": [
"agent, session_id = await init_agent()"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"==================================================\n",
"🤔 User asks: What's 25 plus 17?\n",
"==================================================\n",
"\u001b[30m\u001b[0m\u001b[33minference> \u001b[0m\u001b[36m\u001b[0m\u001b[36m{\"\u001b[0m\u001b[36mtype\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mfunction\u001b[0m\u001b[36m\",\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mname\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mcalculator\u001b[0m\u001b[36m\",\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mparameters\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m {\"\u001b[0m\u001b[36moperation\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36madd\u001b[0m\u001b[36m\",\u001b[0m\u001b[36m \"\u001b[0m\u001b[36my\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36m17\u001b[0m\u001b[36m\",\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mx\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36m25\u001b[0m\u001b[36m\"}}\u001b[0m\u001b[97m\u001b[0m\n"
]
}
],
"source": [
"await run_single_query(agent, session_id, \"What's 25 plus 17?\")"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"==================================================\n",
"🤔 User asks: What's the weather like in Tokyo?\n",
"==================================================\n",
"\u001b[30m\u001b[0m\u001b[33minference> \u001b[0m\u001b[36m\u001b[0m\u001b[36m{\"\u001b[0m\u001b[36mtype\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mfunction\u001b[0m\u001b[36m\",\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mname\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mget\u001b[0m\u001b[36m_weather\u001b[0m\u001b[36m\",\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mparameters\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m {\"\u001b[0m\u001b[36mcity\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m \"\u001b[0m\u001b[36mTok\u001b[0m\u001b[36myo\u001b[0m\u001b[36m\"}}\u001b[0m\u001b[97m\u001b[0m\n"
]
}
],
"source": [
"await run_single_query(agent, session_id, \"What's the weather like in Tokyo?\")"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"#fin"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}