diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 9c1c3170f..dbe251921 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -10020,7 +10020,8 @@ "type": "object", "properties": { "content": { - "$ref": "#/components/schemas/InterleavedContent" + "$ref": "#/components/schemas/InterleavedContent", + "description": "The content of the chunk, which can be interleaved text, images, or other types." }, "metadata": { "type": "object", @@ -10045,7 +10046,15 @@ "type": "object" } ] - } + }, + "description": "Metadata associated with the chunk, such as document ID, source, or other relevant information." + }, + "embedding": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Optional embedding for the chunk. If not provided, it will be computed later." } }, "additionalProperties": false, @@ -10053,9 +10062,10 @@ "content", "metadata" ], - "title": "Chunk" + "title": "Chunk", + "description": "A chunk of content that can be inserted into a vector database." }, - "description": "The chunks to insert." + "description": "The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types. `metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional. If `metadata` is provided, you configure how Llama Stack formats the chunk during generation. If `embedding` is not provided, it will be computed later." }, "ttl_seconds": { "type": "integer", @@ -12285,7 +12295,8 @@ "type": "object", "properties": { "content": { - "$ref": "#/components/schemas/InterleavedContent" + "$ref": "#/components/schemas/InterleavedContent", + "description": "The content of the chunk, which can be interleaved text, images, or other types." }, "metadata": { "type": "object", @@ -12310,7 +12321,15 @@ "type": "object" } ] - } + }, + "description": "Metadata associated with the chunk, such as document ID, source, or other relevant information." + }, + "embedding": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Optional embedding for the chunk. If not provided, it will be computed later." } }, "additionalProperties": false, @@ -12318,7 +12337,8 @@ "content", "metadata" ], - "title": "Chunk" + "title": "Chunk", + "description": "A chunk of content that can be inserted into a vector database." } }, "scores": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 1afe870cf..2dfb037da 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -7024,6 +7024,9 @@ components: properties: content: $ref: '#/components/schemas/InterleavedContent' + description: >- + The content of the chunk, which can be interleaved text, images, + or other types. metadata: type: object additionalProperties: @@ -7034,12 +7037,29 @@ components: - type: string - type: array - type: object + description: >- + Metadata associated with the chunk, such as document ID, source, + or other relevant information. + embedding: + type: array + items: + type: number + description: >- + Optional embedding for the chunk. If not provided, it will be computed + later. additionalProperties: false required: - content - metadata title: Chunk - description: The chunks to insert. + description: >- + A chunk of content that can be inserted into a vector database. + description: >- + The chunks to insert. Each `Chunk` should contain content which can be + interleaved text, images, or other types. `metadata`: `dict[str, Any]` + and `embedding`: `List[float]` are optional. If `metadata` is provided, + you configure how Llama Stack formats the chunk during generation. If + `embedding` is not provided, it will be computed later. ttl_seconds: type: integer description: The time to live of the chunks. @@ -8537,6 +8557,9 @@ components: properties: content: $ref: '#/components/schemas/InterleavedContent' + description: >- + The content of the chunk, which can be interleaved text, images, + or other types. metadata: type: object additionalProperties: @@ -8547,11 +8570,23 @@ components: - type: string - type: array - type: object + description: >- + Metadata associated with the chunk, such as document ID, source, + or other relevant information. + embedding: + type: array + items: + type: number + description: >- + Optional embedding for the chunk. If not provided, it will be computed + later. additionalProperties: false required: - content - metadata title: Chunk + description: >- + A chunk of content that can be inserted into a vector database. scores: type: array items: diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index dbe90a7fc..289c38991 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -57,6 +57,31 @@ chunks = [ ] client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks) ``` + +#### Using Precomputed Embeddings +If you decide to precompute embeddings for your documents, you can insert them directly into the vector database by +including the embedding vectors in the chunk data. This is useful if you have a separate embedding service or if you +want to customize the ingestion process. +```python +chunks_with_embeddings = [ + { + "content": "First chunk of text", + "mime_type": "text/plain", + "embedding": [0.1, 0.2, 0.3, ...], # Your precomputed embedding vector + "metadata": {"document_id": "doc1", "section": "introduction"}, + }, + { + "content": "Second chunk of text", + "mime_type": "text/plain", + "embedding": [0.2, 0.3, 0.4, ...], # Your precomputed embedding vector + "metadata": {"document_id": "doc1", "section": "methodology"}, + }, +] +client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks_with_embeddings) +``` +When providing precomputed embeddings, ensure the embedding dimension matches the embedding_dimension specified when +registering the vector database. + ### Retrieval You can query the vector database to retrieve documents based on their embeddings. ```python diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 3ac62d42c..44cc8f904 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -19,8 +19,16 @@ from llama_stack.schema_utils import json_schema_type, webmethod class Chunk(BaseModel): + """ + A chunk of content that can be inserted into a vector database. + :param content: The content of the chunk, which can be interleaved text, images, or other types. + :param embedding: Optional embedding for the chunk. If not provided, it will be computed later. + :param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information. + """ + content: InterleavedContent metadata: dict[str, Any] = Field(default_factory=dict) + embedding: list[float] | None = None @json_schema_type @@ -50,7 +58,10 @@ class VectorIO(Protocol): """Insert chunks into a vector database. :param vector_db_id: The identifier of the vector database to insert the chunks into. - :param chunks: The chunks to insert. + :param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types. + `metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional. + If `metadata` is provided, you configure how Llama Stack formats the chunk during generation. + If `embedding` is not provided, it will be computed later. :param ttl_seconds: The time to live of the chunks. """ ... diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index c2d264c91..4776d47d0 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -146,7 +146,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti ] for i, chunk in enumerate(chunks): metadata = chunk.metadata - tokens += metadata["token_count"] + tokens += metadata.get("token_count", 0) tokens += metadata.get("metadata_token_count", 0) if tokens > query_config.max_tokens_in_context: diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 3655c7049..4cd15860b 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -171,6 +171,22 @@ def make_overlapped_chunks( return chunks +def _validate_embedding(embedding: NDArray, index: int, expected_dimension: int): + """Helper method to validate embedding format and dimensions""" + if not isinstance(embedding, (list | np.ndarray)): + raise ValueError(f"Embedding at index {index} must be a list or numpy array, got {type(embedding)}") + + if isinstance(embedding, np.ndarray): + if not np.issubdtype(embedding.dtype, np.number): + raise ValueError(f"Embedding at index {index} contains non-numeric values") + else: + if not all(isinstance(e, (float | int | np.number)) for e in embedding): + raise ValueError(f"Embedding at index {index} contains non-numeric values") + + if len(embedding) != expected_dimension: + raise ValueError(f"Embedding at index {index} has dimension {len(embedding)}, expected {expected_dimension}") + + class EmbeddingIndex(ABC): @abstractmethod async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): @@ -199,11 +215,22 @@ class VectorDBWithIndex: self, chunks: list[Chunk], ) -> None: - embeddings_response = await self.inference_api.embeddings( - self.vector_db.embedding_model, [x.content for x in chunks] - ) - embeddings = np.array(embeddings_response.embeddings) + chunks_to_embed = [] + for i, c in enumerate(chunks): + if c.embedding is None: + chunks_to_embed.append(c) + else: + _validate_embedding(c.embedding, i, self.vector_db.embedding_dimension) + if chunks_to_embed: + resp = await self.inference_api.embeddings( + self.vector_db.embedding_model, + [c.content for c in chunks_to_embed], + ) + for c, embedding in zip(chunks_to_embed, resp.embeddings, strict=False): + c.embedding = embedding + + embeddings = np.array([c.embedding for c in chunks], dtype=np.float32) await self.index.add_chunks(chunks, embeddings) async def query_chunks( diff --git a/tests/integration/vector_io/test_vector_io.py b/tests/integration/vector_io/test_vector_io.py index 90cb00313..f1cac9701 100644 --- a/tests/integration/vector_io/test_vector_io.py +++ b/tests/integration/vector_io/test_vector_io.py @@ -120,3 +120,37 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, sample_ch top_match = response.chunks[0] assert top_match is not None assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}" + + +def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id): + vector_db_id = "test_precomputed_embeddings_db" + client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model=embedding_model_id, + embedding_dimension=384, + ) + + chunks_with_embeddings = [ + Chunk( + content="This is a test chunk with precomputed embedding.", + metadata={"document_id": "doc1", "source": "precomputed"}, + embedding=[0.1] * 384, + ), + ] + + client_with_empty_registry.vector_io.insert( + vector_db_id=vector_db_id, + chunks=chunks_with_embeddings, + ) + + # Query for the first document + response = client_with_empty_registry.vector_io.query( + vector_db_id=vector_db_id, + query="precomputed embedding test", + ) + + # Verify the top result is the expected document + assert response is not None + assert len(response.chunks) > 0 + assert response.chunks[0].metadata["document_id"] == "doc1" + assert response.chunks[0].metadata["source"] == "precomputed" diff --git a/tests/unit/providers/vector_io/test_qdrant.py b/tests/unit/providers/vector_io/test_qdrant.py index 34df9b52f..607eccb24 100644 --- a/tests/unit/providers/vector_io/test_qdrant.py +++ b/tests/unit/providers/vector_io/test_qdrant.py @@ -50,6 +50,7 @@ def mock_vector_db(vector_db_id) -> MagicMock: mock_vector_db = MagicMock(spec=VectorDB) mock_vector_db.embedding_model = "embedding_model" mock_vector_db.identifier = vector_db_id + mock_vector_db.embedding_dimension = 384 return mock_vector_db diff --git a/tests/unit/rag/test_vector_store.py b/tests/unit/rag/test_vector_store.py index f97808a6d..9d6b9ee67 100644 --- a/tests/unit/rag/test_vector_store.py +++ b/tests/unit/rag/test_vector_store.py @@ -8,11 +8,20 @@ import base64 import mimetypes import os from pathlib import Path +from unittest.mock import AsyncMock, MagicMock +import numpy as np import pytest from llama_stack.apis.tools import RAGDocument -from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc, make_overlapped_chunks +from llama_stack.apis.vector_io import Chunk +from llama_stack.providers.utils.memory.vector_store import ( + URL, + VectorDBWithIndex, + _validate_embedding, + content_from_doc, + make_overlapped_chunks, +) DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf" # Depending on the machine, this can get parsed a couple of ways @@ -36,6 +45,72 @@ def data_url_from_file(file_path: str) -> str: return data_url +class TestChunk: + def test_chunk(self): + chunk = Chunk( + content="Example chunk content", + metadata={"key": "value"}, + embedding=[0.1, 0.2, 0.3], + ) + + assert chunk.content == "Example chunk content" + assert chunk.metadata == {"key": "value"} + assert chunk.embedding == [0.1, 0.2, 0.3] + + chunk_no_embedding = Chunk( + content="Example chunk content", + metadata={"key": "value"}, + ) + assert chunk_no_embedding.embedding is None + + +class TestValidateEmbedding: + def test_valid_list_embeddings(self): + _validate_embedding([0.1, 0.2, 0.3], 0, 3) + _validate_embedding([1, 2, 3], 1, 3) + _validate_embedding([0.1, 2, 3.5], 2, 3) + + def test_valid_numpy_embeddings(self): + _validate_embedding(np.array([0.1, 0.2, 0.3], dtype=np.float32), 0, 3) + _validate_embedding(np.array([0.1, 0.2, 0.3], dtype=np.float64), 1, 3) + _validate_embedding(np.array([1, 2, 3], dtype=np.int32), 2, 3) + _validate_embedding(np.array([1, 2, 3], dtype=np.int64), 3, 3) + + def test_invalid_embedding_type(self): + error_msg = "must be a list or numpy array" + + with pytest.raises(ValueError, match=error_msg): + _validate_embedding("not a list", 0, 3) + + with pytest.raises(ValueError, match=error_msg): + _validate_embedding(None, 1, 3) + + with pytest.raises(ValueError, match=error_msg): + _validate_embedding(42, 2, 3) + + def test_non_numeric_values(self): + error_msg = "contains non-numeric values" + + with pytest.raises(ValueError, match=error_msg): + _validate_embedding([0.1, "string", 0.3], 0, 3) + + with pytest.raises(ValueError, match=error_msg): + _validate_embedding([0.1, None, 0.3], 1, 3) + + with pytest.raises(ValueError, match=error_msg): + _validate_embedding([1, {}, 3], 2, 3) + + def test_wrong_dimension(self): + with pytest.raises(ValueError, match="has dimension 4, expected 3"): + _validate_embedding([0.1, 0.2, 0.3, 0.4], 0, 3) + + with pytest.raises(ValueError, match="has dimension 2, expected 3"): + _validate_embedding([0.1, 0.2], 1, 3) + + with pytest.raises(ValueError, match="has dimension 0, expected 3"): + _validate_embedding([], 2, 3) + + class TestVectorStore: @pytest.mark.asyncio async def test_returns_content_from_pdf_data_uri(self): @@ -126,3 +201,126 @@ class TestVectorStore: assert str(excinfo.value) == "Failed to serialize metadata to string" assert isinstance(excinfo.value.__cause__, TypeError) assert str(excinfo.value.__cause__) == "Cannot convert to string" + + +class TestVectorDBWithIndex: + @pytest.mark.asyncio + async def test_insert_chunks_without_embeddings(self): + mock_vector_db = MagicMock() + mock_vector_db.embedding_model = "test-model without embeddings" + mock_index = AsyncMock() + mock_inference_api = AsyncMock() + + vector_db_with_index = VectorDBWithIndex( + vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api + ) + + chunks = [ + Chunk(content="Test 1", embedding=None, metadata={}), + Chunk(content="Test 2", embedding=None, metadata={}), + ] + + mock_inference_api.embeddings.return_value.embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + + await vector_db_with_index.insert_chunks(chunks) + + mock_inference_api.embeddings.assert_called_once_with("test-model without embeddings", ["Test 1", "Test 2"]) + mock_index.add_chunks.assert_called_once() + args = mock_index.add_chunks.call_args[0] + assert args[0] == chunks + assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + + @pytest.mark.asyncio + async def test_insert_chunks_with_valid_embeddings(self): + mock_vector_db = MagicMock() + mock_vector_db.embedding_model = "test-model with embeddings" + mock_vector_db.embedding_dimension = 3 + mock_index = AsyncMock() + mock_inference_api = AsyncMock() + + vector_db_with_index = VectorDBWithIndex( + vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api + ) + + chunks = [ + Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3], metadata={}), + Chunk(content="Test 2", embedding=[0.4, 0.5, 0.6], metadata={}), + ] + + await vector_db_with_index.insert_chunks(chunks) + + mock_inference_api.embeddings.assert_not_called() + mock_index.add_chunks.assert_called_once() + args = mock_index.add_chunks.call_args[0] + assert args[0] == chunks + assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + + @pytest.mark.asyncio + async def test_insert_chunks_with_invalid_embeddings(self): + mock_vector_db = MagicMock() + mock_vector_db.embedding_dimension = 3 + mock_vector_db.embedding_model = "test-model with invalid embeddings" + mock_index = AsyncMock() + mock_inference_api = AsyncMock() + + vector_db_with_index = VectorDBWithIndex( + vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api + ) + + # Verify Chunk raises ValueError for invalid embedding type + with pytest.raises(ValueError, match="Input should be a valid list"): + Chunk(content="Test 1", embedding="invalid_type", metadata={}) + + # Verify Chunk raises ValueError for invalid embedding type in insert_chunks (i.e., Chunk errors before insert_chunks is called) + with pytest.raises(ValueError, match="Input should be a valid list"): + await vector_db_with_index.insert_chunks( + [ + Chunk(content="Test 1", embedding=None, metadata={}), + Chunk(content="Test 2", embedding="invalid_type", metadata={}), + ] + ) + + # Verify Chunk raises ValueError for invalid embedding element type in insert_chunks (i.e., Chunk errors before insert_chunks is called) + with pytest.raises(ValueError, match=" Input should be a valid number, unable to parse string as a number "): + await vector_db_with_index.insert_chunks( + Chunk(content="Test 1", embedding=[0.1, "string", 0.3], metadata={}) + ) + + chunks_wrong_dim = [ + Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3, 0.4], metadata={}), + ] + with pytest.raises(ValueError, match="has dimension 4, expected 3"): + await vector_db_with_index.insert_chunks(chunks_wrong_dim) + + mock_inference_api.embeddings.assert_not_called() + mock_index.add_chunks.assert_not_called() + + @pytest.mark.asyncio + async def test_insert_chunks_with_partially_precomputed_embeddings(self): + mock_vector_db = MagicMock() + mock_vector_db.embedding_model = "test-model with partial embeddings" + mock_vector_db.embedding_dimension = 3 + mock_index = AsyncMock() + mock_inference_api = AsyncMock() + + vector_db_with_index = VectorDBWithIndex( + vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api + ) + + chunks = [ + Chunk(content="Test 1", embedding=None, metadata={}), + Chunk(content="Test 2", embedding=[0.2, 0.2, 0.2], metadata={}), + Chunk(content="Test 3", embedding=None, metadata={}), + ] + + mock_inference_api.embeddings.return_value.embeddings = [[0.1, 0.1, 0.1], [0.3, 0.3, 0.3]] + + await vector_db_with_index.insert_chunks(chunks) + + mock_inference_api.embeddings.assert_called_once_with( + "test-model with partial embeddings", ["Test 1", "Test 3"] + ) + mock_index.add_chunks.assert_called_once() + args = mock_index.add_chunks.call_args[0] + assert len(args[0]) == 3 + assert np.array_equal(args[1], np.array([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]], dtype=np.float32))