enhanced tool calling guide

This commit is contained in:
Justin Lee 2024-11-09 16:40:13 -08:00 committed by Justin Lee
parent c6b2ffdc9c
commit 85d312d44c
3 changed files with 733 additions and 779 deletions

View file

@ -1,483 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "LLZwsT_J6OnZ"
},
"source": [
"<a href=\"https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/Tool_Calling101_Using_Together's_Llama_Stack_Server.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ME7IXK4M6Ona"
},
"source": [
"If you'd prefer not to set up a local server, explore this on tool calling with the Together API. This guide will show you how to leverage Together.ai's Llama Stack Server API, allowing you to get started with Llama Stack without the need for a locally built and running server.\n",
"\n",
"## Tool Calling w Together API\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rWl1f1Hc6Onb"
},
"source": [
"In this section, we'll explore how to enhance your applications with tool calling capabilities. We'll cover:\n",
"1. Setting up and using the Brave Search API\n",
"2. Creating custom tools\n",
"3. Configuring tool prompts and safety settings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "sRkJcA_O77hP",
"outputId": "49d33c5c-3300-4dc0-89a6-ff80bfc0bbdf"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting llama-stack-client\n",
" Downloading llama_stack_client-0.0.50-py3-none-any.whl.metadata (13 kB)\n",
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client) (3.7.1)\n",
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client) (1.9.0)\n",
"Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client) (0.27.2)\n",
"Requirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client) (2.9.2)\n",
"Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from llama-stack-client) (1.3.1)\n",
"Requirement already satisfied: tabulate>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client) (0.9.0)\n",
"Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client) (4.12.2)\n",
"Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->llama-stack-client) (3.10)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->llama-stack-client) (1.2.2)\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->llama-stack-client) (2024.8.30)\n",
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->llama-stack-client) (1.0.6)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->llama-stack-client) (0.14.0)\n",
"Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->llama-stack-client) (0.7.0)\n",
"Requirement already satisfied: pydantic-core==2.23.4 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->llama-stack-client) (2.23.4)\n",
"Downloading llama_stack_client-0.0.50-py3-none-any.whl (282 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m283.0/283.0 kB\u001b[0m \u001b[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hInstalling collected packages: llama-stack-client\n",
"Successfully installed llama-stack-client-0.0.50\n"
]
}
],
"source": [
"!pip install llama-stack-client"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T_EW_jV81ldl"
},
"outputs": [],
"source": [
"LLAMA_STACK_API_TOGETHER_URL=\"https://llama-stack.together.ai\"\n",
"LLAMA31_8B_INSTRUCT = \"Llama3.1-8B-Instruct\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "n_QHq45B6Onb"
},
"outputs": [],
"source": [
"import asyncio\n",
"import os\n",
"from typing import Dict, List, Optional\n",
"\n",
"from llama_stack_client import LlamaStackClient\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import (\n",
" AgentConfig,\n",
" AgentConfigToolSearchToolDefinition,\n",
")\n",
"\n",
"# Helper function to create an agent with tools\n",
"async def create_tool_agent(\n",
" client: LlamaStackClient,\n",
" tools: List[Dict],\n",
" instructions: str = \"You are a helpful assistant\",\n",
" model: str = LLAMA31_8B_INSTRUCT\n",
") -> Agent:\n",
" \"\"\"Create an agent with specified tools.\"\"\"\n",
" print(\"Using the following model: \", model)\n",
" agent_config = AgentConfig(\n",
" model=model,\n",
" instructions=instructions,\n",
" sampling_params={\n",
" \"strategy\": \"greedy\",\n",
" \"temperature\": 1.0,\n",
" \"top_p\": 0.9,\n",
" },\n",
" tools=tools,\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"json\",\n",
" enable_session_persistence=True,\n",
" )\n",
"\n",
" return Agent(client, agent_config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iMVYso6_xoDV"
},
"source": [
"Quickly and easily get a free Together.ai API key [here](https://api.together.ai) and replace \"YOUR_TOGETHER_API_KEY\" below with it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3Bjr891C6Onc",
"outputId": "85245ae4-fba4-4ddb-8775-11262ddb1c29"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using the following model: Llama3.1-8B-Instruct\n",
"\n",
"Query: What are the latest developments in quantum computing?\n",
"--------------------------------------------------\n",
"inference> FINDINGS:\n",
"The latest developments in quantum computing involve significant advancements in the field of quantum processors, error correction, and the development of practical applications. Some of the recent breakthroughs include:\n",
"\n",
"* Google's 53-qubit Sycamore processor, which achieved quantum supremacy in 2019 (Source: Google AI Blog, https://ai.googleblog.com/2019/10/experiment-advances-quantum-computing.html)\n",
"* The development of a 100-qubit quantum processor by the Chinese company, Origin Quantum (Source: Physics World, https://physicsworld.com/a/origin-quantum-scales-up-to-100-qubits/)\n",
"* IBM's 127-qubit Eagle processor, which has the potential to perform complex calculations that are currently unsolvable by classical computers (Source: IBM Research Blog, https://www.ibm.com/blogs/research/2020/11/ibm-advances-quantum-computing-research-with-new-127-qubit-processor/)\n",
"* The development of topological quantum computers, which have the potential to solve complex problems in materials science and chemistry (Source: MIT Technology Review, https://www.technologyreview.com/2020/02/24/914776/topological-quantum-computers-are-a-game-changer-for-materials-science/)\n",
"* The development of a new type of quantum error correction code, known as the \"surface code\", which has the potential to solve complex problems in quantum computing (Source: Nature Physics, https://www.nature.com/articles/s41567-021-01314-2)\n",
"\n",
"SOURCES:\n",
"- Google AI Blog: https://ai.googleblog.com/2019/10/experiment-advances-quantum-computing.html\n",
"- Physics World: https://physicsworld.com/a/origin-quantum-scales-up-to-100-qubits/\n",
"- IBM Research Blog: https://www.ibm.com/blogs/research/2020/11/ibm-advances-quantum-computing-research-with-new-127-qubit-processor/\n",
"- MIT Technology Review: https://www.technologyreview.com/2020/02/24/914776/topological-quantum-computers-are-a-game-changer-for-materials-science/\n",
"- Nature Physics: https://www.nature.com/articles/s41567-021-01314-2\n"
]
}
],
"source": [
"# comment this if you don't have a BRAVE_SEARCH_API_KEY\n",
"os.environ[\"BRAVE_SEARCH_API_KEY\"] = 'YOUR_BRAVE_SEARCH_API_KEY'\n",
"\n",
"async def create_search_agent(client: LlamaStackClient) -> Agent:\n",
" \"\"\"Create an agent with Brave Search capability.\"\"\"\n",
"\n",
" # comment this if you don't have a BRAVE_SEARCH_API_KEY\n",
" search_tool = AgentConfigToolSearchToolDefinition(\n",
" type=\"brave_search\",\n",
" engine=\"brave\",\n",
" api_key=os.getenv(\"BRAVE_SEARCH_API_KEY\"),\n",
" )\n",
"\n",
" return await create_tool_agent(\n",
" client=client,\n",
" tools=[search_tool], # set this to [] if you don't have a BRAVE_SEARCH_API_KEY\n",
" model = LLAMA31_8B_INSTRUCT,\n",
" instructions=\"\"\"\n",
" You are a research assistant that can search the web.\n",
" Always cite your sources with URLs when providing information.\n",
" Format your responses as:\n",
"\n",
" FINDINGS:\n",
" [Your summary here]\n",
"\n",
" SOURCES:\n",
" - [Source title](URL)\n",
" \"\"\"\n",
" )\n",
"\n",
"# Example usage\n",
"async def search_example():\n",
" client = LlamaStackClient(base_url=LLAMA_STACK_API_TOGETHER_URL)\n",
" agent = await create_search_agent(client)\n",
"\n",
" # Create a session\n",
" session_id = agent.create_session(\"search-session\")\n",
"\n",
" # Example queries\n",
" queries = [\n",
" \"What are the latest developments in quantum computing?\",\n",
" #\"Who won the most recent Super Bowl?\",\n",
" ]\n",
"\n",
" for query in queries:\n",
" print(f\"\\nQuery: {query}\")\n",
" print(\"-\" * 50)\n",
"\n",
" response = agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": query}],\n",
" session_id=session_id,\n",
" )\n",
"\n",
" async for log in EventLogger().log(response):\n",
" log.print()\n",
"\n",
"# Run the example (in Jupyter, use asyncio.run())\n",
"await search_example()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r3YN6ufb6Onc"
},
"source": [
"## 3. Custom Tool Creation\n",
"\n",
"Let's create a custom weather tool:\n",
"\n",
"#### Key Highlights:\n",
"- **`WeatherTool` Class**: A custom tool that processes weather information requests, supporting location and optional date parameters.\n",
"- **Agent Creation**: The `create_weather_agent` function sets up an agent equipped with the `WeatherTool`, allowing for weather queries in natural language.\n",
"- **Simulation of API Call**: The `run_impl` method simulates fetching weather data. This method can be replaced with an actual API integration for real-world usage.\n",
"- **Interactive Example**: The `weather_example` function shows how to use the agent to handle user queries regarding the weather, providing step-by-step responses."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "A0bOLYGj6Onc",
"outputId": "023a8fb7-49ed-4ab4-e5b7-8050ded5d79a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Query: What's the weather like in San Francisco?\n",
"--------------------------------------------------\n",
"inference> {\n",
" \"function\": \"get_weather\",\n",
" \"parameters\": {\n",
" \"location\": \"San Francisco\"\n",
" }\n",
"}\n",
"\n",
"Query: Tell me the weather in Tokyo tomorrow\n",
"--------------------------------------------------\n",
"inference> {\n",
" \"function\": \"get_weather\",\n",
" \"parameters\": {\n",
" \"location\": \"Tokyo\",\n",
" \"date\": \"tomorrow\"\n",
" }\n",
"}\n"
]
}
],
"source": [
"from typing import TypedDict, Optional, Dict, Any\n",
"from datetime import datetime\n",
"import json\n",
"from llama_stack_client.types.tool_param_definition_param import ToolParamDefinitionParam\n",
"from llama_stack_client.types import CompletionMessage,ToolResponseMessage\n",
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
"\n",
"class WeatherTool(CustomTool):\n",
" \"\"\"Example custom tool for weather information.\"\"\"\n",
"\n",
" def get_name(self) -> str:\n",
" return \"get_weather\"\n",
"\n",
" def get_description(self) -> str:\n",
" return \"Get weather information for a location\"\n",
"\n",
" def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:\n",
" return {\n",
" \"location\": ToolParamDefinitionParam(\n",
" param_type=\"str\",\n",
" description=\"City or location name\",\n",
" required=True\n",
" ),\n",
" \"date\": ToolParamDefinitionParam(\n",
" param_type=\"str\",\n",
" description=\"Optional date (YYYY-MM-DD)\",\n",
" required=False\n",
" )\n",
" }\n",
" async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:\n",
" assert len(messages) == 1, \"Expected single message\"\n",
"\n",
" message = messages[0]\n",
"\n",
" tool_call = message.tool_calls[0]\n",
" # location = tool_call.arguments.get(\"location\", None)\n",
" # date = tool_call.arguments.get(\"date\", None)\n",
" try:\n",
" response = await self.run_impl(**tool_call.arguments)\n",
" response_str = json.dumps(response, ensure_ascii=False)\n",
" except Exception as e:\n",
" response_str = f\"Error when running tool: {e}\"\n",
"\n",
" message = ToolResponseMessage(\n",
" call_id=tool_call.call_id,\n",
" tool_name=tool_call.tool_name,\n",
" content=response_str,\n",
" role=\"ipython\",\n",
" )\n",
" return [message]\n",
"\n",
" async def run_impl(self, location: str, date: Optional[str] = None) -> Dict[str, Any]:\n",
" \"\"\"Simulate getting weather data (replace with actual API call).\"\"\"\n",
" # Mock implementation\n",
" if date:\n",
" return {\n",
" \"temperature\": 90.1,\n",
" \"conditions\": \"sunny\",\n",
" \"humidity\": 40.0\n",
" }\n",
" return {\n",
" \"temperature\": 72.5,\n",
" \"conditions\": \"partly cloudy\",\n",
" \"humidity\": 65.0\n",
" }\n",
"\n",
"\n",
"async def create_weather_agent(client: LlamaStackClient) -> Agent:\n",
" \"\"\"Create an agent with weather tool capability.\"\"\"\n",
"\n",
" agent_config = AgentConfig(\n",
" model=LLAMA31_8B_INSTRUCT,\n",
" #model=model_name,\n",
" instructions=\"\"\"\n",
" You are a weather assistant that can provide weather information.\n",
" Always specify the location clearly in your responses.\n",
" Include both temperature and conditions in your summaries.\n",
" \"\"\",\n",
" sampling_params={\n",
" \"strategy\": \"greedy\",\n",
" \"temperature\": 1.0,\n",
" \"top_p\": 0.9,\n",
" },\n",
" tools=[\n",
" {\n",
" \"function_name\": \"get_weather\",\n",
" \"description\": \"Get weather information for a location\",\n",
" \"parameters\": {\n",
" \"location\": {\n",
" \"param_type\": \"str\",\n",
" \"description\": \"City or location name\",\n",
" \"required\": True,\n",
" },\n",
" \"date\": {\n",
" \"param_type\": \"str\",\n",
" \"description\": \"Optional date (YYYY-MM-DD)\",\n",
" \"required\": False,\n",
" },\n",
" },\n",
" \"type\": \"function_call\",\n",
" }\n",
" ],\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"json\",\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=True\n",
" )\n",
"\n",
" # Create the agent with the tool\n",
" weather_tool = WeatherTool()\n",
" agent = Agent(\n",
" client=client,\n",
" agent_config=agent_config,\n",
" custom_tools=[weather_tool]\n",
" )\n",
"\n",
" return agent\n",
"\n",
"# Example usage\n",
"async def weather_example():\n",
" client = LlamaStackClient(base_url=LLAMA_STACK_API_TOGETHER_URL)\n",
" agent = await create_weather_agent(client)\n",
" session_id = agent.create_session(\"weather-session\")\n",
"\n",
" queries = [\n",
" \"What's the weather like in San Francisco?\",\n",
" \"Tell me the weather in Tokyo tomorrow\",\n",
" ]\n",
"\n",
" for query in queries:\n",
" print(f\"\\nQuery: {query}\")\n",
" print(\"-\" * 50)\n",
"\n",
" response = agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": query}],\n",
" session_id=session_id,\n",
" )\n",
"\n",
" async for log in EventLogger().log(response):\n",
" log.print()\n",
"\n",
"# For Jupyter notebooks\n",
"import nest_asyncio\n",
"nest_asyncio.apply()\n",
"\n",
"# Run the example\n",
"await weather_example()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yKhUkVNq6Onc"
},
"source": [
"Thanks for checking out this tutorial, hopefully you can now automate everything with Llama! :D\n",
"\n",
"Next up, we learn another hot topic of LLMs: Memory and Rag. Continue learning [here](./04_Memory101.ipynb)!"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View file

@ -0,0 +1,378 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "7a1ac883",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/Tool_Calling101_With_Together_Llama_Stack_Server.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
"\n",
"If you'd prefer not to set up a local server, explore this on tool calling with the Together API. This guide will show you how to leverage Together.ai's Llama Stack Server API, allowing you to get started with Llama Stack without the need for a locally built and running server.\n",
"\n",
"## Creating a Custom Tool and Agent Tool Calling with Together API\n"
]
},
{
"cell_type": "markdown",
"id": "d3d3ec91",
"metadata": {},
"source": [
"## Step 1: Import Necessary Packages and import api keys"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2fbe7011",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import requests\n",
"import json\n",
"import asyncio\n",
"import nest_asyncio\n",
"from typing import Dict, List\n",
"from dotenv import load_dotenv\n",
"from llama_stack_client import LlamaStackClient\n",
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
"from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n",
"from llama_stack_client.types import CompletionMessage\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"\n",
"# Allow asyncio to run in Jupyter Notebook\n",
"nest_asyncio.apply()\n",
"\n",
"LLAMA_STACK_API_TOGETHER_URL=\"https://llama-stack.together.ai\""
]
},
{
"cell_type": "markdown",
"id": "ac6042d8",
"metadata": {},
"source": [
"Create a `.env` file and add you brave api key\n",
"\n",
"`BRAVE_SEARCH_API_KEY = \"YOUR_BRAVE_API_KEY_HERE\"`\n",
"\n",
"Now load the `.env` file into your jupyter notebook."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b4b3300c",
"metadata": {},
"outputs": [],
"source": [
"load_dotenv()\n",
"BRAVE_SEARCH_API_KEY = os.environ['BRAVE_SEARCH_API_KEY']"
]
},
{
"cell_type": "markdown",
"id": "c838bb40",
"metadata": {},
"source": [
"## Step 2: Create a class for the Brave Search API integration\n",
"\n",
"Let's create the `BraveSearch` class, which encapsulates the logic for making web search queries using the Brave Search API and formatting the response. The class includes methods for sending requests, processing results, and extracting relevant data to support the integration with an AI toolchain."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "62271ed2",
"metadata": {},
"outputs": [],
"source": [
"class BraveSearch:\n",
" def __init__(self, api_key: str) -> None:\n",
" self.api_key = api_key\n",
"\n",
" async def search(self, query: str) -> str:\n",
" url = \"https://api.search.brave.com/res/v1/web/search\"\n",
" headers = {\n",
" \"X-Subscription-Token\": self.api_key,\n",
" \"Accept-Encoding\": \"gzip\",\n",
" \"Accept\": \"application/json\",\n",
" }\n",
" payload = {\"q\": query}\n",
" response = requests.get(url=url, params=payload, headers=headers)\n",
" return json.dumps(self._clean_brave_response(response.json()))\n",
"\n",
" def _clean_brave_response(self, search_response, top_k=3):\n",
" query = search_response.get(\"query\", {}).get(\"original\", None)\n",
" clean_response = []\n",
" mixed_results = search_response.get(\"mixed\", {}).get(\"main\", [])[:top_k]\n",
"\n",
" for m in mixed_results:\n",
" r_type = m[\"type\"]\n",
" results = search_response.get(r_type, {}).get(\"results\", [])\n",
" if r_type == \"web\" and results:\n",
" idx = m[\"index\"]\n",
" selected_keys = [\"title\", \"url\", \"description\"]\n",
" cleaned = {k: v for k, v in results[idx].items() if k in selected_keys}\n",
" clean_response.append(cleaned)\n",
"\n",
" return {\"query\": query, \"top_k\": clean_response}"
]
},
{
"cell_type": "markdown",
"id": "d987d48f",
"metadata": {},
"source": [
"## Step 3: Create a custom tool class for integration with Llama Stack\n",
"\n",
"Here, we defines the `WebSearchTool` class, which extends `CustomTool` to integrate the Brave Search API with Llama Stack, enabling web search capabilities within AI workflows. The class handles incoming user queries, interacts with the `BraveSearch` class for data retrieval, and formats results for effective response generation."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "92e75cf8",
"metadata": {},
"outputs": [],
"source": [
"class WebSearchTool(CustomTool):\n",
" def __init__(self, api_key: str):\n",
" self.api_key = api_key\n",
" self.engine = BraveSearch(api_key)\n",
"\n",
" def get_name(self) -> str:\n",
" return \"web_search\"\n",
"\n",
" def get_description(self) -> str:\n",
" return \"Search the web for a given query\"\n",
"\n",
" async def run_impl(self, query: str):\n",
" return await self.engine.search(query)\n",
"\n",
" async def run(self, messages):\n",
" query = None\n",
" for message in messages:\n",
" if isinstance(message, CompletionMessage) and message.tool_calls:\n",
" for tool_call in message.tool_calls:\n",
" if 'query' in tool_call.arguments:\n",
" query = tool_call.arguments['query']\n",
" call_id = tool_call.call_id\n",
"\n",
" if query:\n",
" search_result = await self.run_impl(query)\n",
" return [ToolResponseMessage(\n",
" call_id=call_id,\n",
" role=\"ipython\",\n",
" content=self._format_response_for_agent(search_result),\n",
" tool_name=\"brave_search\"\n",
" )]\n",
"\n",
" return [ToolResponseMessage(\n",
" call_id=\"no_call_id\",\n",
" role=\"ipython\",\n",
" content=\"No query provided.\",\n",
" tool_name=\"brave_search\"\n",
" )]\n",
"\n",
" def _format_response_for_agent(self, search_result):\n",
" parsed_result = json.loads(search_result)\n",
" formatted_result = \"Search Results with Citations:\\n\\n\"\n",
" for i, result in enumerate(parsed_result.get(\"top_k\", []), start=1):\n",
" formatted_result += (\n",
" f\"{i}. {result.get('title', 'No Title')}\\n\"\n",
" f\" URL: {result.get('url', 'No URL')}\\n\"\n",
" f\" Description: {result.get('description', 'No Description')}\\n\\n\"\n",
" )\n",
" return formatted_result"
]
},
{
"cell_type": "markdown",
"id": "f282a9bd",
"metadata": {},
"source": [
"## Step 4: Create a function to execute a search query and print the results\n",
"\n",
"Now let's create the `execute_search` function, which initializes the `WebSearchTool`, runs a query asynchronously, and prints the formatted search results for easy viewing."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "aaf5664f",
"metadata": {},
"outputs": [],
"source": [
"async def execute_search(query: str):\n",
" web_search_tool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
" result = await web_search_tool.run_impl(query)\n",
" print(\"Search Results:\", result)"
]
},
{
"cell_type": "markdown",
"id": "7cc3a039",
"metadata": {},
"source": [
"## Step 5: Run the search with an example query"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "5f22c4e2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Search Results: {\"query\": \"Latest developments in quantum computing\", \"top_k\": [{\"title\": \"Quantum Computing | Latest News, Photos & Videos | WIRED\", \"url\": \"https://www.wired.com/tag/quantum-computing/\", \"description\": \"Find the <strong>latest</strong> <strong>Quantum</strong> <strong>Computing</strong> news from WIRED. See related science and technology articles, photos, slideshows and videos.\"}, {\"title\": \"Quantum Computing News -- ScienceDaily\", \"url\": \"https://www.sciencedaily.com/news/matter_energy/quantum_computing/\", \"description\": \"<strong>Quantum</strong> <strong>Computing</strong> News. Read the <strong>latest</strong> about the <strong>development</strong> <strong>of</strong> <strong>quantum</strong> <strong>computers</strong>.\"}]}\n"
]
}
],
"source": [
"query = \"Latest developments in quantum computing\"\n",
"asyncio.run(execute_search(query))"
]
},
{
"cell_type": "markdown",
"id": "ea58f265-dfd7-4935-ae5e-6f3a6d74d805",
"metadata": {},
"source": [
"## Step 6: Run the search tool using an agent\n",
"\n",
"Here, we setup and execute the `WebSearchTool` within an agent configuration in Llama Stack to handle user queries and generate responses. This involves initializing the client, configuring the agent with tool capabilities, and processing user prompts asynchronously to display results."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "9e704b01-f410-492f-8baf-992589b82803",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Created session_id=e4062b71-8034-40a8-8124-3410d9ab8ddc for Agent(27ab94ce-cfe8-4af5-a71b-4e8f4bef6434)\n",
"\u001b[30m\u001b[0m\u001b[35mshield_call> No Violation\u001b[0m\n",
"\u001b[33minference> \u001b[0m\u001b[33m[\u001b[0m\u001b[33mweb\u001b[0m\u001b[33m_search\u001b[0m\u001b[33m(query\u001b[0m\u001b[33m='\u001b[0m\u001b[33mlatest\u001b[0m\u001b[33m developments\u001b[0m\u001b[33m in\u001b[0m\u001b[33m quantum\u001b[0m\u001b[33m computing\u001b[0m\u001b[33m')]\u001b[0m\u001b[97m\u001b[0m\n",
"\u001b[35mshield_call> No Violation\u001b[0m\n",
"\u001b[32mCustomTool> Search Results with Citations:\n",
"\n",
"1. Quantum Computing | Latest News, Photos & Videos | WIRED\n",
" URL: https://www.wired.com/tag/quantum-computing/\n",
" Description: Find the <strong>latest</strong> <strong>Quantum</strong> <strong>Computing</strong> news from WIRED. See related science and technology articles, photos, slideshows and videos.\n",
"\n",
"2. Quantum Computing News -- ScienceDaily\n",
" URL: https://www.sciencedaily.com/news/matter_energy/quantum_computing/\n",
" Description: <strong>Quantum</strong> <strong>Computing</strong> News. Read the <strong>latest</strong> about the <strong>development</strong> <strong>of</strong> <strong>quantum</strong> <strong>computers</strong>.\n",
"\n",
"\u001b[0m\n"
]
}
],
"source": [
"async def run_main(url=LLAMA_STACK_API_TOGETHER_URL, disable_safety: bool = False):\n",
" # Initialize the Llama Stack client with the specified base URL\n",
" client = LlamaStackClient(\n",
" base_url=url,\n",
" )\n",
"\n",
" # Configure input and output shields for safety (use \"llama_guard\" by default)\n",
" input_shields = [] if disable_safety else [\"llama_guard\"]\n",
" output_shields = [] if disable_safety else [\"llama_guard\"]\n",
"\n",
" # Define the agent configuration, including the model and tool setup\n",
" agent_config = AgentConfig(\n",
" model=\"Llama3.2-3B-Instruct\", \n",
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
" sampling_params={\n",
" \"strategy\": \"greedy\", \n",
" \"temperature\": 1.0, \n",
" \"top_p\": 0.9, \n",
" },\n",
" tools=[\n",
" {\n",
" \"function_name\": \"web_search\", # Name of the tool being integrated\n",
" \"description\": \"Search the web for a given query\",\n",
" \"parameters\": {\n",
" \"query\": {\n",
" \"param_type\": \"str\", \n",
" \"description\": \"The query to search for\",\n",
" \"required\": True, \n",
" }\n",
" },\n",
" \"type\": \"function_call\", \n",
" },\n",
" ],\n",
" tool_choice=\"auto\", \n",
" tool_prompt_format=\"python_list\", \n",
" input_shields=input_shields,\n",
" output_shields=output_shields,\n",
" enable_session_persistence=False, \n",
" )\n",
" \n",
" # Initialize custom tools (ensure `WebSearchTool` is defined earlier in the notebook)\n",
" custom_tools = [WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)]\n",
"\n",
" # Create an agent instance with the client and configuration\n",
" agent = Agent(client, agent_config, custom_tools)\n",
" \n",
" # Create a session for interaction and print the session ID\n",
" session_id = agent.create_session(\"test-session\")\n",
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",
"\n",
" response = agent.create_turn(\n",
" messages=[\n",
" {\n",
" \"role\": \"user\", \n",
" \"content\": \"\"\"What are the latest developments in quantum computing?\"\"\", \n",
" }\n",
" ],\n",
" session_id=session_id, # Use the created session ID\n",
" )\n",
"\n",
" # Log and print the response from the agent asynchronously\n",
" async for log in EventLogger().log(response):\n",
" log.print()\n",
"\n",
"# Run the function asynchronously in a Jupyter Notebook cell\n",
"await run_main(disable_safety=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9de15b8d-b12e-4abc-9c30-9f1711a9a215",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}