diff --git a/docs/notebooks/RAG_as_attchements.ipynb b/docs/notebooks/RAG_as_attchements.ipynb new file mode 100644 index 000000000..5b98ac506 --- /dev/null +++ b/docs/notebooks/RAG_as_attchements.ipynb @@ -0,0 +1,193 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_stack_client import LlamaStackClient\n", + "from llama_stack_client.types import Document\n", + "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", + "from llama_stack_client.types.agent_create_params import AgentConfig\n", + "from llama_stack_client.lib.agents.agent import Agent\n", + "from rich.pretty import pprint\n", + "import json\n", + "import uuid\n", + "from pydantic import BaseModel\n", + "import rich\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_ID = \"meta-llama/Llama-3.3-70B-Instruct\"\n", + "\n", + "client = LlamaStackClient(\n", + " base_url=\"http://localhost:8321\",\n", + " provider_data={\n", + " \"fireworks_api_key\": os.environ[\"FIREWORKS_API_KEY\"]\n", + " }\n", + ")\n", + "\n", + "urls = [\n", + " \"memory_optimizations.rst\",\n", + " \"chat.rst\",\n", + " \"llama3.rst\",\n", + " \"datasets.rst\",\n", + " \"qat_finetune.rst\",\n", + " \"lora_finetune.rst\",\n", + "]\n", + "\n", + "attachments = [\n", + " {\n", + " \"content\": f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n", + " \"mime_type\": \"text/plain\",\n", + " }\n", + "\n", + " for i, url in enumerate(urls)\n", + "]\n", + "\n", + "simple_agent = Agent(client, model=MODEL_ID, \n", + " instructions=\"You are a helpful assistant that can answer questions about the Torchtune project.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
Turn(\n",
+       "input_messages=[\n",
+       "│   │   UserMessage(content='What precision formats does torchtune support?', role='user', context=None)\n",
+       "],\n",
+       "output_message=CompletionMessage(\n",
+       "│   │   content='Torchtune supports the following precision formats:\\n\\n* FP32 (32-bit floating point)\\n* FP16 (16-bit floating point)\\n* INT8 (8-bit integer)\\n* BF16 (Brain Floating Point 16, a 16-bit floating point format)\\n\\nThese precision formats can be used for model weights, activations, and gradients, allowing for flexible and efficient tuning of models for various hardware and performance requirements.',\n",
+       "│   │   role='assistant',\n",
+       "│   │   stop_reason='end_of_turn',\n",
+       "│   │   tool_calls=[]\n",
+       "),\n",
+       "session_id='1c23c79b-3945-4e99-bda6-7922b6b4e91c',\n",
+       "started_at=datetime.datetime(2025, 3, 20, 22, 41, 0, 175804, tzinfo=TzInfo(UTC)),\n",
+       "steps=[\n",
+       "│   │   InferenceStep(\n",
+       "│   │   │   api_model_response=CompletionMessage(\n",
+       "│   │   │   │   content='Torchtune supports the following precision formats:\\n\\n* FP32 (32-bit floating point)\\n* FP16 (16-bit floating point)\\n* INT8 (8-bit integer)\\n* BF16 (Brain Floating Point 16, a 16-bit floating point format)\\n\\nThese precision formats can be used for model weights, activations, and gradients, allowing for flexible and efficient tuning of models for various hardware and performance requirements.',\n",
+       "│   │   │   │   role='assistant',\n",
+       "│   │   │   │   stop_reason='end_of_turn',\n",
+       "│   │   │   │   tool_calls=[]\n",
+       "│   │   │   ),\n",
+       "│   │   │   step_id='bf452f18-8fae-470e-9e97-b1af60628fc1',\n",
+       "│   │   │   step_type='inference',\n",
+       "│   │   │   turn_id='efb92c6d-d482-4dd2-ad4b-3250c1e9a231',\n",
+       "│   │   │   completed_at=datetime.datetime(2025, 3, 20, 22, 41, 1, 618765, tzinfo=TzInfo(UTC)),\n",
+       "│   │   │   started_at=datetime.datetime(2025, 3, 20, 22, 41, 0, 175855, tzinfo=TzInfo(UTC))\n",
+       "│   │   )\n",
+       "],\n",
+       "turn_id='efb92c6d-d482-4dd2-ad4b-3250c1e9a231',\n",
+       "completed_at=datetime.datetime(2025, 3, 20, 22, 41, 1, 631357, tzinfo=TzInfo(UTC)),\n",
+       "output_attachments=[]\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mTurn\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[33minput_messages\u001b[0m=\u001b[1m[\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1;35mUserMessage\u001b[0m\u001b[1m(\u001b[0m\u001b[33mcontent\u001b[0m=\u001b[32m'What precision formats does torchtune support?'\u001b[0m, \u001b[33mrole\u001b[0m=\u001b[32m'user'\u001b[0m, \u001b[33mcontext\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m]\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33moutput_message\u001b[0m=\u001b[1;35mCompletionMessage\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[33mcontent\u001b[0m=\u001b[32m'Torchtune supports the following precision formats:\\n\\n* FP32 \u001b[0m\u001b[32m(\u001b[0m\u001b[32m32-bit floating point\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n* FP16 \u001b[0m\u001b[32m(\u001b[0m\u001b[32m16-bit floating point\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n* INT8 \u001b[0m\u001b[32m(\u001b[0m\u001b[32m8-bit integer\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n* BF16 \u001b[0m\u001b[32m(\u001b[0m\u001b[32mBrain Floating Point 16, a 16-bit floating point format\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n\\nThese precision formats can be used for model weights, activations, and gradients, allowing for flexible and efficient tuning of models for various hardware and performance requirements.'\u001b[0m,\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[33mrole\u001b[0m=\u001b[32m'assistant'\u001b[0m,\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[33mstop_reason\u001b[0m=\u001b[32m'end_of_turn'\u001b[0m,\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[33mtool_calls\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33msession_id\u001b[0m=\u001b[32m'1c23c79b-3945-4e99-bda6-7922b6b4e91c'\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33mstarted_at\u001b[0m=\u001b[1;35mdatetime\u001b[0m\u001b[1;35m.datetime\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m2025\u001b[0m, \u001b[1;36m3\u001b[0m, \u001b[1;36m20\u001b[0m, \u001b[1;36m22\u001b[0m, \u001b[1;36m41\u001b[0m, \u001b[1;36m0\u001b[0m, \u001b[1;36m175804\u001b[0m, \u001b[33mtzinfo\u001b[0m=\u001b[1;35mTzInfo\u001b[0m\u001b[1m(\u001b[0mUTC\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33msteps\u001b[0m=\u001b[1m[\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1;35mInferenceStep\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mapi_model_response\u001b[0m=\u001b[1;35mCompletionMessage\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mcontent\u001b[0m=\u001b[32m'Torchtune supports the following precision formats:\\n\\n* FP32 \u001b[0m\u001b[32m(\u001b[0m\u001b[32m32-bit floating point\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n* FP16 \u001b[0m\u001b[32m(\u001b[0m\u001b[32m16-bit floating point\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n* INT8 \u001b[0m\u001b[32m(\u001b[0m\u001b[32m8-bit integer\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n* BF16 \u001b[0m\u001b[32m(\u001b[0m\u001b[32mBrain Floating Point 16, a 16-bit floating point format\u001b[0m\u001b[32m)\u001b[0m\u001b[32m\\n\\nThese precision formats can be used for model weights, activations, and gradients, allowing for flexible and efficient tuning of models for various hardware and performance requirements.'\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mrole\u001b[0m=\u001b[32m'assistant'\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mstop_reason\u001b[0m=\u001b[32m'end_of_turn'\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mtool_calls\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mstep_id\u001b[0m=\u001b[32m'bf452f18-8fae-470e-9e97-b1af60628fc1'\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mstep_type\u001b[0m=\u001b[32m'inference'\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mturn_id\u001b[0m=\u001b[32m'efb92c6d-d482-4dd2-ad4b-3250c1e9a231'\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mcompleted_at\u001b[0m=\u001b[1;35mdatetime\u001b[0m\u001b[1;35m.datetime\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m2025\u001b[0m, \u001b[1;36m3\u001b[0m, \u001b[1;36m20\u001b[0m, \u001b[1;36m22\u001b[0m, \u001b[1;36m41\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m618765\u001b[0m, \u001b[33mtzinfo\u001b[0m=\u001b[1;35mTzInfo\u001b[0m\u001b[1m(\u001b[0mUTC\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mstarted_at\u001b[0m=\u001b[1;35mdatetime\u001b[0m\u001b[1;35m.datetime\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m2025\u001b[0m, \u001b[1;36m3\u001b[0m, \u001b[1;36m20\u001b[0m, \u001b[1;36m22\u001b[0m, \u001b[1;36m41\u001b[0m, \u001b[1;36m0\u001b[0m, \u001b[1;36m175855\u001b[0m, \u001b[33mtzinfo\u001b[0m=\u001b[1;35mTzInfo\u001b[0m\u001b[1m(\u001b[0mUTC\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m]\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33mturn_id\u001b[0m=\u001b[32m'efb92c6d-d482-4dd2-ad4b-3250c1e9a231'\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33mcompleted_at\u001b[0m=\u001b[1;35mdatetime\u001b[0m\u001b[1;35m.datetime\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m2025\u001b[0m, \u001b[1;36m3\u001b[0m, \u001b[1;36m20\u001b[0m, \u001b[1;36m22\u001b[0m, \u001b[1;36m41\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m631357\u001b[0m, \u001b[33mtzinfo\u001b[0m=\u001b[1;35mTzInfo\u001b[0m\u001b[1m(\u001b[0mUTC\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33moutput_attachments\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "simple_session_id = simple_agent.create_session(session_name=f\"simple_session_{uuid.uuid4()}\")\n", + "response = simple_agent.create_turn(\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"What precision formats does torchtune support?\"\n", + " }\n", + " ],\n", + " session_id=simple_session_id,\n", + " stream=False\n", + " )\n", + "\n", + "pprint(response)\n", + "\n", + "session_response = client.agents.session.retrieve(agent_id=simple_agent.agent_id, session_id=simple_session_id)\n", + "pprint(session_response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "master", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 88b6e9697..a511e9ca5 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -40,10 +40,10 @@ from llama_stack.apis.agents import ( Turn, ) from llama_stack.apis.common.content_types import ( - URL, TextContentItem, ToolCallDelta, ToolCallParseStatus, + URL, ) from llama_stack.apis.inference import ( ChatCompletionResponseEventType, @@ -80,7 +80,9 @@ from .safety import SafetyException, ShieldRunnerMixin def make_random_string(length: int = 8): - return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) + return "".join( + secrets.choice(string.ascii_letters + string.digits) for _ in range(length) + ) TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") @@ -179,7 +181,9 @@ class ChatAgent(ShieldRunnerMixin): messages.extend(self.turn_to_messages(turn)) return messages - async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: + async def create_and_execute_turn( + self, request: AgentTurnCreateRequest + ) -> AsyncGenerator: await self._initialize_tools(request.toolgroups) async with tracing.span("create_and_execute_turn") as span: span.set_attribute("session_id", request.session_id) @@ -220,13 +224,16 @@ class ChatAgent(ShieldRunnerMixin): messages = await self.get_messages_from_turns(turns) if is_resume: tool_response_messages = [ - ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses + ToolResponseMessage(call_id=x.call_id, content=x.content) + for x in request.tool_responses ] messages.extend(tool_response_messages) last_turn = turns[-1] last_turn_messages = self.turn_to_messages(last_turn) last_turn_messages = [ - x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) + x + for x in last_turn_messages + if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) ] last_turn_messages.extend(tool_response_messages) @@ -236,17 +243,31 @@ class ChatAgent(ShieldRunnerMixin): # mark tool execution step as complete # if there's no tool execution in progress step (due to storage, or tool call parsing on client), # we'll create a new tool execution step with current time - in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( - request.session_id, request.turn_id + in_progress_tool_call_step = ( + await self.storage.get_in_progress_tool_call_step( + request.session_id, request.turn_id + ) ) now = datetime.now(timezone.utc).isoformat() tool_execution_step = ToolExecutionStep( - step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), + step_id=( + in_progress_tool_call_step.step_id + if in_progress_tool_call_step + else str(uuid.uuid4()) + ), turn_id=request.turn_id, - tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), + tool_calls=( + in_progress_tool_call_step.tool_calls + if in_progress_tool_call_step + else [] + ), tool_responses=request.tool_responses, completed_at=now, - started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), + started_at=( + in_progress_tool_call_step.started_at + if in_progress_tool_call_step + else now + ), ) steps.append(tool_execution_step) yield AgentTurnResponseStreamChunk( @@ -280,9 +301,14 @@ class ChatAgent(ShieldRunnerMixin): output_message = chunk continue - assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" + assert isinstance( + chunk, AgentTurnResponseStreamChunk + ), f"Unexpected type {type(chunk)}" event = chunk.event - if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: + if ( + event.payload.event_type + == AgentTurnResponseEventType.step_complete.value + ): steps.append(event.payload.step_details) yield chunk @@ -440,6 +466,18 @@ class ChatAgent(ShieldRunnerMixin): ) span.set_attribute("output", "no violations") + async def get_raw_document_text(self, document: Document) -> str: + if isinstance(document.content, URL): + return await load_data_from_url(document.content) + elif isinstance(document.content, str): + return document.content + elif isinstance(document.content, TextContentItem): + return document.content.text + else: + raise ValueError( + f"Unexpected document content type: {type(document.content)}" + ) + async def _run( self, session_id: str, @@ -449,8 +487,23 @@ class ChatAgent(ShieldRunnerMixin): stream: bool = False, documents: Optional[List[Document]] = None, ) -> AsyncGenerator: + # if documents: + # await self.handle_documents(session_id, documents, input_messages) + + # if document is passed in a turn, we parse the raw text of the document + # and sent it as a user message if documents: - await self.handle_documents(session_id, documents, input_messages) + contexts = [] + for document in documents: + raw_document_text = await self.get_raw_document_text(document) + contexts.append(TextContentItem(text=raw_document_text)) + # modify the last user message to include the document + input_messages.append( + ToolResponseMessage( + call_id=str(uuid.uuid4()), + content=contexts, + ) + ) session_info = await self.storage.get_session_info(session_id) # if the session has a memory bank id, let the memory tool use it @@ -458,13 +511,19 @@ class ChatAgent(ShieldRunnerMixin): for tool_name in self.tool_name_to_args.keys(): if tool_name == MEMORY_QUERY_TOOL: if "vector_db_ids" not in self.tool_name_to_args[tool_name]: - self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id] + self.tool_name_to_args[tool_name]["vector_db_ids"] = [ + session_info.vector_db_id + ] else: - self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id) + self.tool_name_to_args[tool_name]["vector_db_ids"].append( + session_info.vector_db_id + ) output_attachments = [] - n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 + n_iter = ( + await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 + ) # Build a map of custom tools to their definitions for faster lookup client_tools = {} @@ -487,6 +546,9 @@ class ChatAgent(ShieldRunnerMixin): stop_reason = None async with tracing.span("inference") as span: + from rich.pretty import pprint + + pprint(input_messages) async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, @@ -542,12 +604,16 @@ class ChatAgent(ShieldRunnerMixin): span.set_attribute("stop_reason", stop_reason) span.set_attribute( "input", - json.dumps([json.loads(m.model_dump_json()) for m in input_messages]), + json.dumps( + [json.loads(m.model_dump_json()) for m in input_messages] + ), ) output_attr = json.dumps( { "content": content, - "tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls], + "tool_calls": [ + json.loads(t.model_dump_json()) for t in tool_calls + ], } ) span.set_attribute("output", output_attr) @@ -611,7 +677,9 @@ class ChatAgent(ShieldRunnerMixin): message.content = [message.content] + output_attachments yield message else: - logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") + logger.debug( + f"completion message with EOM (iter: {n_iter}): {str(message)}" + ) input_messages = input_messages + [message] else: input_messages = input_messages + [message] @@ -660,7 +728,9 @@ class ChatAgent(ShieldRunnerMixin): "input": message.model_dump_json(), }, ) as span: - tool_execution_start_time = datetime.now(timezone.utc).isoformat() + tool_execution_start_time = datetime.now( + timezone.utc + ).isoformat() tool_result = await self.execute_tool_call_maybe( session_id, tool_call, @@ -709,7 +779,9 @@ class ChatAgent(ShieldRunnerMixin): # TODO: add tool-input touchpoint and a "start" event for this step also # but that needs a lot more refactoring of Tool code potentially if (type(result_message.content) is str) and ( - out_attachment := _interpret_content_as_attachment(result_message.content) + out_attachment := _interpret_content_as_attachment( + result_message.content + ) ): # NOTE: when we push this message back to the model, the model may ignore the # attached file path etc. since the model is trained to only provide a user message @@ -746,16 +818,24 @@ class ChatAgent(ShieldRunnerMixin): toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, ) -> None: toolgroup_to_args = {} - for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []): + for toolgroup in (self.agent_config.toolgroups or []) + ( + toolgroups_for_turn or [] + ): if isinstance(toolgroup, AgentToolGroupWithArgs): tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name) toolgroup_to_args[tool_group_name] = toolgroup.args # Determine which tools to include - tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or [] + tool_groups_to_include = ( + toolgroups_for_turn or self.agent_config.toolgroups or [] + ) agent_config_toolgroups = [] for toolgroup in tool_groups_to_include: - name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup + name = ( + toolgroup.name + if isinstance(toolgroup, AgentToolGroupWithArgs) + else toolgroup + ) if name not in agent_config_toolgroups: agent_config_toolgroups.append(name) @@ -781,20 +861,32 @@ class ChatAgent(ShieldRunnerMixin): }, ) for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: - toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) + toolgroup_name, input_tool_name = self._parse_toolgroup_name( + toolgroup_name_with_maybe_tool_name + ) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) if not tools.data: available_tool_groups = ", ".join( - [t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data] + [ + t.identifier + for t in (await self.tool_groups_api.list_tool_groups()).data + ] ) - raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}") - if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data): + raise ValueError( + f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}" + ) + if input_tool_name is not None and not any( + tool.identifier == input_tool_name for tool in tools.data + ): raise ValueError( f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}" ) for tool_def in tools.data: - if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP: + if ( + toolgroup_name.startswith("builtin") + and toolgroup_name != RAG_TOOL_GROUP + ): identifier: str | BuiltinTool | None = tool_def.identifier if identifier == "web_search": identifier = BuiltinTool.brave_search @@ -823,11 +915,18 @@ class ChatAgent(ShieldRunnerMixin): for param in tool_def.parameters }, ) - tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {}) + tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get( + toolgroup_name, {} + ) - self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args + self.tool_defs, self.tool_name_to_args = ( + list(tool_name_to_def.values()), + tool_name_to_args, + ) - def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]: + def _parse_toolgroup_name( + self, toolgroup_name_with_maybe_tool_name: str + ) -> tuple[str, Optional[str]]: """Parse a toolgroup name into its components. Args: @@ -863,7 +962,9 @@ class ChatAgent(ShieldRunnerMixin): else: tool_name_str = tool_name - logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}") + logger.info( + f"executing tool call: {tool_name_str} with args: {tool_call.arguments}" + ) result = await self.tool_runtime_api.invoke_tool( tool_name=tool_name_str, kwargs={ @@ -876,144 +977,142 @@ class ChatAgent(ShieldRunnerMixin): logger.debug(f"tool call {tool_name_str} completed with result: {result}") return result - async def handle_documents( - self, - session_id: str, - documents: List[Document], - input_messages: List[Message], - ) -> None: - memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs) - code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs) - content_items = [] - url_items = [] - pattern = re.compile("^(https?://|file://|data:)") - for d in documents: - if isinstance(d.content, URL): - url_items.append(d.content) - elif pattern.match(d.content): - url_items.append(URL(uri=d.content)) - else: - content_items.append(d) + # async def handle_documents( + # self, + # session_id: str, + # documents: List[Document], + # input_messages: List[Message], + # ) -> None: + # memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs) + # code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs) + # content_items = [] + # url_items = [] + # pattern = re.compile("^(https?://|file://|data:)") + # for d in documents: + # if isinstance(d.content, URL): + # url_items.append(d.content) + # elif pattern.match(d.content): + # url_items.append(URL(uri=d.content)) + # else: + # content_items.append(d) - # Save the contents to a tempdir and use its path as a URL if code interpreter is present - if code_interpreter_tool: - for c in content_items: - temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt") - with open(temp_file_path, "w") as temp_file: - temp_file.write(c.content) - url_items.append(URL(uri=f"file://{temp_file_path}")) + # # Save the contents to a tempdir and use its path as a URL if code interpreter is present + # if code_interpreter_tool: + # for c in content_items: + # temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt") + # with open(temp_file_path, "w") as temp_file: + # temp_file.write(c.content) + # url_items.append(URL(uri=f"file://{temp_file_path}")) - if memory_tool and code_interpreter_tool: - # if both memory and code_interpreter are available, we download the URLs - # and attach the data to the last message. - await attachment_message(self.tempdir, url_items, input_messages[-1]) - # Since memory is present, add all the data to the memory bank - await self.add_to_session_vector_db(session_id, documents) - elif code_interpreter_tool: - # if only code_interpreter is available, we download the URLs to a tempdir - # and attach the path to them as a message to inference with the - # assumption that the model invokes the code_interpreter tool with the path - await attachment_message(self.tempdir, url_items, input_messages[-1]) - elif memory_tool: - # if only memory is available, we load the data from the URLs and content items to the memory bank - await self.add_to_session_vector_db(session_id, documents) - else: - # if no memory or code_interpreter tool is available, - # we try to load the data from the URLs and content items as a message to inference - # and add it to the last message's context - input_messages[-1].context = "\n".join( - [doc.content for doc in content_items] + await load_data_from_urls(url_items) - ) + # if memory_tool and code_interpreter_tool: + # # if both memory and code_interpreter are available, we download the URLs + # # and attach the data to the last message. + # await attachment_message(self.tempdir, url_items, input_messages[-1]) + # # Since memory is present, add all the data to the memory bank + # await self.add_to_session_vector_db(session_id, documents) + # elif code_interpreter_tool: + # # if only code_interpreter is available, we download the URLs to a tempdir + # # and attach the path to them as a message to inference with the + # # assumption that the model invokes the code_interpreter tool with the path + # await attachment_message(self.tempdir, url_items, input_messages[-1]) + # elif memory_tool: + # # if only memory is available, we load the data from the URLs and content items to the memory bank + # await self.add_to_session_vector_db(session_id, documents) + # else: + # # if no memory or code_interpreter tool is available, + # # we try to load the data from the URLs and content items as a message to inference + # # and add it to the last message's context + # input_messages[-1].context = "\n".join( + # [doc.content for doc in content_items] + await load_data_from_urls(url_items) + # ) - async def _ensure_vector_db(self, session_id: str) -> str: - session_info = await self.storage.get_session_info(session_id) - if session_info is None: - raise ValueError(f"Session {session_id} not found") + # async def _ensure_vector_db(self, session_id: str) -> str: + # session_info = await self.storage.get_session_info(session_id) + # if session_info is None: + # raise ValueError(f"Session {session_id} not found") - if session_info.vector_db_id is None: - vector_db_id = f"vector_db_{session_id}" + # if session_info.vector_db_id is None: + # vector_db_id = f"vector_db_{session_id}" - # TODO: the semantic for registration is definitely not "creation" - # so we need to fix it if we expect the agent to create a new vector db - # for each session - await self.vector_io_api.register_vector_db( - vector_db_id=vector_db_id, - embedding_model="all-MiniLM-L6-v2", - ) - await self.storage.add_vector_db_to_session(session_id, vector_db_id) - else: - vector_db_id = session_info.vector_db_id + # # TODO: the semantic for registration is definitely not "creation" + # # so we need to fix it if we expect the agent to create a new vector db + # # for each session + # await self.vector_io_api.register_vector_db( + # vector_db_id=vector_db_id, + # embedding_model="all-MiniLM-L6-v2", + # ) + # await self.storage.add_vector_db_to_session(session_id, vector_db_id) + # else: + # vector_db_id = session_info.vector_db_id - return vector_db_id + # return vector_db_id - async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None: - vector_db_id = await self._ensure_vector_db(session_id) - documents = [ - RAGDocument( - document_id=str(uuid.uuid4()), - content=a.content, - mime_type=a.mime_type, - metadata={}, - ) - for a in data - ] - await self.tool_runtime_api.rag_tool.insert( - documents=documents, - vector_db_id=vector_db_id, - chunk_size_in_tokens=512, - ) + # async def add_to_session_vector_db( + # self, session_id: str, data: List[Document] + # ) -> None: + # vector_db_id = await self._ensure_vector_db(session_id) + # documents = [ + # RAGDocument( + # document_id=str(uuid.uuid4()), + # content=a.content, + # mime_type=a.mime_type, + # metadata={}, + # ) + # for a in data + # ] + # await self.tool_runtime_api.rag_tool.insert( + # documents=documents, + # vector_db_id=vector_db_id, + # chunk_size_in_tokens=512, + # ) -async def load_data_from_urls(urls: List[URL]) -> List[str]: - data = [] - for url in urls: - uri = url.uri - if uri.startswith("file://"): - filepath = uri[len("file://") :] - with open(filepath, "r") as f: - data.append(f.read()) - elif uri.startswith("http"): - async with httpx.AsyncClient() as client: - r = await client.get(uri) - resp = r.text - data.append(resp) - return data +async def load_data_from_url(url: URL) -> str: + uri = url.uri + if uri.startswith("http"): + async with httpx.AsyncClient() as client: + r = await client.get(uri) + resp = r.text + return resp + return "" -async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None: - contents = [] +# async def attachment_message( +# tempdir: str, urls: List[URL], message: UserMessage +# ) -> None: +# contents = [] - for url in urls: - uri = url.uri - if uri.startswith("file://"): - filepath = uri[len("file://") :] - elif uri.startswith("http"): - path = urlparse(uri).path - basename = os.path.basename(path) - filepath = f"{tempdir}/{make_random_string() + basename}" - logger.info(f"Downloading {url} -> {filepath}") +# for url in urls: +# uri = url.uri +# if uri.startswith("file://"): +# filepath = uri[len("file://") :] +# elif uri.startswith("http"): +# path = urlparse(uri).path +# basename = os.path.basename(path) +# filepath = f"{tempdir}/{make_random_string() + basename}" +# logger.info(f"Downloading {url} -> {filepath}") - async with httpx.AsyncClient() as client: - r = await client.get(uri) - resp = r.text - with open(filepath, "w") as fp: - fp.write(resp) - else: - raise ValueError(f"Unsupported URL {url}") +# async with httpx.AsyncClient() as client: +# r = await client.get(uri) +# resp = r.text +# with open(filepath, "w") as fp: +# fp.write(resp) +# else: +# raise ValueError(f"Unsupported URL {url}") - contents.append( - TextContentItem( - text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.' - ) - ) +# contents.append( +# TextContentItem( +# text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.' +# ) +# ) - if isinstance(message.content, list): - message.content.extend(contents) - else: - if isinstance(message.content, str): - message.content = [TextContentItem(text=message.content)] + contents - else: - message.content = [message.content] + contents +# if isinstance(message.content, list): +# message.content.extend(contents) +# else: +# if isinstance(message.content, str): +# message.content = [TextContentItem(text=message.content)] + contents +# else: +# message.content = [message.content] + contents def _interpret_content_as_attachment(