From bc8b377a7c2800e3df09db58c651422b33f36ffa Mon Sep 17 00:00:00 2001 From: Sumanth Kamenani Date: Wed, 15 Oct 2025 14:02:48 -0400 Subject: [PATCH] fix(vector-io): handle missing document_id in insert_chunks (#3521) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixed KeyError when chunks don't have document_id in metadata or chunk_metadata. Updated logging to safely extract document_id using getattr and RAG memory to handle different document_id locations. Added test for missing document_id scenarios. Fixes issue #3494 where /v1/vector-io/insert would crash with KeyError. Fixed KeyError when chunks don't have document_id in metadata or chunk_metadata. Updated logging to safely extract document_id using getattr and RAG memory to handle different document_id locations. Added test for missing document_id scenarios. # What does this PR do? Fixes a KeyError crash in `/v1/vector-io/insert` when chunks are missing `document_id` fields. The API was failing even though `document_id` is optional according to the schema. Closes #3494 ## Test Plan **Before fix:** - POST to `/v1/vector-io/insert` with chunks → 500 KeyError - Happened regardless of where `document_id` was placed **After fix:** - Same request works fine → 200 OK - Tested with Postman using FAISS backend - Added unit test covering missing `document_id` scenarios --- llama_stack/apis/vector_io/vector_io.py | 16 ++++++++++ llama_stack/core/routers/vector_io.py | 4 ++- .../inline/tool_runtime/rag/memory.py | 2 +- .../test_vector_io_openai_vector_stores.py | 31 +++++++++++++++++++ 4 files changed, 51 insertions(+), 2 deletions(-) diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 3ced81bdd..a309c47f9 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -93,6 +93,22 @@ class Chunk(BaseModel): return generate_chunk_id(str(uuid.uuid4()), str(self.content)) + @property + def document_id(self) -> str | None: + """Returns the document_id from either metadata or chunk_metadata, with metadata taking precedence.""" + # Check metadata first (takes precedence) + doc_id = self.metadata.get("document_id") + if doc_id is not None: + if not isinstance(doc_id, str): + raise TypeError(f"metadata['document_id'] must be a string, got {type(doc_id).__name__}: {doc_id!r}") + return doc_id + + # Fall back to chunk_metadata if available (Pydantic ensures type safety) + if self.chunk_metadata is not None: + return self.chunk_metadata.document_id + + return None + @json_schema_type class QueryChunksResponse(BaseModel): diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index 3bd8c3073..f4e871a40 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -93,8 +93,10 @@ class VectorIORouter(VectorIO): chunks: list[Chunk], ttl_seconds: int | None = None, ) -> None: + doc_ids = [chunk.document_id for chunk in chunks[:3]] logger.debug( - f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", + f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, " + f"ttl_seconds={ttl_seconds}, chunk_ids={doc_ids}{' and more...' if len(chunks) > 3 else ''}" ) provider = await self.routing_table.get_provider_impl(vector_db_id) return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds) diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 3ccfd0bcb..dc3dfbbca 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -272,7 +272,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti return RAGQueryResult( content=picked, metadata={ - "document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]], + "document_ids": [c.document_id for c in chunks[: len(picked)]], "chunks": [c.content for c in chunks[: len(picked)]], "scores": scores[: len(picked)], "vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]], diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index 7038f8a41..32d59c91b 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -128,6 +128,37 @@ async def test_insert_chunks_missing_db_raises(vector_io_adapter): await vector_io_adapter.insert_chunks("db_not_exist", []) +async def test_insert_chunks_with_missing_document_id(vector_io_adapter): + """Ensure no KeyError when document_id is missing or in different places.""" + from llama_stack.apis.vector_io import Chunk, ChunkMetadata + + fake_index = AsyncMock() + vector_io_adapter.cache["db1"] = fake_index + + # Various document_id scenarios that shouldn't crash + chunks = [ + Chunk(content="has doc_id in metadata", metadata={"document_id": "doc-1"}), + Chunk(content="no doc_id anywhere", metadata={"source": "test"}), + Chunk(content="doc_id in chunk_metadata", chunk_metadata=ChunkMetadata(document_id="doc-3")), + ] + + # Should work without KeyError + await vector_io_adapter.insert_chunks("db1", chunks) + fake_index.insert_chunks.assert_awaited_once() + + +async def test_document_id_with_invalid_type_raises_error(): + """Ensure TypeError is raised when document_id is not a string.""" + from llama_stack.apis.vector_io import Chunk + + # Integer document_id should raise TypeError + chunk = Chunk(content="test", metadata={"document_id": 12345}) + with pytest.raises(TypeError) as exc_info: + _ = chunk.document_id + assert "metadata['document_id'] must be a string" in str(exc_info.value) + assert "got int" in str(exc_info.value) + + async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter): expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1]) fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))