add more about safety and agent docs

This commit is contained in:
Kai Wu 2024-11-04 16:23:46 -08:00
parent d61f328ffb
commit 87904d329f
4 changed files with 161 additions and 81 deletions

View file

@ -30,24 +30,24 @@
"# Helper function to convert image to data URL\n",
"def image_to_data_url(file_path: Union[str, Path]) -> str:\n",
" \"\"\"Convert an image file to a data URL format.\n",
" \n",
"\n",
" Args:\n",
" file_path: Path to the image file\n",
" \n",
"\n",
" Returns:\n",
" str: Data URL containing the encoded image\n",
" \"\"\"\n",
" file_path = Path(file_path)\n",
" if not file_path.exists():\n",
" raise FileNotFoundError(f\"Image not found: {file_path}\")\n",
" \n",
"\n",
" mime_type, _ = mimetypes.guess_type(str(file_path))\n",
" if mime_type is None:\n",
" raise ValueError(\"Could not determine MIME type of the image\")\n",
" \n",
"\n",
" with open(file_path, \"rb\") as image_file:\n",
" encoded_string = base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
" \n",
"\n",
" return f\"data:{mime_type};base64,{encoded_string}\""
]
},
@ -97,7 +97,7 @@
" if question.lower() == 'exit':\n",
" print(\"Chat ended.\")\n",
" return\n",
" \n",
"\n",
" message = UserMessage(\n",
" role=\"user\",\n",
" content=[\n",
@ -105,18 +105,18 @@
" question,\n",
" ],\n",
" )\n",
" \n",
"\n",
" print(f\"\\nUser> {question}\")\n",
" response = client.inference.chat_completion(\n",
" messages=[message],\n",
" model=\"Llama3.2-11B-Vision-Instruct\",\n",
" stream=True,\n",
" )\n",
" \n",
"\n",
" print(\"Assistant> \", end='')\n",
" async for log in EventLogger().log(response):\n",
" log.print()\n",
" \n",
"\n",
" text_input.value = '' # Clear input after sending\n",
"\n",
"text_input.on_submit(lambda x: asyncio.create_task(on_submit(x)))"
@ -184,7 +184,7 @@
" output_shields=[\"llama_guard\"],\n",
" enable_session_persistence=True,\n",
" )\n",
" \n",
"\n",
" return Agent(client, agent_config)"
]
},
@ -212,7 +212,7 @@
" engine=\"brave\",\n",
" api_key=os.getenv(\"BRAVE_SEARCH_API_KEY\"),\n",
" )\n",
" \n",
"\n",
" return await create_tool_agent(\n",
" client=client,\n",
" tools=[search_tool],\n",
@ -220,10 +220,10 @@
" You are a research assistant that can search the web.\n",
" Always cite your sources with URLs when providing information.\n",
" Format your responses as:\n",
" \n",
"\n",
" FINDINGS:\n",
" [Your summary here]\n",
" \n",
"\n",
" SOURCES:\n",
" - [Source title](URL)\n",
" \"\"\"\n",
@ -233,25 +233,25 @@
"async def search_example():\n",
" client = LlamaStackClient(base_url=\"http://localhost:8000\")\n",
" agent = await create_search_agent(client)\n",
" \n",
"\n",
" # Create a session\n",
" session_id = agent.create_session(\"search-session\")\n",
" \n",
"\n",
" # Example queries\n",
" queries = [\n",
" \"What are the latest developments in quantum computing?\",\n",
" \"Who won the most recent Super Bowl?\",\n",
" ]\n",
" \n",
"\n",
" for query in queries:\n",
" print(f\"\\nQuery: {query}\")\n",
" print(\"-\" * 50)\n",
" \n",
"\n",
" response = agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": query}],\n",
" session_id=session_id,\n",
" )\n",
" \n",
"\n",
" async for log in EventLogger().log(response):\n",
" log.print()\n",
"\n",
@ -289,10 +289,10 @@
"\n",
"class WeatherTool:\n",
" \"\"\"Example custom tool for weather information.\"\"\"\n",
" \n",
"\n",
" def __init__(self, api_key: Optional[str] = None):\n",
" self.api_key = api_key\n",
" \n",
"\n",
" async def get_weather(self, location: str, date: Optional[str] = None) -> WeatherOutput:\n",
" \"\"\"Simulate getting weather data (replace with actual API call).\"\"\"\n",
" # Mock implementation\n",
@ -301,7 +301,7 @@
" \"conditions\": \"partly cloudy\",\n",
" \"humidity\": 65.0\n",
" }\n",
" \n",
"\n",
" async def __call__(self, input_data: WeatherInput) -> WeatherOutput:\n",
" \"\"\"Make the tool callable with structured input.\"\"\"\n",
" return await self.get_weather(\n",
@ -334,7 +334,7 @@
" },\n",
" \"implementation\": WeatherTool()\n",
" }\n",
" \n",
"\n",
" return await create_tool_agent(\n",
" client=client,\n",
" tools=[weather_tool],\n",
@ -349,23 +349,23 @@
"async def weather_example():\n",
" client = LlamaStackClient(base_url=\"http://localhost:8000\")\n",
" agent = await create_weather_agent(client)\n",
" \n",
"\n",
" session_id = agent.create_session(\"weather-session\")\n",
" \n",
"\n",
" queries = [\n",
" \"What's the weather like in San Francisco?\",\n",
" \"Tell me the weather in Tokyo tomorrow\",\n",
" ]\n",
" \n",
"\n",
" for query in queries:\n",
" print(f\"\\nQuery: {query}\")\n",
" print(\"-\" * 50)\n",
" \n",
"\n",
" response = agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": query}],\n",
" session_id=session_id,\n",
" )\n",
" \n",
"\n",
" async for log in EventLogger().log(response):\n",
" log.print()\n",
"\n",
@ -413,7 +413,7 @@
" \"implementation\": WeatherTool()\n",
" }\n",
" ]\n",
" \n",
"\n",
" return await create_tool_agent(\n",
" client=client,\n",
" tools=tools,\n",
@ -430,24 +430,24 @@
" client = LlamaStackClient(base_url=\"http://localhost:8000\")\n",
" agent = await create_multi_tool_agent(client)\n",
" session_id = agent.create_session(\"interactive-session\")\n",
" \n",
"\n",
" print(\"🤖 Multi-tool Agent Ready! (type 'exit' to quit)\")\n",
" print(\"Example questions:\")\n",
" print(\"- What's the weather in Paris and what events are happening there?\")\n",
" print(\"- Tell me about recent space discoveries and the weather on Mars\")\n",
" \n",
"\n",
" while True:\n",
" query = input(\"\\nYour question: \")\n",
" if query.lower() == 'exit':\n",
" break\n",
" \n",
"\n",
" print(\"\\nThinking...\")\n",
" try:\n",
" response = agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": query}],\n",
" session_id=session_id,\n",
" )\n",
" \n",
"\n",
" async for log in EventLogger().log(response):\n",
" log.print()\n",
" except Exception as e:\n",
@ -533,13 +533,13 @@
"# Helper function to convert files to data URLs\n",
"def data_url_from_file(file_path: str) -> str:\n",
" \"\"\"Convert a file to a data URL for API transmission\n",
" \n",
"\n",
" Args:\n",
" file_path (str): Path to the file to convert\n",
" \n",
"\n",
" Returns:\n",
" str: Data URL containing the file's contents\n",
" \n",
"\n",
" Example:\n",
" >>> url = data_url_from_file('example.txt')\n",
" >>> print(url[:30]) # Preview the start of the URL\n",
@ -707,18 +707,18 @@
"source": [
"def print_query_results(query: str):\n",
" \"\"\"Helper function to print query results in a readable format\n",
" \n",
"\n",
" Args:\n",
" query (str): The search query to execute\n",
" \"\"\"\n",
" print(f\"\\nQuery: {query}\")\n",
" print(\"-\" * 50)\n",
" \n",
"\n",
" response = client.memory.query(\n",
" bank_id=\"tutorial_bank\",\n",
" query=[query], # The API accepts multiple queries at once!\n",
" )\n",
" \n",
"\n",
" for i, (chunk, score) in enumerate(zip(response.chunks, response.scores)):\n",
" print(f\"\\nResult {i+1} (Score: {score:.3f})\")\n",
" print(\"=\" * 40)\n",