mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-24 00:47:00 +00:00
Some checks failed
Integration Tests / test-matrix (http, 3.12, agents) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, inspect) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.12, scoring) (push) Failing after 12s
Integration Tests / test-matrix (http, 3.12, providers) (push) Failing after 14s
Integration Tests / test-matrix (http, 3.12, inference) (push) Failing after 16s
Integration Tests / test-matrix (http, 3.12, tool_runtime) (push) Failing after 11s
Integration Tests / test-matrix (http, 3.12, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.12, post_training) (push) Failing after 17s
Integration Tests / test-matrix (http, 3.13, agents) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, datasets) (push) Failing after 19s
Integration Tests / test-matrix (http, 3.13, inference) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.13, datasets) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.13, inspect) (push) Failing after 15s
Integration Tests / test-matrix (http, 3.13, providers) (push) Failing after 12s
Integration Tests / test-matrix (http, 3.13, scoring) (push) Failing after 21s
Integration Tests / test-matrix (http, 3.13, post_training) (push) Failing after 23s
Integration Tests / test-matrix (http, 3.13, tool_runtime) (push) Failing after 19s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 49s
Integration Tests / test-matrix (http, 3.13, vector_io) (push) Failing after 18s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 16s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.13, agents) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.13, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, inference) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.13, inspect) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.13, post_training) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.13, providers) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.13, scoring) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.13, tool_runtime) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.13, vector_io) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 6s
Test Llama Stack Build / generate-matrix (push) Successful in 6s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 13s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 11s
Python Package Build Test / build (3.13) (push) Failing after 3s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Test Llama Stack Build / build (push) Failing after 5s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Update ReadTheDocs / update-readthedocs (push) Failing after 7s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 36s
Python Package Build Test / build (3.12) (push) Failing after 33s
Test Llama Stack Build / build-single-provider (push) Failing after 37s
Test External Providers / test-external-providers (venv) (push) Failing after 32s
Pre-commit / pre-commit (push) Successful in 1m24s
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> - update REAMDE.md format and python version - update package name: CustomTool was renamed to ClientTool in https://github.com/meta-llama/llama-stack-client-python/pull/73 <!-- If resolving an issue, uncomment and update the line below --> Closes #2556 ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Signed-off-by: Wen Zhou <wenzhou@redhat.com>
361 lines
15 KiB
Text
361 lines
15 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7a1ac883",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Tool Calling\n",
|
|
"\n",
|
|
"\n",
|
|
"## Creating a Custom Tool and Agent Tool Calling\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d3d3ec91",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 1: Import Necessary Packages and Api Keys"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "2fbe7011",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import asyncio\n",
|
|
"import json\n",
|
|
"import os\n",
|
|
"from typing import Dict, List\n",
|
|
"\n",
|
|
"import nest_asyncio\n",
|
|
"import requests\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"from llama_stack_client import LlamaStackClient\n",
|
|
"from llama_stack_client.lib.agents.agent import Agent\n",
|
|
"from llama_stack_client.lib.agents.client_tool import ClientTool\n",
|
|
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
|
"from llama_stack_client.types import CompletionMessage\n",
|
|
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
|
"from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n",
|
|
"\n",
|
|
"# Allow asyncio to run in Jupyter Notebook\n",
|
|
"nest_asyncio.apply()\n",
|
|
"\n",
|
|
"HOST = \"localhost\"\n",
|
|
"PORT = 8321\n",
|
|
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ac6042d8",
|
|
"metadata": {},
|
|
"source": [
|
|
"Create a `.env` file and add you brave api key\n",
|
|
"\n",
|
|
"`BRAVE_SEARCH_API_KEY = \"YOUR_BRAVE_API_KEY_HERE\"`\n",
|
|
"\n",
|
|
"Now load the `.env` file into your jupyter notebook."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "b4b3300c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"load_dotenv()\n",
|
|
"BRAVE_SEARCH_API_KEY = os.environ[\"BRAVE_SEARCH_API_KEY\"]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c838bb40",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 2: Create a class for the Brave Search API integration\n",
|
|
"\n",
|
|
"Let's create the `BraveSearch` class, which encapsulates the logic for making web search queries using the Brave Search API and formatting the response. The class includes methods for sending requests, processing results, and extracting relevant data to support the integration with an AI toolchain."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "62271ed2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class BraveSearch:\n",
|
|
" def __init__(self, api_key: str) -> None:\n",
|
|
" self.api_key = api_key\n",
|
|
"\n",
|
|
" async def search(self, query: str) -> str:\n",
|
|
" url = \"https://api.search.brave.com/res/v1/web/search\"\n",
|
|
" headers = {\n",
|
|
" \"X-Subscription-Token\": self.api_key,\n",
|
|
" \"Accept-Encoding\": \"gzip\",\n",
|
|
" \"Accept\": \"application/json\",\n",
|
|
" }\n",
|
|
" payload = {\"q\": query}\n",
|
|
" response = requests.get(url=url, params=payload, headers=headers)\n",
|
|
" return json.dumps(self._clean_brave_response(response.json()))\n",
|
|
"\n",
|
|
" def _clean_brave_response(self, search_response, top_k=3):\n",
|
|
" query = search_response.get(\"query\", {}).get(\"original\", None)\n",
|
|
" clean_response = []\n",
|
|
" mixed_results = search_response.get(\"mixed\", {}).get(\"main\", [])[:top_k]\n",
|
|
"\n",
|
|
" for m in mixed_results:\n",
|
|
" r_type = m[\"type\"]\n",
|
|
" results = search_response.get(r_type, {}).get(\"results\", [])\n",
|
|
" if r_type == \"web\" and results:\n",
|
|
" idx = m[\"index\"]\n",
|
|
" selected_keys = [\"title\", \"url\", \"description\"]\n",
|
|
" cleaned = {k: v for k, v in results[idx].items() if k in selected_keys}\n",
|
|
" clean_response.append(cleaned)\n",
|
|
"\n",
|
|
" return {\"query\": query, \"top_k\": clean_response}\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d987d48f",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 3: Create a Custom Tool Class\n",
|
|
"\n",
|
|
"Here, we defines the `WebSearchTool` class, which extends `ClientTool` to integrate the Brave Search API with Llama Stack, enabling web search capabilities within AI workflows. The class handles incoming user queries, interacts with the `BraveSearch` class for data retrieval, and formats results for effective response generation."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "92e75cf8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class WebSearchTool(ClientTool):\n",
|
|
" def __init__(self, api_key: str):\n",
|
|
" self.api_key = api_key\n",
|
|
" self.engine = BraveSearch(api_key)\n",
|
|
"\n",
|
|
" def get_name(self) -> str:\n",
|
|
" return \"web_search\"\n",
|
|
"\n",
|
|
" def get_description(self) -> str:\n",
|
|
" return \"Search the web for a given query\"\n",
|
|
"\n",
|
|
" async def run_impl(self, query: str):\n",
|
|
" return await self.engine.search(query)\n",
|
|
"\n",
|
|
" async def run(self, messages):\n",
|
|
" query = None\n",
|
|
" for message in messages:\n",
|
|
" if isinstance(message, CompletionMessage) and message.tool_calls:\n",
|
|
" for tool_call in message.tool_calls:\n",
|
|
" if \"query\" in tool_call.arguments:\n",
|
|
" query = tool_call.arguments[\"query\"]\n",
|
|
" call_id = tool_call.call_id\n",
|
|
"\n",
|
|
" if query:\n",
|
|
" search_result = await self.run_impl(query)\n",
|
|
" return [\n",
|
|
" ToolResponseMessage(\n",
|
|
" call_id=call_id,\n",
|
|
" role=\"ipython\",\n",
|
|
" content=self._format_response_for_agent(search_result),\n",
|
|
" tool_name=\"brave_search\",\n",
|
|
" )\n",
|
|
" ]\n",
|
|
"\n",
|
|
" return [\n",
|
|
" ToolResponseMessage(\n",
|
|
" call_id=\"no_call_id\",\n",
|
|
" role=\"ipython\",\n",
|
|
" content=\"No query provided.\",\n",
|
|
" tool_name=\"brave_search\",\n",
|
|
" )\n",
|
|
" ]\n",
|
|
"\n",
|
|
" def _format_response_for_agent(self, search_result):\n",
|
|
" parsed_result = json.loads(search_result)\n",
|
|
" formatted_result = \"Search Results with Citations:\\n\\n\"\n",
|
|
" for i, result in enumerate(parsed_result.get(\"top_k\", []), start=1):\n",
|
|
" formatted_result += (\n",
|
|
" f\"{i}. {result.get('title', 'No Title')}\\n\"\n",
|
|
" f\" URL: {result.get('url', 'No URL')}\\n\"\n",
|
|
" f\" Description: {result.get('description', 'No Description')}\\n\\n\"\n",
|
|
" )\n",
|
|
" return formatted_result\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f282a9bd",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 4: Create a function to execute a search query and print the results\n",
|
|
"\n",
|
|
"Now let's create the `execute_search` function, which initializes the `WebSearchTool`, runs a query asynchronously, and prints the formatted search results for easy viewing."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "aaf5664f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"async def execute_search(query: str):\n",
|
|
" web_search_tool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
|
" result = await web_search_tool.run_impl(query)\n",
|
|
" print(\"Search Results:\", result)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7cc3a039",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 5: Run the search with an example query"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "5f22c4e2",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Search Results: {\"query\": \"Latest developments in quantum computing\", \"top_k\": [{\"title\": \"Quantum Computing | Latest News, Photos & Videos | WIRED\", \"url\": \"https://www.wired.com/tag/quantum-computing/\", \"description\": \"Find the <strong>latest</strong> <strong>Quantum</strong> <strong>Computing</strong> news from WIRED. See related science and technology articles, photos, slideshows and videos.\"}, {\"title\": \"Quantum Computing News -- ScienceDaily\", \"url\": \"https://www.sciencedaily.com/news/matter_energy/quantum_computing/\", \"description\": \"<strong>Quantum</strong> <strong>Computing</strong> News. Read the <strong>latest</strong> about the <strong>development</strong> <strong>of</strong> <strong>quantum</strong> <strong>computers</strong>.\"}]}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"query = \"Latest developments in quantum computing\"\n",
|
|
"asyncio.run(execute_search(query))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ea58f265-dfd7-4935-ae5e-6f3a6d74d805",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 6: Run the search tool using an agent\n",
|
|
"\n",
|
|
"Here, we setup and execute the `WebSearchTool` within an agent configuration in Llama Stack to handle user queries and generate responses. This involves initializing the client, configuring the agent with tool capabilities, and processing user prompts asynchronously to display results."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "9e704b01-f410-492f-8baf-992589b82803",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Created session_id=34d2978d-e299-4a2a-9219-4ffe2fb124a2 for Agent(8a68f2c3-2b2a-4f67-a355-c6d5b2451d6a)\n",
|
|
"\u001b[30m\u001b[0m\u001b[33minference> \u001b[0m\u001b[33m[\u001b[0m\u001b[33mweb\u001b[0m\u001b[33m_search\u001b[0m\u001b[33m(query\u001b[0m\u001b[33m=\"\u001b[0m\u001b[33mlatest\u001b[0m\u001b[33m developments\u001b[0m\u001b[33m in\u001b[0m\u001b[33m quantum\u001b[0m\u001b[33m computing\u001b[0m\u001b[33m\")]\u001b[0m\u001b[97m\u001b[0m\n",
|
|
"\u001b[32mCustomTool> Search Results with Citations:\n",
|
|
"\n",
|
|
"1. Quantum Computing | Latest News, Photos & Videos | WIRED\n",
|
|
" URL: https://www.wired.com/tag/quantum-computing/\n",
|
|
" Description: Find the <strong>latest</strong> <strong>Quantum</strong> <strong>Computing</strong> news from WIRED. See related science and technology articles, photos, slideshows and videos.\n",
|
|
"\n",
|
|
"2. Quantum Computing News -- ScienceDaily\n",
|
|
" URL: https://www.sciencedaily.com/news/matter_energy/quantum_computing/\n",
|
|
" Description: <strong>Quantum</strong> <strong>Computing</strong> News. Read the <strong>latest</strong> about the <strong>development</strong> <strong>of</strong> <strong>quantum</strong> <strong>computers</strong>.\n",
|
|
"\n",
|
|
"\u001b[0m\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"async def run_main(disable_safety: bool = False):\n",
|
|
" # Initialize the Llama Stack client with the specified base URL\n",
|
|
" client = LlamaStackClient(\n",
|
|
" base_url=f\"http://{HOST}:{PORT}\",\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Configure input and output shields for safety (use \"llama_guard\" by default)\n",
|
|
" input_shields = [] if disable_safety else [\"llama_guard\"]\n",
|
|
" output_shields = [] if disable_safety else [\"llama_guard\"]\n",
|
|
"\n",
|
|
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
|
|
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
|
"\n",
|
|
" # Create an agent instance with the client and configuration\n",
|
|
" agent = Agent(\n",
|
|
" client,\n",
|
|
" model=MODEL_NAME,\n",
|
|
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
|
|
" sampling_params={\n",
|
|
" \"strategy\": {\n",
|
|
" \"type\": \"greedy\",\n",
|
|
" },\n",
|
|
" },\n",
|
|
" tools=[webSearchTool],\n",
|
|
" input_shields=input_shields,\n",
|
|
" output_shields=output_shields,\n",
|
|
" enable_session_persistence=False,\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Create a session for interaction and print the session ID\n",
|
|
" session_id = agent.create_session(\"test-session\")\n",
|
|
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",
|
|
"\n",
|
|
" response = agent.create_turn(\n",
|
|
" messages=[\n",
|
|
" {\n",
|
|
" \"role\": \"user\",\n",
|
|
" \"content\": \"\"\"What are the latest developments in quantum computing?\"\"\",\n",
|
|
" }\n",
|
|
" ],\n",
|
|
" session_id=session_id, # Use the created session ID\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Log and print the response from the agent asynchronously\n",
|
|
" async for log in EventLogger().log(response):\n",
|
|
" log.print()\n",
|
|
"\n",
|
|
"\n",
|
|
"# Run the function asynchronously in a Jupyter Notebook cell\n",
|
|
"await run_main(disable_safety=True)\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"fileHeader": "",
|
|
"fileUid": "f0abbf6d-ed52-40ad-afb4-f5ec99130249",
|
|
"isAdHoc": false,
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.15"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|