From aa546de8d631fd2ac4fc2397c74d7e747f86dfee Mon Sep 17 00:00:00 2001 From: Daniele Martinoli Date: Fri, 28 Feb 2025 08:32:02 +0100 Subject: [PATCH] renamed insert_vector_db_id to documents_db_id, removed vector_db_id from session info Signed-off-by: Daniele Martinoli --- docs/source/building_applications/rag.md | 11 +++++-- .../agents/meta_reference/agent_instance.py | 30 +++---------------- .../agents/meta_reference/persistence.py | 11 ------- 3 files changed, 13 insertions(+), 39 deletions(-) diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index e2e5fd6b5..08d68fefa 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -93,7 +93,14 @@ agent_config = AgentConfig( { "name": "builtin::rag/knowledge_search", "args": { - "vector_db_ids": [vector_db_id], + # 'documents_db_id' holds the ID of the registered vector database + # where the provided documents will be ingested. This argument is mandatory + # when the 'documents' parameter is provided in a 'create_turn' invocation. + # When provided, 'documents_db_id' will also be used to extract contextual information + # for the query. + "documents_db_id": vector_db_id, + # Optionally, the 'vector_db_ids' argument can specify additional vector databases + # to use at query time. }, } ], @@ -109,7 +116,7 @@ response = agent.create_turn( ], documents=[ { - "content": "https://raw.githubusercontent.com/example/doc.rst", + "content": "https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/memory_optimizations.rst", "mime_type": "text/plain", } ], diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index cf3cc8ef7..816d21f58 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -911,7 +911,7 @@ class ChatAgent(ShieldRunnerMixin): async def handle_documents( self, - session_id: str, + documents_db_id: str, documents: List[Document], input_messages: List[Message], tool_defs: Dict[str, ToolDefinition], @@ -943,7 +943,7 @@ class ChatAgent(ShieldRunnerMixin): msg = await attachment_message(self.tempdir, url_items) input_messages.append(msg) # Since memory is present, add all the data to the memory bank - await self.add_to_session_vector_db(session_id, documents) + await self.add_to_session_vector_db(documents_db_id, documents) elif code_interpreter_tool: # if only code_interpreter is available, we download the URLs to a tempdir # and attach the path to them as a message to inference with the @@ -952,7 +952,7 @@ class ChatAgent(ShieldRunnerMixin): input_messages.append(msg) elif memory_tool: # if only memory is available, we load the data from the URLs and content items to the memory bank - await self.add_to_session_vector_db(session_id, documents) + await self.add_to_session_vector_db(documents_db_id, documents) else: # if no memory or code_interpreter tool is available, # we try to load the data from the URLs and content items as a message to inference @@ -961,29 +961,7 @@ class ChatAgent(ShieldRunnerMixin): [doc.content for doc in content_items] + await load_data_from_urls(url_items) ) - async def _ensure_vector_db(self, session_id: str) -> str: - session_info = await self.storage.get_session_info(session_id) - if session_info is None: - raise ValueError(f"Session {session_id} not found") - - if session_info.vector_db_id is None: - vector_db_id = f"vector_db_{session_id}" - - # TODO: the semantic for registration is definitely not "creation" - # so we need to fix it if we expect the agent to create a new vector db - # for each session - await self.vector_io_api.register_vector_db( - vector_db_id=vector_db_id, - embedding_model="all-MiniLM-L6-v2", - ) - await self.storage.add_vector_db_to_session(session_id, vector_db_id) - else: - vector_db_id = session_info.vector_db_id - - return vector_db_id - - async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None: - vector_db_id = await self._ensure_vector_db(session_id) + async def add_to_session_vector_db(self, vector_db_id: str, data: List[Document]) -> None: documents = [ RAGDocument( document_id=str(uuid.uuid4()), diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 9a2663567..a8d5ce69e 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -53,17 +53,6 @@ class AgentPersistence: return AgentSessionInfo(**json.loads(value)) - async def add_vector_db_to_session(self, session_id: str, vector_db_id: str): - session_info = await self.get_session_info(session_id) - if session_info is None: - raise ValueError(f"Session {session_id} not found") - - session_info.vector_db_id = vector_db_id - await self.kvstore.set( - key=f"session:{self.agent_id}:{session_id}", - value=session_info.model_dump_json(), - ) - async def add_turn_to_session(self, session_id: str, turn: Turn): await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",