mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
test
This commit is contained in:
parent
1f04ca357b
commit
0b1e71718c
2 changed files with 448 additions and 156 deletions
193
docs/notebooks/RAG_as_attchements.ipynb
Normal file
193
docs/notebooks/RAG_as_attchements.ipynb
Normal file
|
@ -0,0 +1,193 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from llama_stack_client import LlamaStackClient\n",
|
||||||
|
"from llama_stack_client.types import Document\n",
|
||||||
|
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
||||||
|
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||||
|
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||||
|
"from rich.pretty import pprint\n",
|
||||||
|
"import json\n",
|
||||||
|
"import uuid\n",
|
||||||
|
"from pydantic import BaseModel\n",
|
||||||
|
"import rich\n",
|
||||||
|
"import os"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 18,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"MODEL_ID = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
|
||||||
|
"\n",
|
||||||
|
"client = LlamaStackClient(\n",
|
||||||
|
" base_url=\"http://localhost:8321\",\n",
|
||||||
|
" provider_data={\n",
|
||||||
|
" \"fireworks_api_key\": os.environ[\"FIREWORKS_API_KEY\"]\n",
|
||||||
|
" }\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"urls = [\n",
|
||||||
|
" \"memory_optimizations.rst\",\n",
|
||||||
|
" \"chat.rst\",\n",
|
||||||
|
" \"llama3.rst\",\n",
|
||||||
|
" \"datasets.rst\",\n",
|
||||||
|
" \"qat_finetune.rst\",\n",
|
||||||
|
" \"lora_finetune.rst\",\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"attachments = [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"content\": f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
|
||||||
|
" \"mime_type\": \"text/plain\",\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" for i, url in enumerate(urls)\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"simple_agent = Agent(client, model=MODEL_ID, \n",
|
||||||
|
" instructions=\"You are a helpful assistant that can answer questions about the Torchtune project.\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Turn</span><span style=\"font-weight: bold\">(</span>\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">input_messages</span>=<span style=\"font-weight: bold\">[</span>\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">UserMessage</span><span style=\"font-weight: bold\">(</span><span style=\"color: #808000; text-decoration-color: #808000\">content</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'What precision formats does torchtune support?'</span>, <span style=\"color: #808000; text-decoration-color: #808000\">role</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'user'</span>, <span style=\"color: #808000; text-decoration-color: #808000\">context</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span><span style=\"font-weight: bold\">)</span>\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"font-weight: bold\">]</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">output_message</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">CompletionMessage</span><span style=\"font-weight: bold\">(</span>\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">content</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'Torchtune supports the following precision formats:\\n\\n* FP32 (32-bit floating point)\\n* FP16 (16-bit floating point)\\n* INT8 (8-bit integer)\\n* BF16 (Brain Floating Point 16, a 16-bit floating point format)\\n\\nThese precision formats can be used for model weights, activations, and gradients, allowing for flexible and efficient tuning of models for various hardware and performance requirements.'</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">role</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'assistant'</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">stop_reason</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'end_of_turn'</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">tool_calls</span>=<span style=\"font-weight: bold\">[]</span>\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"font-weight: bold\">)</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">session_id</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'1c23c79b-3945-4e99-bda6-7922b6b4e91c'</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">started_at</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">datetime</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">.datetime</span><span style=\"font-weight: bold\">(</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2025</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">20</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">22</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">41</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">175804</span>, <span style=\"color: #808000; text-decoration-color: #808000\">tzinfo</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">TzInfo</span><span style=\"font-weight: bold\">(</span>UTC<span style=\"font-weight: bold\">))</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">steps</span>=<span style=\"font-weight: bold\">[</span>\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">InferenceStep</span><span style=\"font-weight: bold\">(</span>\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">api_model_response</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">CompletionMessage</span><span style=\"font-weight: bold\">(</span>\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">content</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'Torchtune supports the following precision formats:\\n\\n* FP32 (32-bit floating point)\\n* FP16 (16-bit floating point)\\n* INT8 (8-bit integer)\\n* BF16 (Brain Floating Point 16, a 16-bit floating point format)\\n\\nThese precision formats can be used for model weights, activations, and gradients, allowing for flexible and efficient tuning of models for various hardware and performance requirements.'</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">role</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'assistant'</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">stop_reason</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'end_of_turn'</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">tool_calls</span>=<span style=\"font-weight: bold\">[]</span>\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"font-weight: bold\">)</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">step_id</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'bf452f18-8fae-470e-9e97-b1af60628fc1'</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">step_type</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'inference'</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">turn_id</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'efb92c6d-d482-4dd2-ad4b-3250c1e9a231'</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">completed_at</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">datetime</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">.datetime</span><span style=\"font-weight: bold\">(</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2025</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">20</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">22</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">41</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">618765</span>, <span style=\"color: #808000; text-decoration-color: #808000\">tzinfo</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">TzInfo</span><span style=\"font-weight: bold\">(</span>UTC<span style=\"font-weight: bold\">))</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">started_at</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">datetime</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">.datetime</span><span style=\"font-weight: bold\">(</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2025</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">20</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">22</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">41</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">175855</span>, <span style=\"color: #808000; text-decoration-color: #808000\">tzinfo</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">TzInfo</span><span style=\"font-weight: bold\">(</span>UTC<span style=\"font-weight: bold\">))</span>\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">)</span>\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"font-weight: bold\">]</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">turn_id</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'efb92c6d-d482-4dd2-ad4b-3250c1e9a231'</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">completed_at</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">datetime</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">.datetime</span><span style=\"font-weight: bold\">(</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2025</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">20</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">22</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">41</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">631357</span>, <span style=\"color: #808000; text-decoration-color: #808000\">tzinfo</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">TzInfo</span><span style=\"font-weight: bold\">(</span>UTC<span style=\"font-weight: bold\">))</span>,\n",
|
||||||
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">output_attachments</span>=<span style=\"font-weight: bold\">[]</span>\n",
|
||||||
|
"<span style=\"font-weight: bold\">)</span>\n",
|
||||||
|
"</pre>\n"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"\u001b[1;35mTurn\u001b[0m\u001b[1m(\u001b[0m\n",
|
||||||
|
"\u001b[2;32m│ \u001b[0m\u001b[33minput_messages\u001b[0m=\u001b[1m[\u001b[0m\n",
|
||||||
|
"\u001b[2;32m│ │ \u001b[0m\u001b[1;35mUserMessage\u001b[0m\u001b[1m(\u001b[0m\u001b[33mcontent\u001b[0m=\u001b[32m'What precision formats does torchtune support?'\u001b[0m, \u001b[33mrole\u001b[0m=\u001b[32m'user'\u001b[0m, \u001b[33mcontext\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m\n",
|
||||||
|
"\u001b[2;32m│ \u001b[0m\u001b[1m]\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ \u001b[0m\u001b[33moutput_message\u001b[0m=\u001b[1;35mCompletionMessage\u001b[0m\u001b[1m(\u001b[0m\n",
|
||||||
|
"\u001b[2;32m│ │ \u001b[0m\u001b[33mcontent\u001b[0m=\u001b[32m'Torchtune supports the following precision formats:\\n\\n* FP32 \u001b[0m\u001b[32m(\u001b[0m\u001b[32m32-bit floating point\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n* FP16 \u001b[0m\u001b[32m(\u001b[0m\u001b[32m16-bit floating point\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n* INT8 \u001b[0m\u001b[32m(\u001b[0m\u001b[32m8-bit integer\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n* BF16 \u001b[0m\u001b[32m(\u001b[0m\u001b[32mBrain Floating Point 16, a 16-bit floating point format\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n\\nThese precision formats can be used for model weights, activations, and gradients, allowing for flexible and efficient tuning of models for various hardware and performance requirements.'\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ │ \u001b[0m\u001b[33mrole\u001b[0m=\u001b[32m'assistant'\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ │ \u001b[0m\u001b[33mstop_reason\u001b[0m=\u001b[32m'end_of_turn'\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ │ \u001b[0m\u001b[33mtool_calls\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||||
|
"\u001b[2;32m│ \u001b[0m\u001b[1m)\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ \u001b[0m\u001b[33msession_id\u001b[0m=\u001b[32m'1c23c79b-3945-4e99-bda6-7922b6b4e91c'\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ \u001b[0m\u001b[33mstarted_at\u001b[0m=\u001b[1;35mdatetime\u001b[0m\u001b[1;35m.datetime\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m2025\u001b[0m, \u001b[1;36m3\u001b[0m, \u001b[1;36m20\u001b[0m, \u001b[1;36m22\u001b[0m, \u001b[1;36m41\u001b[0m, \u001b[1;36m0\u001b[0m, \u001b[1;36m175804\u001b[0m, \u001b[33mtzinfo\u001b[0m=\u001b[1;35mTzInfo\u001b[0m\u001b[1m(\u001b[0mUTC\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ \u001b[0m\u001b[33msteps\u001b[0m=\u001b[1m[\u001b[0m\n",
|
||||||
|
"\u001b[2;32m│ │ \u001b[0m\u001b[1;35mInferenceStep\u001b[0m\u001b[1m(\u001b[0m\n",
|
||||||
|
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mapi_model_response\u001b[0m=\u001b[1;35mCompletionMessage\u001b[0m\u001b[1m(\u001b[0m\n",
|
||||||
|
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mcontent\u001b[0m=\u001b[32m'Torchtune supports the following precision formats:\\n\\n* FP32 \u001b[0m\u001b[32m(\u001b[0m\u001b[32m32-bit floating point\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n* FP16 \u001b[0m\u001b[32m(\u001b[0m\u001b[32m16-bit floating point\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n* INT8 \u001b[0m\u001b[32m(\u001b[0m\u001b[32m8-bit integer\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n* BF16 \u001b[0m\u001b[32m(\u001b[0m\u001b[32mBrain Floating Point 16, a 16-bit floating point format\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n\\nThese precision formats can be used for model weights, activations, and gradients, allowing for flexible and efficient tuning of models for various hardware and performance requirements.'\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mrole\u001b[0m=\u001b[32m'assistant'\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mstop_reason\u001b[0m=\u001b[32m'end_of_turn'\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mtool_calls\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||||
|
"\u001b[2;32m│ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mstep_id\u001b[0m=\u001b[32m'bf452f18-8fae-470e-9e97-b1af60628fc1'\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mstep_type\u001b[0m=\u001b[32m'inference'\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mturn_id\u001b[0m=\u001b[32m'efb92c6d-d482-4dd2-ad4b-3250c1e9a231'\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mcompleted_at\u001b[0m=\u001b[1;35mdatetime\u001b[0m\u001b[1;35m.datetime\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m2025\u001b[0m, \u001b[1;36m3\u001b[0m, \u001b[1;36m20\u001b[0m, \u001b[1;36m22\u001b[0m, \u001b[1;36m41\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m618765\u001b[0m, \u001b[33mtzinfo\u001b[0m=\u001b[1;35mTzInfo\u001b[0m\u001b[1m(\u001b[0mUTC\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mstarted_at\u001b[0m=\u001b[1;35mdatetime\u001b[0m\u001b[1;35m.datetime\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m2025\u001b[0m, \u001b[1;36m3\u001b[0m, \u001b[1;36m20\u001b[0m, \u001b[1;36m22\u001b[0m, \u001b[1;36m41\u001b[0m, \u001b[1;36m0\u001b[0m, \u001b[1;36m175855\u001b[0m, \u001b[33mtzinfo\u001b[0m=\u001b[1;35mTzInfo\u001b[0m\u001b[1m(\u001b[0mUTC\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n",
|
||||||
|
"\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n",
|
||||||
|
"\u001b[2;32m│ \u001b[0m\u001b[1m]\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ \u001b[0m\u001b[33mturn_id\u001b[0m=\u001b[32m'efb92c6d-d482-4dd2-ad4b-3250c1e9a231'\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ \u001b[0m\u001b[33mcompleted_at\u001b[0m=\u001b[1;35mdatetime\u001b[0m\u001b[1;35m.datetime\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m2025\u001b[0m, \u001b[1;36m3\u001b[0m, \u001b[1;36m20\u001b[0m, \u001b[1;36m22\u001b[0m, \u001b[1;36m41\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m631357\u001b[0m, \u001b[33mtzinfo\u001b[0m=\u001b[1;35mTzInfo\u001b[0m\u001b[1m(\u001b[0mUTC\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n",
|
||||||
|
"\u001b[2;32m│ \u001b[0m\u001b[33moutput_attachments\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||||
|
"\u001b[1m)\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"simple_session_id = simple_agent.create_session(session_name=f\"simple_session_{uuid.uuid4()}\")\n",
|
||||||
|
"response = simple_agent.create_turn(\n",
|
||||||
|
" messages=[\n",
|
||||||
|
" {\n",
|
||||||
|
" \"role\": \"user\",\n",
|
||||||
|
" \"content\": \"What precision formats does torchtune support?\"\n",
|
||||||
|
" }\n",
|
||||||
|
" ],\n",
|
||||||
|
" session_id=simple_session_id,\n",
|
||||||
|
" stream=False\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"pprint(response)\n",
|
||||||
|
"\n",
|
||||||
|
"session_response = client.agents.session.retrieve(agent_id=simple_agent.agent_id, session_id=simple_session_id)\n",
|
||||||
|
"pprint(session_response)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "master",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.16"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -40,10 +40,10 @@ from llama_stack.apis.agents import (
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
|
URL,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
|
@ -80,7 +80,9 @@ from .safety import SafetyException, ShieldRunnerMixin
|
||||||
|
|
||||||
|
|
||||||
def make_random_string(length: int = 8):
|
def make_random_string(length: int = 8):
|
||||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
return "".join(
|
||||||
|
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
|
@ -179,7 +181,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
messages.extend(self.turn_to_messages(turn))
|
messages.extend(self.turn_to_messages(turn))
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
async def create_and_execute_turn(
|
||||||
|
self, request: AgentTurnCreateRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
await self._initialize_tools(request.toolgroups)
|
await self._initialize_tools(request.toolgroups)
|
||||||
async with tracing.span("create_and_execute_turn") as span:
|
async with tracing.span("create_and_execute_turn") as span:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
@ -220,13 +224,16 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
messages = await self.get_messages_from_turns(turns)
|
messages = await self.get_messages_from_turns(turns)
|
||||||
if is_resume:
|
if is_resume:
|
||||||
tool_response_messages = [
|
tool_response_messages = [
|
||||||
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
ToolResponseMessage(call_id=x.call_id, content=x.content)
|
||||||
|
for x in request.tool_responses
|
||||||
]
|
]
|
||||||
messages.extend(tool_response_messages)
|
messages.extend(tool_response_messages)
|
||||||
last_turn = turns[-1]
|
last_turn = turns[-1]
|
||||||
last_turn_messages = self.turn_to_messages(last_turn)
|
last_turn_messages = self.turn_to_messages(last_turn)
|
||||||
last_turn_messages = [
|
last_turn_messages = [
|
||||||
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
x
|
||||||
|
for x in last_turn_messages
|
||||||
|
if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||||
]
|
]
|
||||||
last_turn_messages.extend(tool_response_messages)
|
last_turn_messages.extend(tool_response_messages)
|
||||||
|
|
||||||
|
@ -236,17 +243,31 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# mark tool execution step as complete
|
# mark tool execution step as complete
|
||||||
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||||
# we'll create a new tool execution step with current time
|
# we'll create a new tool execution step with current time
|
||||||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
in_progress_tool_call_step = (
|
||||||
request.session_id, request.turn_id
|
await self.storage.get_in_progress_tool_call_step(
|
||||||
|
request.session_id, request.turn_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
now = datetime.now(timezone.utc).isoformat()
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
tool_execution_step = ToolExecutionStep(
|
tool_execution_step = ToolExecutionStep(
|
||||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
step_id=(
|
||||||
|
in_progress_tool_call_step.step_id
|
||||||
|
if in_progress_tool_call_step
|
||||||
|
else str(uuid.uuid4())
|
||||||
|
),
|
||||||
turn_id=request.turn_id,
|
turn_id=request.turn_id,
|
||||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
tool_calls=(
|
||||||
|
in_progress_tool_call_step.tool_calls
|
||||||
|
if in_progress_tool_call_step
|
||||||
|
else []
|
||||||
|
),
|
||||||
tool_responses=request.tool_responses,
|
tool_responses=request.tool_responses,
|
||||||
completed_at=now,
|
completed_at=now,
|
||||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
started_at=(
|
||||||
|
in_progress_tool_call_step.started_at
|
||||||
|
if in_progress_tool_call_step
|
||||||
|
else now
|
||||||
|
),
|
||||||
)
|
)
|
||||||
steps.append(tool_execution_step)
|
steps.append(tool_execution_step)
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
@ -280,9 +301,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
output_message = chunk
|
output_message = chunk
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
assert isinstance(
|
||||||
|
chunk, AgentTurnResponseStreamChunk
|
||||||
|
), f"Unexpected type {type(chunk)}"
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
if (
|
||||||
|
event.payload.event_type
|
||||||
|
== AgentTurnResponseEventType.step_complete.value
|
||||||
|
):
|
||||||
steps.append(event.payload.step_details)
|
steps.append(event.payload.step_details)
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
@ -440,6 +466,18 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
span.set_attribute("output", "no violations")
|
span.set_attribute("output", "no violations")
|
||||||
|
|
||||||
|
async def get_raw_document_text(self, document: Document) -> str:
|
||||||
|
if isinstance(document.content, URL):
|
||||||
|
return await load_data_from_url(document.content)
|
||||||
|
elif isinstance(document.content, str):
|
||||||
|
return document.content
|
||||||
|
elif isinstance(document.content, TextContentItem):
|
||||||
|
return document.content.text
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected document content type: {type(document.content)}"
|
||||||
|
)
|
||||||
|
|
||||||
async def _run(
|
async def _run(
|
||||||
self,
|
self,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
@ -449,8 +487,23 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
# if documents:
|
||||||
|
# await self.handle_documents(session_id, documents, input_messages)
|
||||||
|
|
||||||
|
# if document is passed in a turn, we parse the raw text of the document
|
||||||
|
# and sent it as a user message
|
||||||
if documents:
|
if documents:
|
||||||
await self.handle_documents(session_id, documents, input_messages)
|
contexts = []
|
||||||
|
for document in documents:
|
||||||
|
raw_document_text = await self.get_raw_document_text(document)
|
||||||
|
contexts.append(TextContentItem(text=raw_document_text))
|
||||||
|
# modify the last user message to include the document
|
||||||
|
input_messages.append(
|
||||||
|
ToolResponseMessage(
|
||||||
|
call_id=str(uuid.uuid4()),
|
||||||
|
content=contexts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
# if the session has a memory bank id, let the memory tool use it
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
|
@ -458,13 +511,19 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
for tool_name in self.tool_name_to_args.keys():
|
for tool_name in self.tool_name_to_args.keys():
|
||||||
if tool_name == MEMORY_QUERY_TOOL:
|
if tool_name == MEMORY_QUERY_TOOL:
|
||||||
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
|
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
|
||||||
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id]
|
self.tool_name_to_args[tool_name]["vector_db_ids"] = [
|
||||||
|
session_info.vector_db_id
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
|
self.tool_name_to_args[tool_name]["vector_db_ids"].append(
|
||||||
|
session_info.vector_db_id
|
||||||
|
)
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
n_iter = (
|
||||||
|
await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||||
|
)
|
||||||
|
|
||||||
# Build a map of custom tools to their definitions for faster lookup
|
# Build a map of custom tools to their definitions for faster lookup
|
||||||
client_tools = {}
|
client_tools = {}
|
||||||
|
@ -487,6 +546,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
async with tracing.span("inference") as span:
|
async with tracing.span("inference") as span:
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
|
pprint(input_messages)
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
|
@ -542,12 +604,16 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
span.set_attribute("stop_reason", stop_reason)
|
span.set_attribute("stop_reason", stop_reason)
|
||||||
span.set_attribute(
|
span.set_attribute(
|
||||||
"input",
|
"input",
|
||||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
json.dumps(
|
||||||
|
[json.loads(m.model_dump_json()) for m in input_messages]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
output_attr = json.dumps(
|
output_attr = json.dumps(
|
||||||
{
|
{
|
||||||
"content": content,
|
"content": content,
|
||||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
"tool_calls": [
|
||||||
|
json.loads(t.model_dump_json()) for t in tool_calls
|
||||||
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
span.set_attribute("output", output_attr)
|
span.set_attribute("output", output_attr)
|
||||||
|
@ -611,7 +677,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
message.content = [message.content] + output_attachments
|
message.content = [message.content] + output_attachments
|
||||||
yield message
|
yield message
|
||||||
else:
|
else:
|
||||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
logger.debug(
|
||||||
|
f"completion message with EOM (iter: {n_iter}): {str(message)}"
|
||||||
|
)
|
||||||
input_messages = input_messages + [message]
|
input_messages = input_messages + [message]
|
||||||
else:
|
else:
|
||||||
input_messages = input_messages + [message]
|
input_messages = input_messages + [message]
|
||||||
|
@ -660,7 +728,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
"input": message.model_dump_json(),
|
"input": message.model_dump_json(),
|
||||||
},
|
},
|
||||||
) as span:
|
) as span:
|
||||||
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
|
tool_execution_start_time = datetime.now(
|
||||||
|
timezone.utc
|
||||||
|
).isoformat()
|
||||||
tool_result = await self.execute_tool_call_maybe(
|
tool_result = await self.execute_tool_call_maybe(
|
||||||
session_id,
|
session_id,
|
||||||
tool_call,
|
tool_call,
|
||||||
|
@ -709,7 +779,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||||
# but that needs a lot more refactoring of Tool code potentially
|
# but that needs a lot more refactoring of Tool code potentially
|
||||||
if (type(result_message.content) is str) and (
|
if (type(result_message.content) is str) and (
|
||||||
out_attachment := _interpret_content_as_attachment(result_message.content)
|
out_attachment := _interpret_content_as_attachment(
|
||||||
|
result_message.content
|
||||||
|
)
|
||||||
):
|
):
|
||||||
# NOTE: when we push this message back to the model, the model may ignore the
|
# NOTE: when we push this message back to the model, the model may ignore the
|
||||||
# attached file path etc. since the model is trained to only provide a user message
|
# attached file path etc. since the model is trained to only provide a user message
|
||||||
|
@ -746,16 +818,24 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
toolgroup_to_args = {}
|
toolgroup_to_args = {}
|
||||||
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
for toolgroup in (self.agent_config.toolgroups or []) + (
|
||||||
|
toolgroups_for_turn or []
|
||||||
|
):
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
||||||
toolgroup_to_args[tool_group_name] = toolgroup.args
|
toolgroup_to_args[tool_group_name] = toolgroup.args
|
||||||
|
|
||||||
# Determine which tools to include
|
# Determine which tools to include
|
||||||
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
|
tool_groups_to_include = (
|
||||||
|
toolgroups_for_turn or self.agent_config.toolgroups or []
|
||||||
|
)
|
||||||
agent_config_toolgroups = []
|
agent_config_toolgroups = []
|
||||||
for toolgroup in tool_groups_to_include:
|
for toolgroup in tool_groups_to_include:
|
||||||
name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
|
name = (
|
||||||
|
toolgroup.name
|
||||||
|
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||||
|
else toolgroup
|
||||||
|
)
|
||||||
if name not in agent_config_toolgroups:
|
if name not in agent_config_toolgroups:
|
||||||
agent_config_toolgroups.append(name)
|
agent_config_toolgroups.append(name)
|
||||||
|
|
||||||
|
@ -781,20 +861,32 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||||
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
toolgroup_name, input_tool_name = self._parse_toolgroup_name(
|
||||||
|
toolgroup_name_with_maybe_tool_name
|
||||||
|
)
|
||||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||||
if not tools.data:
|
if not tools.data:
|
||||||
available_tool_groups = ", ".join(
|
available_tool_groups = ", ".join(
|
||||||
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
|
[
|
||||||
|
t.identifier
|
||||||
|
for t in (await self.tool_groups_api.list_tool_groups()).data
|
||||||
|
]
|
||||||
)
|
)
|
||||||
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
|
raise ValueError(
|
||||||
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
|
f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}"
|
||||||
|
)
|
||||||
|
if input_tool_name is not None and not any(
|
||||||
|
tool.identifier == input_tool_name for tool in tools.data
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||||
)
|
)
|
||||||
|
|
||||||
for tool_def in tools.data:
|
for tool_def in tools.data:
|
||||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
if (
|
||||||
|
toolgroup_name.startswith("builtin")
|
||||||
|
and toolgroup_name != RAG_TOOL_GROUP
|
||||||
|
):
|
||||||
identifier: str | BuiltinTool | None = tool_def.identifier
|
identifier: str | BuiltinTool | None = tool_def.identifier
|
||||||
if identifier == "web_search":
|
if identifier == "web_search":
|
||||||
identifier = BuiltinTool.brave_search
|
identifier = BuiltinTool.brave_search
|
||||||
|
@ -823,11 +915,18 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
for param in tool_def.parameters
|
for param in tool_def.parameters
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(
|
||||||
|
toolgroup_name, {}
|
||||||
|
)
|
||||||
|
|
||||||
self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args
|
self.tool_defs, self.tool_name_to_args = (
|
||||||
|
list(tool_name_to_def.values()),
|
||||||
|
tool_name_to_args,
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
def _parse_toolgroup_name(
|
||||||
|
self, toolgroup_name_with_maybe_tool_name: str
|
||||||
|
) -> tuple[str, Optional[str]]:
|
||||||
"""Parse a toolgroup name into its components.
|
"""Parse a toolgroup name into its components.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -863,7 +962,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
tool_name_str = tool_name
|
tool_name_str = tool_name
|
||||||
|
|
||||||
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
logger.info(
|
||||||
|
f"executing tool call: {tool_name_str} with args: {tool_call.arguments}"
|
||||||
|
)
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
tool_name=tool_name_str,
|
tool_name=tool_name_str,
|
||||||
kwargs={
|
kwargs={
|
||||||
|
@ -876,144 +977,142 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def handle_documents(
|
# async def handle_documents(
|
||||||
self,
|
# self,
|
||||||
session_id: str,
|
# session_id: str,
|
||||||
documents: List[Document],
|
# documents: List[Document],
|
||||||
input_messages: List[Message],
|
# input_messages: List[Message],
|
||||||
) -> None:
|
# ) -> None:
|
||||||
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
|
# memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
|
||||||
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
|
# code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
|
||||||
content_items = []
|
# content_items = []
|
||||||
url_items = []
|
# url_items = []
|
||||||
pattern = re.compile("^(https?://|file://|data:)")
|
# pattern = re.compile("^(https?://|file://|data:)")
|
||||||
for d in documents:
|
# for d in documents:
|
||||||
if isinstance(d.content, URL):
|
# if isinstance(d.content, URL):
|
||||||
url_items.append(d.content)
|
# url_items.append(d.content)
|
||||||
elif pattern.match(d.content):
|
# elif pattern.match(d.content):
|
||||||
url_items.append(URL(uri=d.content))
|
# url_items.append(URL(uri=d.content))
|
||||||
else:
|
# else:
|
||||||
content_items.append(d)
|
# content_items.append(d)
|
||||||
|
|
||||||
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
# # Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
||||||
if code_interpreter_tool:
|
# if code_interpreter_tool:
|
||||||
for c in content_items:
|
# for c in content_items:
|
||||||
temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
|
# temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
|
||||||
with open(temp_file_path, "w") as temp_file:
|
# with open(temp_file_path, "w") as temp_file:
|
||||||
temp_file.write(c.content)
|
# temp_file.write(c.content)
|
||||||
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
# url_items.append(URL(uri=f"file://{temp_file_path}"))
|
||||||
|
|
||||||
if memory_tool and code_interpreter_tool:
|
# if memory_tool and code_interpreter_tool:
|
||||||
# if both memory and code_interpreter are available, we download the URLs
|
# # if both memory and code_interpreter are available, we download the URLs
|
||||||
# and attach the data to the last message.
|
# # and attach the data to the last message.
|
||||||
await attachment_message(self.tempdir, url_items, input_messages[-1])
|
# await attachment_message(self.tempdir, url_items, input_messages[-1])
|
||||||
# Since memory is present, add all the data to the memory bank
|
# # Since memory is present, add all the data to the memory bank
|
||||||
await self.add_to_session_vector_db(session_id, documents)
|
# await self.add_to_session_vector_db(session_id, documents)
|
||||||
elif code_interpreter_tool:
|
# elif code_interpreter_tool:
|
||||||
# if only code_interpreter is available, we download the URLs to a tempdir
|
# # if only code_interpreter is available, we download the URLs to a tempdir
|
||||||
# and attach the path to them as a message to inference with the
|
# # and attach the path to them as a message to inference with the
|
||||||
# assumption that the model invokes the code_interpreter tool with the path
|
# # assumption that the model invokes the code_interpreter tool with the path
|
||||||
await attachment_message(self.tempdir, url_items, input_messages[-1])
|
# await attachment_message(self.tempdir, url_items, input_messages[-1])
|
||||||
elif memory_tool:
|
# elif memory_tool:
|
||||||
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
# # if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||||
await self.add_to_session_vector_db(session_id, documents)
|
# await self.add_to_session_vector_db(session_id, documents)
|
||||||
else:
|
# else:
|
||||||
# if no memory or code_interpreter tool is available,
|
# # if no memory or code_interpreter tool is available,
|
||||||
# we try to load the data from the URLs and content items as a message to inference
|
# # we try to load the data from the URLs and content items as a message to inference
|
||||||
# and add it to the last message's context
|
# # and add it to the last message's context
|
||||||
input_messages[-1].context = "\n".join(
|
# input_messages[-1].context = "\n".join(
|
||||||
[doc.content for doc in content_items] + await load_data_from_urls(url_items)
|
# [doc.content for doc in content_items] + await load_data_from_urls(url_items)
|
||||||
)
|
# )
|
||||||
|
|
||||||
async def _ensure_vector_db(self, session_id: str) -> str:
|
# async def _ensure_vector_db(self, session_id: str) -> str:
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
# session_info = await self.storage.get_session_info(session_id)
|
||||||
if session_info is None:
|
# if session_info is None:
|
||||||
raise ValueError(f"Session {session_id} not found")
|
# raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
if session_info.vector_db_id is None:
|
# if session_info.vector_db_id is None:
|
||||||
vector_db_id = f"vector_db_{session_id}"
|
# vector_db_id = f"vector_db_{session_id}"
|
||||||
|
|
||||||
# TODO: the semantic for registration is definitely not "creation"
|
# # TODO: the semantic for registration is definitely not "creation"
|
||||||
# so we need to fix it if we expect the agent to create a new vector db
|
# # so we need to fix it if we expect the agent to create a new vector db
|
||||||
# for each session
|
# # for each session
|
||||||
await self.vector_io_api.register_vector_db(
|
# await self.vector_io_api.register_vector_db(
|
||||||
vector_db_id=vector_db_id,
|
# vector_db_id=vector_db_id,
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
# embedding_model="all-MiniLM-L6-v2",
|
||||||
)
|
# )
|
||||||
await self.storage.add_vector_db_to_session(session_id, vector_db_id)
|
# await self.storage.add_vector_db_to_session(session_id, vector_db_id)
|
||||||
else:
|
# else:
|
||||||
vector_db_id = session_info.vector_db_id
|
# vector_db_id = session_info.vector_db_id
|
||||||
|
|
||||||
return vector_db_id
|
# return vector_db_id
|
||||||
|
|
||||||
async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None:
|
# async def add_to_session_vector_db(
|
||||||
vector_db_id = await self._ensure_vector_db(session_id)
|
# self, session_id: str, data: List[Document]
|
||||||
documents = [
|
# ) -> None:
|
||||||
RAGDocument(
|
# vector_db_id = await self._ensure_vector_db(session_id)
|
||||||
document_id=str(uuid.uuid4()),
|
# documents = [
|
||||||
content=a.content,
|
# RAGDocument(
|
||||||
mime_type=a.mime_type,
|
# document_id=str(uuid.uuid4()),
|
||||||
metadata={},
|
# content=a.content,
|
||||||
)
|
# mime_type=a.mime_type,
|
||||||
for a in data
|
# metadata={},
|
||||||
]
|
# )
|
||||||
await self.tool_runtime_api.rag_tool.insert(
|
# for a in data
|
||||||
documents=documents,
|
# ]
|
||||||
vector_db_id=vector_db_id,
|
# await self.tool_runtime_api.rag_tool.insert(
|
||||||
chunk_size_in_tokens=512,
|
# documents=documents,
|
||||||
)
|
# vector_db_id=vector_db_id,
|
||||||
|
# chunk_size_in_tokens=512,
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
async def load_data_from_urls(urls: List[URL]) -> List[str]:
|
async def load_data_from_url(url: URL) -> str:
|
||||||
data = []
|
uri = url.uri
|
||||||
for url in urls:
|
if uri.startswith("http"):
|
||||||
uri = url.uri
|
async with httpx.AsyncClient() as client:
|
||||||
if uri.startswith("file://"):
|
r = await client.get(uri)
|
||||||
filepath = uri[len("file://") :]
|
resp = r.text
|
||||||
with open(filepath, "r") as f:
|
return resp
|
||||||
data.append(f.read())
|
return ""
|
||||||
elif uri.startswith("http"):
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
r = await client.get(uri)
|
|
||||||
resp = r.text
|
|
||||||
data.append(resp)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None:
|
# async def attachment_message(
|
||||||
contents = []
|
# tempdir: str, urls: List[URL], message: UserMessage
|
||||||
|
# ) -> None:
|
||||||
|
# contents = []
|
||||||
|
|
||||||
for url in urls:
|
# for url in urls:
|
||||||
uri = url.uri
|
# uri = url.uri
|
||||||
if uri.startswith("file://"):
|
# if uri.startswith("file://"):
|
||||||
filepath = uri[len("file://") :]
|
# filepath = uri[len("file://") :]
|
||||||
elif uri.startswith("http"):
|
# elif uri.startswith("http"):
|
||||||
path = urlparse(uri).path
|
# path = urlparse(uri).path
|
||||||
basename = os.path.basename(path)
|
# basename = os.path.basename(path)
|
||||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
# filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||||
logger.info(f"Downloading {url} -> {filepath}")
|
# logger.info(f"Downloading {url} -> {filepath}")
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
# async with httpx.AsyncClient() as client:
|
||||||
r = await client.get(uri)
|
# r = await client.get(uri)
|
||||||
resp = r.text
|
# resp = r.text
|
||||||
with open(filepath, "w") as fp:
|
# with open(filepath, "w") as fp:
|
||||||
fp.write(resp)
|
# fp.write(resp)
|
||||||
else:
|
# else:
|
||||||
raise ValueError(f"Unsupported URL {url}")
|
# raise ValueError(f"Unsupported URL {url}")
|
||||||
|
|
||||||
contents.append(
|
# contents.append(
|
||||||
TextContentItem(
|
# TextContentItem(
|
||||||
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
|
# text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
|
||||||
)
|
# )
|
||||||
)
|
# )
|
||||||
|
|
||||||
if isinstance(message.content, list):
|
# if isinstance(message.content, list):
|
||||||
message.content.extend(contents)
|
# message.content.extend(contents)
|
||||||
else:
|
# else:
|
||||||
if isinstance(message.content, str):
|
# if isinstance(message.content, str):
|
||||||
message.content = [TextContentItem(text=message.content)] + contents
|
# message.content = [TextContentItem(text=message.content)] + contents
|
||||||
else:
|
# else:
|
||||||
message.content = [message.content] + contents
|
# message.content = [message.content] + contents
|
||||||
|
|
||||||
|
|
||||||
def _interpret_content_as_attachment(
|
def _interpret_content_as_attachment(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue