diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 816d21f58..2a93e7b3f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -530,6 +530,8 @@ class ChatAgent(ShieldRunnerMixin): toolgroups.add(toolgroup) tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn) + if documents: + await self.handle_documents(session_id, documents, input_messages, tool_defs) output_attachments = [] @@ -911,7 +913,7 @@ class ChatAgent(ShieldRunnerMixin): async def handle_documents( self, - documents_db_id: str, + session_id: str, documents: List[Document], input_messages: List[Message], tool_defs: Dict[str, ToolDefinition], @@ -943,7 +945,7 @@ class ChatAgent(ShieldRunnerMixin): msg = await attachment_message(self.tempdir, url_items) input_messages.append(msg) # Since memory is present, add all the data to the memory bank - await self.add_to_session_vector_db(documents_db_id, documents) + await self.add_to_session_vector_db(session_id, documents) elif code_interpreter_tool: # if only code_interpreter is available, we download the URLs to a tempdir # and attach the path to them as a message to inference with the @@ -952,7 +954,7 @@ class ChatAgent(ShieldRunnerMixin): input_messages.append(msg) elif memory_tool: # if only memory is available, we load the data from the URLs and content items to the memory bank - await self.add_to_session_vector_db(documents_db_id, documents) + await self.add_to_session_vector_db(session_id, documents) else: # if no memory or code_interpreter tool is available, # we try to load the data from the URLs and content items as a message to inference @@ -961,7 +963,29 @@ class ChatAgent(ShieldRunnerMixin): [doc.content for doc in content_items] + await load_data_from_urls(url_items) ) - async def add_to_session_vector_db(self, vector_db_id: str, data: List[Document]) -> None: + async def _ensure_vector_db(self, session_id: str) -> str: + session_info = await self.storage.get_session_info(session_id) + if session_info is None: + raise ValueError(f"Session {session_id} not found") + + if session_info.vector_db_id is None: + vector_db_id = f"vector_db_{session_id}" + + # TODO: the semantic for registration is definitely not "creation" + # so we need to fix it if we expect the agent to create a new vector db + # for each session + await self.vector_io_api.register_vector_db( + vector_db_id=vector_db_id, + embedding_model="all-MiniLM-L6-v2", + ) + await self.storage.add_vector_db_to_session(session_id, vector_db_id) + else: + vector_db_id = session_info.vector_db_id + + return vector_db_id + + async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None: + vector_db_id = await self._ensure_vector_db(session_id) documents = [ RAGDocument( document_id=str(uuid.uuid4()), diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index a8d5ce69e..9a2663567 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -53,6 +53,17 @@ class AgentPersistence: return AgentSessionInfo(**json.loads(value)) + async def add_vector_db_to_session(self, session_id: str, vector_db_id: str): + session_info = await self.get_session_info(session_id) + if session_info is None: + raise ValueError(f"Session {session_id} not found") + + session_info.vector_db_id = vector_db_id + await self.kvstore.set( + key=f"session:{self.agent_id}:{session_id}", + value=session_info.model_dump_json(), + ) + async def add_turn_to_session(self, session_id: str, turn: Turn): await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",