mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? This PR kills the notion of "ShieldType". The impetus for this is the realization: > Why is keyword llama-guard appearing so many times everywhere, sometimes with hyphens, sometimes with underscores? Now that we have a notion of "provider specific resource identifiers" and "user specific aliases" for those and the fact that this works with models ("Llama3.1-8B-Instruct" <> "fireworks/llama-3pv1-..."), we can follow the same rules for Shields. So each Safety provider can make up a notion of identifiers it has registered. This already happens with Bedrock correctly. We just generalize it for Llama Guard, Prompt Guard, etc. For Llama Guard, we further simplify by just adopting the underlying model name itself as the identifier! No confusion necessary. While doing this, I noticed a bug in our DistributionRegistry where we weren't scoping identifiers by type. Fixed. ## Feature/Issue validation/testing/test plan Ran (inference, safety, memory, agents) tests with ollama and fireworks providers.
796 lines
28 KiB
Text
796 lines
28 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
" let's explore how to have a conversation about images using the Memory API! This section will show you how to:\n",
|
|
"1. Load and prepare images for the API\n",
|
|
"2. Send image-based queries\n",
|
|
"3. Create an interactive chat loop with images\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import asyncio\n",
|
|
"import base64\n",
|
|
"import mimetypes\n",
|
|
"from pathlib import Path\n",
|
|
"from typing import Optional, Union\n",
|
|
"\n",
|
|
"from llama_stack_client import LlamaStackClient\n",
|
|
"from llama_stack_client.types import UserMessage\n",
|
|
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
|
|
"from termcolor import cprint\n",
|
|
"\n",
|
|
"# Helper function to convert image to data URL\n",
|
|
"def image_to_data_url(file_path: Union[str, Path]) -> str:\n",
|
|
" \"\"\"Convert an image file to a data URL format.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" file_path: Path to the image file\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" str: Data URL containing the encoded image\n",
|
|
" \"\"\"\n",
|
|
" file_path = Path(file_path)\n",
|
|
" if not file_path.exists():\n",
|
|
" raise FileNotFoundError(f\"Image not found: {file_path}\")\n",
|
|
"\n",
|
|
" mime_type, _ = mimetypes.guess_type(str(file_path))\n",
|
|
" if mime_type is None:\n",
|
|
" raise ValueError(\"Could not determine MIME type of the image\")\n",
|
|
"\n",
|
|
" with open(file_path, \"rb\") as image_file:\n",
|
|
" encoded_string = base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
|
|
"\n",
|
|
" return f\"data:{mime_type};base64,{encoded_string}\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 2. Create an Interactive Image Chat\n",
|
|
"\n",
|
|
"Let's create a function that enables back-and-forth conversation about an image:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from IPython.display import Image, display\n",
|
|
"import ipywidgets as widgets\n",
|
|
"\n",
|
|
"# Display the image we'll be chatting about\n",
|
|
"image_path = \"your_image.jpg\" # Replace with your image path\n",
|
|
"display(Image(filename=image_path))\n",
|
|
"\n",
|
|
"# Initialize the client\n",
|
|
"client = LlamaStackClient(\n",
|
|
" base_url=f\"http://localhost:8000\", # Adjust host/port as needed\n",
|
|
")\n",
|
|
"\n",
|
|
"# Create chat interface\n",
|
|
"output = widgets.Output()\n",
|
|
"text_input = widgets.Text(\n",
|
|
" value='',\n",
|
|
" placeholder='Type your question about the image...',\n",
|
|
" description='Ask:',\n",
|
|
" disabled=False\n",
|
|
")\n",
|
|
"\n",
|
|
"# Display interface\n",
|
|
"display(text_input, output)\n",
|
|
"\n",
|
|
"# Handle chat interaction\n",
|
|
"async def on_submit(change):\n",
|
|
" with output:\n",
|
|
" question = text_input.value\n",
|
|
" if question.lower() == 'exit':\n",
|
|
" print(\"Chat ended.\")\n",
|
|
" return\n",
|
|
"\n",
|
|
" message = UserMessage(\n",
|
|
" role=\"user\",\n",
|
|
" content=[\n",
|
|
" {\"image\": {\"uri\": image_to_data_url(image_path)}},\n",
|
|
" question,\n",
|
|
" ],\n",
|
|
" )\n",
|
|
"\n",
|
|
" print(f\"\\nUser> {question}\")\n",
|
|
" response = client.inference.chat_completion(\n",
|
|
" messages=[message],\n",
|
|
" model=\"Llama3.2-11B-Vision-Instruct\",\n",
|
|
" stream=True,\n",
|
|
" )\n",
|
|
"\n",
|
|
" print(\"Assistant> \", end='')\n",
|
|
" async for log in EventLogger().log(response):\n",
|
|
" log.print()\n",
|
|
"\n",
|
|
" text_input.value = '' # Clear input after sending\n",
|
|
"\n",
|
|
"text_input.on_submit(lambda x: asyncio.create_task(on_submit(x)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Tool Calling"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"In this section, we'll explore how to enhance your applications with tool calling capabilities. We'll cover:\n",
|
|
"1. Setting up and using the Brave Search API\n",
|
|
"2. Creating custom tools\n",
|
|
"3. Configuring tool prompts and safety settings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import asyncio\n",
|
|
"import os\n",
|
|
"from typing import Dict, List, Optional\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"\n",
|
|
"from llama_stack_client import LlamaStackClient\n",
|
|
"from llama_stack_client.lib.agents.agent import Agent\n",
|
|
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
|
"from llama_stack_client.types.agent_create_params import (\n",
|
|
" AgentConfig,\n",
|
|
" AgentConfigToolSearchToolDefinition,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Load environment variables\n",
|
|
"load_dotenv()\n",
|
|
"\n",
|
|
"# Helper function to create an agent with tools\n",
|
|
"async def create_tool_agent(\n",
|
|
" client: LlamaStackClient,\n",
|
|
" tools: List[Dict],\n",
|
|
" instructions: str = \"You are a helpful assistant\",\n",
|
|
" model: str = \"Llama3.1-8B-Instruct\",\n",
|
|
") -> Agent:\n",
|
|
" \"\"\"Create an agent with specified tools.\"\"\"\n",
|
|
" agent_config = AgentConfig(\n",
|
|
" model=model,\n",
|
|
" instructions=instructions,\n",
|
|
" sampling_params={\n",
|
|
" \"strategy\": \"greedy\",\n",
|
|
" \"temperature\": 1.0,\n",
|
|
" \"top_p\": 0.9,\n",
|
|
" },\n",
|
|
" tools=tools,\n",
|
|
" tool_choice=\"auto\",\n",
|
|
" tool_prompt_format=\"json\",\n",
|
|
" input_shields=[\"Llama-Guard-3-1B\"],\n",
|
|
" output_shields=[\"Llama-Guard-3-1B\"],\n",
|
|
" enable_session_persistence=True,\n",
|
|
" )\n",
|
|
"\n",
|
|
" return Agent(client, agent_config)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"First, create a `.env` file in your notebook directory with your Brave Search API key:\n",
|
|
"\n",
|
|
"```\n",
|
|
"BRAVE_SEARCH_API_KEY=your_key_here\n",
|
|
"```\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"async def create_search_agent(client: LlamaStackClient) -> Agent:\n",
|
|
" \"\"\"Create an agent with Brave Search capability.\"\"\"\n",
|
|
" search_tool = AgentConfigToolSearchToolDefinition(\n",
|
|
" type=\"brave_search\",\n",
|
|
" engine=\"brave\",\n",
|
|
" api_key=os.getenv(\"BRAVE_SEARCH_API_KEY\"),\n",
|
|
" )\n",
|
|
"\n",
|
|
" return await create_tool_agent(\n",
|
|
" client=client,\n",
|
|
" tools=[search_tool],\n",
|
|
" instructions=\"\"\"\n",
|
|
" You are a research assistant that can search the web.\n",
|
|
" Always cite your sources with URLs when providing information.\n",
|
|
" Format your responses as:\n",
|
|
"\n",
|
|
" FINDINGS:\n",
|
|
" [Your summary here]\n",
|
|
"\n",
|
|
" SOURCES:\n",
|
|
" - [Source title](URL)\n",
|
|
" \"\"\"\n",
|
|
" )\n",
|
|
"\n",
|
|
"# Example usage\n",
|
|
"async def search_example():\n",
|
|
" client = LlamaStackClient(base_url=\"http://localhost:8000\")\n",
|
|
" agent = await create_search_agent(client)\n",
|
|
"\n",
|
|
" # Create a session\n",
|
|
" session_id = agent.create_session(\"search-session\")\n",
|
|
"\n",
|
|
" # Example queries\n",
|
|
" queries = [\n",
|
|
" \"What are the latest developments in quantum computing?\",\n",
|
|
" \"Who won the most recent Super Bowl?\",\n",
|
|
" ]\n",
|
|
"\n",
|
|
" for query in queries:\n",
|
|
" print(f\"\\nQuery: {query}\")\n",
|
|
" print(\"-\" * 50)\n",
|
|
"\n",
|
|
" response = agent.create_turn(\n",
|
|
" messages=[{\"role\": \"user\", \"content\": query}],\n",
|
|
" session_id=session_id,\n",
|
|
" )\n",
|
|
"\n",
|
|
" async for log in EventLogger().log(response):\n",
|
|
" log.print()\n",
|
|
"\n",
|
|
"# Run the example (in Jupyter, use asyncio.run())\n",
|
|
"await search_example()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 3. Custom Tool Creation\n",
|
|
"\n",
|
|
"Let's create a custom weather tool:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import TypedDict, Optional\n",
|
|
"from datetime import datetime\n",
|
|
"\n",
|
|
"# Define tool types\n",
|
|
"class WeatherInput(TypedDict):\n",
|
|
" location: str\n",
|
|
" date: Optional[str]\n",
|
|
"\n",
|
|
"class WeatherOutput(TypedDict):\n",
|
|
" temperature: float\n",
|
|
" conditions: str\n",
|
|
" humidity: float\n",
|
|
"\n",
|
|
"class WeatherTool:\n",
|
|
" \"\"\"Example custom tool for weather information.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, api_key: Optional[str] = None):\n",
|
|
" self.api_key = api_key\n",
|
|
"\n",
|
|
" async def get_weather(self, location: str, date: Optional[str] = None) -> WeatherOutput:\n",
|
|
" \"\"\"Simulate getting weather data (replace with actual API call).\"\"\"\n",
|
|
" # Mock implementation\n",
|
|
" return {\n",
|
|
" \"temperature\": 72.5,\n",
|
|
" \"conditions\": \"partly cloudy\",\n",
|
|
" \"humidity\": 65.0\n",
|
|
" }\n",
|
|
"\n",
|
|
" async def __call__(self, input_data: WeatherInput) -> WeatherOutput:\n",
|
|
" \"\"\"Make the tool callable with structured input.\"\"\"\n",
|
|
" return await self.get_weather(\n",
|
|
" location=input_data[\"location\"],\n",
|
|
" date=input_data.get(\"date\")\n",
|
|
" )\n",
|
|
"\n",
|
|
"async def create_weather_agent(client: LlamaStackClient) -> Agent:\n",
|
|
" \"\"\"Create an agent with weather tool capability.\"\"\"\n",
|
|
" weather_tool = {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"get_weather\",\n",
|
|
" \"description\": \"Get weather information for a location\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"location\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"City or location name\"\n",
|
|
" },\n",
|
|
" \"date\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"Optional date (YYYY-MM-DD)\",\n",
|
|
" \"format\": \"date\"\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"location\"]\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"implementation\": WeatherTool()\n",
|
|
" }\n",
|
|
"\n",
|
|
" return await create_tool_agent(\n",
|
|
" client=client,\n",
|
|
" tools=[weather_tool],\n",
|
|
" instructions=\"\"\"\n",
|
|
" You are a weather assistant that can provide weather information.\n",
|
|
" Always specify the location clearly in your responses.\n",
|
|
" Include both temperature and conditions in your summaries.\n",
|
|
" \"\"\"\n",
|
|
" )\n",
|
|
"\n",
|
|
"# Example usage\n",
|
|
"async def weather_example():\n",
|
|
" client = LlamaStackClient(base_url=\"http://localhost:8000\")\n",
|
|
" agent = await create_weather_agent(client)\n",
|
|
"\n",
|
|
" session_id = agent.create_session(\"weather-session\")\n",
|
|
"\n",
|
|
" queries = [\n",
|
|
" \"What's the weather like in San Francisco?\",\n",
|
|
" \"Tell me the weather in Tokyo tomorrow\",\n",
|
|
" ]\n",
|
|
"\n",
|
|
" for query in queries:\n",
|
|
" print(f\"\\nQuery: {query}\")\n",
|
|
" print(\"-\" * 50)\n",
|
|
"\n",
|
|
" response = agent.create_turn(\n",
|
|
" messages=[{\"role\": \"user\", \"content\": query}],\n",
|
|
" session_id=session_id,\n",
|
|
" )\n",
|
|
"\n",
|
|
" async for log in EventLogger().log(response):\n",
|
|
" log.print()\n",
|
|
"\n",
|
|
"# Run the example\n",
|
|
"await weather_example()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Multi-Tool Agent"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"async def create_multi_tool_agent(client: LlamaStackClient) -> Agent:\n",
|
|
" \"\"\"Create an agent with multiple tools.\"\"\"\n",
|
|
" tools = [\n",
|
|
" # Brave Search tool\n",
|
|
" AgentConfigToolSearchToolDefinition(\n",
|
|
" type=\"brave_search\",\n",
|
|
" engine=\"brave\",\n",
|
|
" api_key=os.getenv(\"BRAVE_SEARCH_API_KEY\"),\n",
|
|
" ),\n",
|
|
" # Weather tool\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"get_weather\",\n",
|
|
" \"description\": \"Get weather information for a location\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"location\": {\"type\": \"string\"},\n",
|
|
" \"date\": {\"type\": \"string\", \"format\": \"date\"}\n",
|
|
" },\n",
|
|
" \"required\": [\"location\"]\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"implementation\": WeatherTool()\n",
|
|
" }\n",
|
|
" ]\n",
|
|
"\n",
|
|
" return await create_tool_agent(\n",
|
|
" client=client,\n",
|
|
" tools=tools,\n",
|
|
" instructions=\"\"\"\n",
|
|
" You are an assistant that can search the web and check weather information.\n",
|
|
" Use the appropriate tool based on the user's question.\n",
|
|
" For weather queries, always specify location and conditions.\n",
|
|
" For web searches, always cite your sources.\n",
|
|
" \"\"\"\n",
|
|
" )\n",
|
|
"\n",
|
|
"# Interactive example with multi-tool agent\n",
|
|
"async def interactive_multi_tool():\n",
|
|
" client = LlamaStackClient(base_url=\"http://localhost:8000\")\n",
|
|
" agent = await create_multi_tool_agent(client)\n",
|
|
" session_id = agent.create_session(\"interactive-session\")\n",
|
|
"\n",
|
|
" print(\"🤖 Multi-tool Agent Ready! (type 'exit' to quit)\")\n",
|
|
" print(\"Example questions:\")\n",
|
|
" print(\"- What's the weather in Paris and what events are happening there?\")\n",
|
|
" print(\"- Tell me about recent space discoveries and the weather on Mars\")\n",
|
|
"\n",
|
|
" while True:\n",
|
|
" query = input(\"\\nYour question: \")\n",
|
|
" if query.lower() == 'exit':\n",
|
|
" break\n",
|
|
"\n",
|
|
" print(\"\\nThinking...\")\n",
|
|
" try:\n",
|
|
" response = agent.create_turn(\n",
|
|
" messages=[{\"role\": \"user\", \"content\": query}],\n",
|
|
" session_id=session_id,\n",
|
|
" )\n",
|
|
"\n",
|
|
" async for log in EventLogger().log(response):\n",
|
|
" log.print()\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error: {e}\")\n",
|
|
"\n",
|
|
"# Run interactive example\n",
|
|
"await interactive_multi_tool()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Memory "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Getting Started with Memory API Tutorial 🚀\n",
|
|
"Welcome! This interactive tutorial will guide you through using the Memory API, a powerful tool for document storage and retrieval. Whether you're new to vector databases or an experienced developer, this notebook will help you understand the basics and get up and running quickly.\n",
|
|
"What you'll learn:\n",
|
|
"\n",
|
|
"How to set up and configure the Memory API client\n",
|
|
"Creating and managing memory banks (vector stores)\n",
|
|
"Different ways to insert documents into the system\n",
|
|
"How to perform intelligent queries on your documents\n",
|
|
"\n",
|
|
"Prerequisites:\n",
|
|
"\n",
|
|
"Basic Python knowledge\n",
|
|
"A running instance of the Memory API server (we'll use localhost in this tutorial)\n",
|
|
"\n",
|
|
"Let's start by installing the required packages:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Install the client library and a helper package for colored output\n",
|
|
"!pip install llama-stack-client termcolor\n",
|
|
"\n",
|
|
"# 💡 Note: If you're running this in a new environment, you might need to restart\n",
|
|
"# your kernel after installation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"1. Initial Setup\n",
|
|
"First, we'll import the necessary libraries and set up some helper functions. Let's break down what each import does:\n",
|
|
"\n",
|
|
"llama_stack_client: Our main interface to the Memory API\n",
|
|
"base64: Helps us encode files for transmission\n",
|
|
"mimetypes: Determines file types automatically\n",
|
|
"termcolor: Makes our output prettier with colors\n",
|
|
"\n",
|
|
"❓ Question: Why do we need to convert files to data URLs?\n",
|
|
"Answer: Data URLs allow us to embed file contents directly in our requests, making it easier to transmit files to the API without needing separate file uploads."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import base64\n",
|
|
"import json\n",
|
|
"import mimetypes\n",
|
|
"import os\n",
|
|
"from pathlib import Path\n",
|
|
"\n",
|
|
"from llama_stack_client import LlamaStackClient\n",
|
|
"from llama_stack_client.types.memory_insert_params import Document\n",
|
|
"from termcolor import cprint\n",
|
|
"\n",
|
|
"# Helper function to convert files to data URLs\n",
|
|
"def data_url_from_file(file_path: str) -> str:\n",
|
|
" \"\"\"Convert a file to a data URL for API transmission\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" file_path (str): Path to the file to convert\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" str: Data URL containing the file's contents\n",
|
|
"\n",
|
|
" Example:\n",
|
|
" >>> url = data_url_from_file('example.txt')\n",
|
|
" >>> print(url[:30]) # Preview the start of the URL\n",
|
|
" 'data:text/plain;base64,SGVsbG8='\n",
|
|
" \"\"\"\n",
|
|
" if not os.path.exists(file_path):\n",
|
|
" raise FileNotFoundError(f\"File not found: {file_path}\")\n",
|
|
"\n",
|
|
" with open(file_path, \"rb\") as file:\n",
|
|
" file_content = file.read()\n",
|
|
"\n",
|
|
" base64_content = base64.b64encode(file_content).decode(\"utf-8\")\n",
|
|
" mime_type, _ = mimetypes.guess_type(file_path)\n",
|
|
"\n",
|
|
" data_url = f\"data:{mime_type};base64,{base64_content}\"\n",
|
|
" return data_url"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"2. Initialize Client and Create Memory Bank\n",
|
|
"Now we'll set up our connection to the Memory API and create our first memory bank. A memory bank is like a specialized database that stores document embeddings for semantic search.\n",
|
|
"❓ Key Concepts:\n",
|
|
"\n",
|
|
"embedding_model: The model used to convert text into vector representations\n",
|
|
"chunk_size: How large each piece of text should be when splitting documents\n",
|
|
"overlap_size: How much overlap between chunks (helps maintain context)\n",
|
|
"\n",
|
|
"✨ Pro Tip: Choose your chunk size based on your use case. Smaller chunks (256-512 tokens) are better for precise retrieval, while larger chunks (1024+ tokens) maintain more context."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Configure connection parameters\n",
|
|
"HOST = \"localhost\" # Replace with your host if using a remote server\n",
|
|
"PORT = 8000 # Replace with your port if different\n",
|
|
"\n",
|
|
"# Initialize client\n",
|
|
"client = LlamaStackClient(\n",
|
|
" base_url=f\"http://{HOST}:{PORT}\",\n",
|
|
")\n",
|
|
"\n",
|
|
"# Let's see what providers are available\n",
|
|
"# Providers determine where and how your data is stored\n",
|
|
"providers = client.providers.list()\n",
|
|
"print(\"Available providers:\")\n",
|
|
"print(json.dumps(providers, indent=2))\n",
|
|
"\n",
|
|
"# Create a memory bank with optimized settings for general use\n",
|
|
"client.memory_banks.register(\n",
|
|
" memory_bank={\n",
|
|
" \"identifier\": \"tutorial_bank\", # A unique name for your memory bank\n",
|
|
" \"embedding_model\": \"all-MiniLM-L6-v2\", # A lightweight but effective model\n",
|
|
" \"chunk_size_in_tokens\": 512, # Good balance between precision and context\n",
|
|
" \"overlap_size_in_tokens\": 64, # Helps maintain context between chunks\n",
|
|
" \"provider_id\": providers[\"memory\"][0].provider_id, # Use the first available provider\n",
|
|
" }\n",
|
|
")\n",
|
|
"\n",
|
|
"# Let's verify our memory bank was created\n",
|
|
"memory_banks = client.memory_banks.list()\n",
|
|
"print(\"\\nRegistered memory banks:\")\n",
|
|
"print(json.dumps(memory_banks, indent=2))\n",
|
|
"\n",
|
|
"# 🎯 Exercise: Try creating another memory bank with different settings!\n",
|
|
"# What happens if you try to create a bank with the same identifier?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"3. Insert Documents\n",
|
|
"The Memory API supports multiple ways to add documents. We'll demonstrate two common approaches:\n",
|
|
"\n",
|
|
"Loading documents from URLs\n",
|
|
"Loading documents from local files\n",
|
|
"\n",
|
|
"❓ Important Concepts:\n",
|
|
"\n",
|
|
"Each document needs a unique document_id\n",
|
|
"Metadata helps organize and filter documents later\n",
|
|
"The API automatically processes and chunks documents"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Example URLs to documentation\n",
|
|
"# 💡 Replace these with your own URLs or use the examples\n",
|
|
"urls = [\n",
|
|
" \"memory_optimizations.rst\",\n",
|
|
" \"chat.rst\",\n",
|
|
" \"llama3.rst\",\n",
|
|
"]\n",
|
|
"\n",
|
|
"# Create documents from URLs\n",
|
|
"# We add metadata to help organize our documents\n",
|
|
"url_documents = [\n",
|
|
" Document(\n",
|
|
" document_id=f\"url-doc-{i}\", # Unique ID for each document\n",
|
|
" content=f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
|
|
" mime_type=\"text/plain\",\n",
|
|
" metadata={\"source\": \"url\", \"filename\": url}, # Metadata helps with organization\n",
|
|
" )\n",
|
|
" for i, url in enumerate(urls)\n",
|
|
"]\n",
|
|
"\n",
|
|
"# Example with local files\n",
|
|
"# 💡 Replace these with your actual files\n",
|
|
"local_files = [\"example.txt\", \"readme.md\"]\n",
|
|
"file_documents = [\n",
|
|
" Document(\n",
|
|
" document_id=f\"file-doc-{i}\",\n",
|
|
" content=data_url_from_file(path),\n",
|
|
" metadata={\"source\": \"local\", \"filename\": path},\n",
|
|
" )\n",
|
|
" for i, path in enumerate(local_files)\n",
|
|
" if os.path.exists(path)\n",
|
|
"]\n",
|
|
"\n",
|
|
"# Combine all documents\n",
|
|
"all_documents = url_documents + file_documents\n",
|
|
"\n",
|
|
"# Insert documents into memory bank\n",
|
|
"response = client.memory.insert(\n",
|
|
" bank_id=\"tutorial_bank\",\n",
|
|
" documents=all_documents,\n",
|
|
")\n",
|
|
"\n",
|
|
"print(\"Documents inserted successfully!\")\n",
|
|
"\n",
|
|
"# 🎯 Exercise: Try adding your own documents!\n",
|
|
"# - What happens if you try to insert a document with an existing ID?\n",
|
|
"# - What other metadata might be useful to add?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"4. Query the Memory Bank\n",
|
|
"Now for the exciting part - querying our documents! The Memory API uses semantic search to find relevant content based on meaning, not just keywords.\n",
|
|
"❓ Understanding Scores:\n",
|
|
"\n",
|
|
"Scores range from 0 to 1, with 1 being the most relevant\n",
|
|
"Generally, scores above 0.7 indicate strong relevance\n",
|
|
"Consider your use case when deciding on score thresholds"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def print_query_results(query: str):\n",
|
|
" \"\"\"Helper function to print query results in a readable format\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" query (str): The search query to execute\n",
|
|
" \"\"\"\n",
|
|
" print(f\"\\nQuery: {query}\")\n",
|
|
" print(\"-\" * 50)\n",
|
|
"\n",
|
|
" response = client.memory.query(\n",
|
|
" bank_id=\"tutorial_bank\",\n",
|
|
" query=[query], # The API accepts multiple queries at once!\n",
|
|
" )\n",
|
|
"\n",
|
|
" for i, (chunk, score) in enumerate(zip(response.chunks, response.scores)):\n",
|
|
" print(f\"\\nResult {i+1} (Score: {score:.3f})\")\n",
|
|
" print(\"=\" * 40)\n",
|
|
" print(chunk)\n",
|
|
" print(\"=\" * 40)\n",
|
|
"\n",
|
|
"# Let's try some example queries\n",
|
|
"queries = [\n",
|
|
" \"How do I use LoRA?\", # Technical question\n",
|
|
" \"Tell me about memory optimizations\", # General topic\n",
|
|
" \"What are the key features of Llama 3?\" # Product-specific\n",
|
|
"]\n",
|
|
"\n",
|
|
"for query in queries:\n",
|
|
" print_query_results(query)\n",
|
|
"\n",
|
|
"# 🎯 Exercises:\n",
|
|
"# 1. Try writing your own queries! What works well? What doesn't?\n",
|
|
"# 2. How do different phrasings of the same question affect results?\n",
|
|
"# 3. What happens if you query for content that isn't in your documents?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"5. Advanced Usage: Query with Metadata Filtering\n",
|
|
"One powerful feature is the ability to filter results based on metadata. This helps when you want to search within specific subsets of your documents.\n",
|
|
"❓ Use Cases for Metadata Filtering:\n",
|
|
"\n",
|
|
"Search within specific document types\n",
|
|
"Filter by date ranges\n",
|
|
"Limit results to certain authors or sources"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Query with metadata filter\n",
|
|
"response = client.memory.query(\n",
|
|
" bank_id=\"tutorial_bank\",\n",
|
|
" query=[\"Tell me about optimization\"],\n",
|
|
" metadata_filter={\"source\": \"url\"} # Only search in URL documents\n",
|
|
")\n",
|
|
"\n",
|
|
"print(\"\\nFiltered Query Results:\")\n",
|
|
"print(\"-\" * 50)\n",
|
|
"for chunk, score in zip(response.chunks, response.scores):\n",
|
|
" print(f\"Score: {score:.3f}\")\n",
|
|
" print(f\"Chunk:\\n{chunk}\\n\")\n",
|
|
"\n",
|
|
"# 🎯 Advanced Exercises:\n",
|
|
"# 1. Try combining multiple metadata filters\n",
|
|
"# 2. Compare results with and without filters\n",
|
|
"# 3. What happens with non-existent metadata fields?"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"name": "python",
|
|
"version": "3.12.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|