diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 9c1c3170f..dbe251921 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -10020,7 +10020,8 @@
"type": "object",
"properties": {
"content": {
- "$ref": "#/components/schemas/InterleavedContent"
+ "$ref": "#/components/schemas/InterleavedContent",
+ "description": "The content of the chunk, which can be interleaved text, images, or other types."
},
"metadata": {
"type": "object",
@@ -10045,7 +10046,15 @@
"type": "object"
}
]
- }
+ },
+ "description": "Metadata associated with the chunk, such as document ID, source, or other relevant information."
+ },
+ "embedding": {
+ "type": "array",
+ "items": {
+ "type": "number"
+ },
+ "description": "Optional embedding for the chunk. If not provided, it will be computed later."
}
},
"additionalProperties": false,
@@ -10053,9 +10062,10 @@
"content",
"metadata"
],
- "title": "Chunk"
+ "title": "Chunk",
+ "description": "A chunk of content that can be inserted into a vector database."
},
- "description": "The chunks to insert."
+ "description": "The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types. `metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional. If `metadata` is provided, you configure how Llama Stack formats the chunk during generation. If `embedding` is not provided, it will be computed later."
},
"ttl_seconds": {
"type": "integer",
@@ -12285,7 +12295,8 @@
"type": "object",
"properties": {
"content": {
- "$ref": "#/components/schemas/InterleavedContent"
+ "$ref": "#/components/schemas/InterleavedContent",
+ "description": "The content of the chunk, which can be interleaved text, images, or other types."
},
"metadata": {
"type": "object",
@@ -12310,7 +12321,15 @@
"type": "object"
}
]
- }
+ },
+ "description": "Metadata associated with the chunk, such as document ID, source, or other relevant information."
+ },
+ "embedding": {
+ "type": "array",
+ "items": {
+ "type": "number"
+ },
+ "description": "Optional embedding for the chunk. If not provided, it will be computed later."
}
},
"additionalProperties": false,
@@ -12318,7 +12337,8 @@
"content",
"metadata"
],
- "title": "Chunk"
+ "title": "Chunk",
+ "description": "A chunk of content that can be inserted into a vector database."
}
},
"scores": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 1afe870cf..2dfb037da 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -7024,6 +7024,9 @@ components:
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
+ description: >-
+ The content of the chunk, which can be interleaved text, images,
+ or other types.
metadata:
type: object
additionalProperties:
@@ -7034,12 +7037,29 @@ components:
- type: string
- type: array
- type: object
+ description: >-
+ Metadata associated with the chunk, such as document ID, source,
+ or other relevant information.
+ embedding:
+ type: array
+ items:
+ type: number
+ description: >-
+ Optional embedding for the chunk. If not provided, it will be computed
+ later.
additionalProperties: false
required:
- content
- metadata
title: Chunk
- description: The chunks to insert.
+ description: >-
+ A chunk of content that can be inserted into a vector database.
+ description: >-
+ The chunks to insert. Each `Chunk` should contain content which can be
+ interleaved text, images, or other types. `metadata`: `dict[str, Any]`
+ and `embedding`: `List[float]` are optional. If `metadata` is provided,
+ you configure how Llama Stack formats the chunk during generation. If
+ `embedding` is not provided, it will be computed later.
ttl_seconds:
type: integer
description: The time to live of the chunks.
@@ -8537,6 +8557,9 @@ components:
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
+ description: >-
+ The content of the chunk, which can be interleaved text, images,
+ or other types.
metadata:
type: object
additionalProperties:
@@ -8547,11 +8570,23 @@ components:
- type: string
- type: array
- type: object
+ description: >-
+ Metadata associated with the chunk, such as document ID, source,
+ or other relevant information.
+ embedding:
+ type: array
+ items:
+ type: number
+ description: >-
+ Optional embedding for the chunk. If not provided, it will be computed
+ later.
additionalProperties: false
required:
- content
- metadata
title: Chunk
+ description: >-
+ A chunk of content that can be inserted into a vector database.
scores:
type: array
items:
diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md
index dbe90a7fc..289c38991 100644
--- a/docs/source/building_applications/rag.md
+++ b/docs/source/building_applications/rag.md
@@ -57,6 +57,31 @@ chunks = [
]
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
```
+
+#### Using Precomputed Embeddings
+If you decide to precompute embeddings for your documents, you can insert them directly into the vector database by
+including the embedding vectors in the chunk data. This is useful if you have a separate embedding service or if you
+want to customize the ingestion process.
+```python
+chunks_with_embeddings = [
+ {
+ "content": "First chunk of text",
+ "mime_type": "text/plain",
+ "embedding": [0.1, 0.2, 0.3, ...], # Your precomputed embedding vector
+ "metadata": {"document_id": "doc1", "section": "introduction"},
+ },
+ {
+ "content": "Second chunk of text",
+ "mime_type": "text/plain",
+ "embedding": [0.2, 0.3, 0.4, ...], # Your precomputed embedding vector
+ "metadata": {"document_id": "doc1", "section": "methodology"},
+ },
+]
+client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks_with_embeddings)
+```
+When providing precomputed embeddings, ensure the embedding dimension matches the embedding_dimension specified when
+registering the vector database.
+
### Retrieval
You can query the vector database to retrieve documents based on their embeddings.
```python
diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py
index 3ac62d42c..44cc8f904 100644
--- a/llama_stack/apis/vector_io/vector_io.py
+++ b/llama_stack/apis/vector_io/vector_io.py
@@ -19,8 +19,16 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class Chunk(BaseModel):
+ """
+ A chunk of content that can be inserted into a vector database.
+ :param content: The content of the chunk, which can be interleaved text, images, or other types.
+ :param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
+ :param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information.
+ """
+
content: InterleavedContent
metadata: dict[str, Any] = Field(default_factory=dict)
+ embedding: list[float] | None = None
@json_schema_type
@@ -50,7 +58,10 @@ class VectorIO(Protocol):
"""Insert chunks into a vector database.
:param vector_db_id: The identifier of the vector database to insert the chunks into.
- :param chunks: The chunks to insert.
+ :param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types.
+ `metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional.
+ If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
+ If `embedding` is not provided, it will be computed later.
:param ttl_seconds: The time to live of the chunks.
"""
...
diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py
index c2d264c91..4776d47d0 100644
--- a/llama_stack/providers/inline/tool_runtime/rag/memory.py
+++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py
@@ -146,7 +146,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
]
for i, chunk in enumerate(chunks):
metadata = chunk.metadata
- tokens += metadata["token_count"]
+ tokens += metadata.get("token_count", 0)
tokens += metadata.get("metadata_token_count", 0)
if tokens > query_config.max_tokens_in_context:
diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py
index 3655c7049..4cd15860b 100644
--- a/llama_stack/providers/utils/memory/vector_store.py
+++ b/llama_stack/providers/utils/memory/vector_store.py
@@ -171,6 +171,22 @@ def make_overlapped_chunks(
return chunks
+def _validate_embedding(embedding: NDArray, index: int, expected_dimension: int):
+ """Helper method to validate embedding format and dimensions"""
+ if not isinstance(embedding, (list | np.ndarray)):
+ raise ValueError(f"Embedding at index {index} must be a list or numpy array, got {type(embedding)}")
+
+ if isinstance(embedding, np.ndarray):
+ if not np.issubdtype(embedding.dtype, np.number):
+ raise ValueError(f"Embedding at index {index} contains non-numeric values")
+ else:
+ if not all(isinstance(e, (float | int | np.number)) for e in embedding):
+ raise ValueError(f"Embedding at index {index} contains non-numeric values")
+
+ if len(embedding) != expected_dimension:
+ raise ValueError(f"Embedding at index {index} has dimension {len(embedding)}, expected {expected_dimension}")
+
+
class EmbeddingIndex(ABC):
@abstractmethod
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
@@ -199,11 +215,22 @@ class VectorDBWithIndex:
self,
chunks: list[Chunk],
) -> None:
- embeddings_response = await self.inference_api.embeddings(
- self.vector_db.embedding_model, [x.content for x in chunks]
- )
- embeddings = np.array(embeddings_response.embeddings)
+ chunks_to_embed = []
+ for i, c in enumerate(chunks):
+ if c.embedding is None:
+ chunks_to_embed.append(c)
+ else:
+ _validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
+ if chunks_to_embed:
+ resp = await self.inference_api.embeddings(
+ self.vector_db.embedding_model,
+ [c.content for c in chunks_to_embed],
+ )
+ for c, embedding in zip(chunks_to_embed, resp.embeddings, strict=False):
+ c.embedding = embedding
+
+ embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
await self.index.add_chunks(chunks, embeddings)
async def query_chunks(
diff --git a/tests/integration/vector_io/test_vector_io.py b/tests/integration/vector_io/test_vector_io.py
index 90cb00313..f1cac9701 100644
--- a/tests/integration/vector_io/test_vector_io.py
+++ b/tests/integration/vector_io/test_vector_io.py
@@ -120,3 +120,37 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, sample_ch
top_match = response.chunks[0]
assert top_match is not None
assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}"
+
+
+def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id):
+ vector_db_id = "test_precomputed_embeddings_db"
+ client_with_empty_registry.vector_dbs.register(
+ vector_db_id=vector_db_id,
+ embedding_model=embedding_model_id,
+ embedding_dimension=384,
+ )
+
+ chunks_with_embeddings = [
+ Chunk(
+ content="This is a test chunk with precomputed embedding.",
+ metadata={"document_id": "doc1", "source": "precomputed"},
+ embedding=[0.1] * 384,
+ ),
+ ]
+
+ client_with_empty_registry.vector_io.insert(
+ vector_db_id=vector_db_id,
+ chunks=chunks_with_embeddings,
+ )
+
+ # Query for the first document
+ response = client_with_empty_registry.vector_io.query(
+ vector_db_id=vector_db_id,
+ query="precomputed embedding test",
+ )
+
+ # Verify the top result is the expected document
+ assert response is not None
+ assert len(response.chunks) > 0
+ assert response.chunks[0].metadata["document_id"] == "doc1"
+ assert response.chunks[0].metadata["source"] == "precomputed"
diff --git a/tests/unit/providers/vector_io/test_qdrant.py b/tests/unit/providers/vector_io/test_qdrant.py
index 34df9b52f..607eccb24 100644
--- a/tests/unit/providers/vector_io/test_qdrant.py
+++ b/tests/unit/providers/vector_io/test_qdrant.py
@@ -50,6 +50,7 @@ def mock_vector_db(vector_db_id) -> MagicMock:
mock_vector_db = MagicMock(spec=VectorDB)
mock_vector_db.embedding_model = "embedding_model"
mock_vector_db.identifier = vector_db_id
+ mock_vector_db.embedding_dimension = 384
return mock_vector_db
diff --git a/tests/unit/rag/test_vector_store.py b/tests/unit/rag/test_vector_store.py
index f97808a6d..9d6b9ee67 100644
--- a/tests/unit/rag/test_vector_store.py
+++ b/tests/unit/rag/test_vector_store.py
@@ -8,11 +8,20 @@ import base64
import mimetypes
import os
from pathlib import Path
+from unittest.mock import AsyncMock, MagicMock
+import numpy as np
import pytest
from llama_stack.apis.tools import RAGDocument
-from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc, make_overlapped_chunks
+from llama_stack.apis.vector_io import Chunk
+from llama_stack.providers.utils.memory.vector_store import (
+ URL,
+ VectorDBWithIndex,
+ _validate_embedding,
+ content_from_doc,
+ make_overlapped_chunks,
+)
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
# Depending on the machine, this can get parsed a couple of ways
@@ -36,6 +45,72 @@ def data_url_from_file(file_path: str) -> str:
return data_url
+class TestChunk:
+ def test_chunk(self):
+ chunk = Chunk(
+ content="Example chunk content",
+ metadata={"key": "value"},
+ embedding=[0.1, 0.2, 0.3],
+ )
+
+ assert chunk.content == "Example chunk content"
+ assert chunk.metadata == {"key": "value"}
+ assert chunk.embedding == [0.1, 0.2, 0.3]
+
+ chunk_no_embedding = Chunk(
+ content="Example chunk content",
+ metadata={"key": "value"},
+ )
+ assert chunk_no_embedding.embedding is None
+
+
+class TestValidateEmbedding:
+ def test_valid_list_embeddings(self):
+ _validate_embedding([0.1, 0.2, 0.3], 0, 3)
+ _validate_embedding([1, 2, 3], 1, 3)
+ _validate_embedding([0.1, 2, 3.5], 2, 3)
+
+ def test_valid_numpy_embeddings(self):
+ _validate_embedding(np.array([0.1, 0.2, 0.3], dtype=np.float32), 0, 3)
+ _validate_embedding(np.array([0.1, 0.2, 0.3], dtype=np.float64), 1, 3)
+ _validate_embedding(np.array([1, 2, 3], dtype=np.int32), 2, 3)
+ _validate_embedding(np.array([1, 2, 3], dtype=np.int64), 3, 3)
+
+ def test_invalid_embedding_type(self):
+ error_msg = "must be a list or numpy array"
+
+ with pytest.raises(ValueError, match=error_msg):
+ _validate_embedding("not a list", 0, 3)
+
+ with pytest.raises(ValueError, match=error_msg):
+ _validate_embedding(None, 1, 3)
+
+ with pytest.raises(ValueError, match=error_msg):
+ _validate_embedding(42, 2, 3)
+
+ def test_non_numeric_values(self):
+ error_msg = "contains non-numeric values"
+
+ with pytest.raises(ValueError, match=error_msg):
+ _validate_embedding([0.1, "string", 0.3], 0, 3)
+
+ with pytest.raises(ValueError, match=error_msg):
+ _validate_embedding([0.1, None, 0.3], 1, 3)
+
+ with pytest.raises(ValueError, match=error_msg):
+ _validate_embedding([1, {}, 3], 2, 3)
+
+ def test_wrong_dimension(self):
+ with pytest.raises(ValueError, match="has dimension 4, expected 3"):
+ _validate_embedding([0.1, 0.2, 0.3, 0.4], 0, 3)
+
+ with pytest.raises(ValueError, match="has dimension 2, expected 3"):
+ _validate_embedding([0.1, 0.2], 1, 3)
+
+ with pytest.raises(ValueError, match="has dimension 0, expected 3"):
+ _validate_embedding([], 2, 3)
+
+
class TestVectorStore:
@pytest.mark.asyncio
async def test_returns_content_from_pdf_data_uri(self):
@@ -126,3 +201,126 @@ class TestVectorStore:
assert str(excinfo.value) == "Failed to serialize metadata to string"
assert isinstance(excinfo.value.__cause__, TypeError)
assert str(excinfo.value.__cause__) == "Cannot convert to string"
+
+
+class TestVectorDBWithIndex:
+ @pytest.mark.asyncio
+ async def test_insert_chunks_without_embeddings(self):
+ mock_vector_db = MagicMock()
+ mock_vector_db.embedding_model = "test-model without embeddings"
+ mock_index = AsyncMock()
+ mock_inference_api = AsyncMock()
+
+ vector_db_with_index = VectorDBWithIndex(
+ vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
+ )
+
+ chunks = [
+ Chunk(content="Test 1", embedding=None, metadata={}),
+ Chunk(content="Test 2", embedding=None, metadata={}),
+ ]
+
+ mock_inference_api.embeddings.return_value.embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
+
+ await vector_db_with_index.insert_chunks(chunks)
+
+ mock_inference_api.embeddings.assert_called_once_with("test-model without embeddings", ["Test 1", "Test 2"])
+ mock_index.add_chunks.assert_called_once()
+ args = mock_index.add_chunks.call_args[0]
+ assert args[0] == chunks
+ assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
+
+ @pytest.mark.asyncio
+ async def test_insert_chunks_with_valid_embeddings(self):
+ mock_vector_db = MagicMock()
+ mock_vector_db.embedding_model = "test-model with embeddings"
+ mock_vector_db.embedding_dimension = 3
+ mock_index = AsyncMock()
+ mock_inference_api = AsyncMock()
+
+ vector_db_with_index = VectorDBWithIndex(
+ vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
+ )
+
+ chunks = [
+ Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3], metadata={}),
+ Chunk(content="Test 2", embedding=[0.4, 0.5, 0.6], metadata={}),
+ ]
+
+ await vector_db_with_index.insert_chunks(chunks)
+
+ mock_inference_api.embeddings.assert_not_called()
+ mock_index.add_chunks.assert_called_once()
+ args = mock_index.add_chunks.call_args[0]
+ assert args[0] == chunks
+ assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
+
+ @pytest.mark.asyncio
+ async def test_insert_chunks_with_invalid_embeddings(self):
+ mock_vector_db = MagicMock()
+ mock_vector_db.embedding_dimension = 3
+ mock_vector_db.embedding_model = "test-model with invalid embeddings"
+ mock_index = AsyncMock()
+ mock_inference_api = AsyncMock()
+
+ vector_db_with_index = VectorDBWithIndex(
+ vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
+ )
+
+ # Verify Chunk raises ValueError for invalid embedding type
+ with pytest.raises(ValueError, match="Input should be a valid list"):
+ Chunk(content="Test 1", embedding="invalid_type", metadata={})
+
+ # Verify Chunk raises ValueError for invalid embedding type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
+ with pytest.raises(ValueError, match="Input should be a valid list"):
+ await vector_db_with_index.insert_chunks(
+ [
+ Chunk(content="Test 1", embedding=None, metadata={}),
+ Chunk(content="Test 2", embedding="invalid_type", metadata={}),
+ ]
+ )
+
+ # Verify Chunk raises ValueError for invalid embedding element type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
+ with pytest.raises(ValueError, match=" Input should be a valid number, unable to parse string as a number "):
+ await vector_db_with_index.insert_chunks(
+ Chunk(content="Test 1", embedding=[0.1, "string", 0.3], metadata={})
+ )
+
+ chunks_wrong_dim = [
+ Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3, 0.4], metadata={}),
+ ]
+ with pytest.raises(ValueError, match="has dimension 4, expected 3"):
+ await vector_db_with_index.insert_chunks(chunks_wrong_dim)
+
+ mock_inference_api.embeddings.assert_not_called()
+ mock_index.add_chunks.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_insert_chunks_with_partially_precomputed_embeddings(self):
+ mock_vector_db = MagicMock()
+ mock_vector_db.embedding_model = "test-model with partial embeddings"
+ mock_vector_db.embedding_dimension = 3
+ mock_index = AsyncMock()
+ mock_inference_api = AsyncMock()
+
+ vector_db_with_index = VectorDBWithIndex(
+ vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
+ )
+
+ chunks = [
+ Chunk(content="Test 1", embedding=None, metadata={}),
+ Chunk(content="Test 2", embedding=[0.2, 0.2, 0.2], metadata={}),
+ Chunk(content="Test 3", embedding=None, metadata={}),
+ ]
+
+ mock_inference_api.embeddings.return_value.embeddings = [[0.1, 0.1, 0.1], [0.3, 0.3, 0.3]]
+
+ await vector_db_with_index.insert_chunks(chunks)
+
+ mock_inference_api.embeddings.assert_called_once_with(
+ "test-model with partial embeddings", ["Test 1", "Test 3"]
+ )
+ mock_index.add_chunks.assert_called_once()
+ args = mock_index.add_chunks.call_args[0]
+ assert len(args[0]) == 3
+ assert np.array_equal(args[1], np.array([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]], dtype=np.float32))