feat: Enable ingestion of precomputed embeddings (#2317)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 3s
Integration Tests / test-matrix (http, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, post_training) (push) Failing after 9s
Integration Tests / test-matrix (http, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, datasets) (push) Failing after 10s
Integration Tests / test-matrix (http, inference) (push) Failing after 10s
Integration Tests / test-matrix (library, agents) (push) Failing after 9s
Integration Tests / test-matrix (http, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, datasets) (push) Failing after 8s
Integration Tests / test-matrix (http, providers) (push) Failing after 9s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (library, inference) (push) Failing after 9s
Test External Providers / test-external-providers (venv) (push) Failing after 6s
Integration Tests / test-matrix (library, inspect) (push) Failing after 8s
Integration Tests / test-matrix (library, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, scoring) (push) Failing after 8s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 7s
Unit Tests / unit-tests (3.10) (push) Failing after 9s
Unit Tests / unit-tests (3.13) (push) Failing after 7s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 9s
Unit Tests / unit-tests (3.12) (push) Failing after 9s
Update ReadTheDocs / update-readthedocs (push) Failing after 7s
Pre-commit / pre-commit (push) Successful in 1m15s

This commit is contained in:
Francisco Arceo 2025-05-31 04:03:37 -06:00 committed by GitHub
parent 31ce208bda
commit f328436831
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 366 additions and 15 deletions

View file

@ -10020,7 +10020,8 @@
"type": "object", "type": "object",
"properties": { "properties": {
"content": { "content": {
"$ref": "#/components/schemas/InterleavedContent" "$ref": "#/components/schemas/InterleavedContent",
"description": "The content of the chunk, which can be interleaved text, images, or other types."
}, },
"metadata": { "metadata": {
"type": "object", "type": "object",
@ -10045,7 +10046,15 @@
"type": "object" "type": "object"
} }
] ]
} },
"description": "Metadata associated with the chunk, such as document ID, source, or other relevant information."
},
"embedding": {
"type": "array",
"items": {
"type": "number"
},
"description": "Optional embedding for the chunk. If not provided, it will be computed later."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -10053,9 +10062,10 @@
"content", "content",
"metadata" "metadata"
], ],
"title": "Chunk" "title": "Chunk",
"description": "A chunk of content that can be inserted into a vector database."
}, },
"description": "The chunks to insert." "description": "The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types. `metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional. If `metadata` is provided, you configure how Llama Stack formats the chunk during generation. If `embedding` is not provided, it will be computed later."
}, },
"ttl_seconds": { "ttl_seconds": {
"type": "integer", "type": "integer",
@ -12285,7 +12295,8 @@
"type": "object", "type": "object",
"properties": { "properties": {
"content": { "content": {
"$ref": "#/components/schemas/InterleavedContent" "$ref": "#/components/schemas/InterleavedContent",
"description": "The content of the chunk, which can be interleaved text, images, or other types."
}, },
"metadata": { "metadata": {
"type": "object", "type": "object",
@ -12310,7 +12321,15 @@
"type": "object" "type": "object"
} }
] ]
} },
"description": "Metadata associated with the chunk, such as document ID, source, or other relevant information."
},
"embedding": {
"type": "array",
"items": {
"type": "number"
},
"description": "Optional embedding for the chunk. If not provided, it will be computed later."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -12318,7 +12337,8 @@
"content", "content",
"metadata" "metadata"
], ],
"title": "Chunk" "title": "Chunk",
"description": "A chunk of content that can be inserted into a vector database."
} }
}, },
"scores": { "scores": {

View file

@ -7024,6 +7024,9 @@ components:
properties: properties:
content: content:
$ref: '#/components/schemas/InterleavedContent' $ref: '#/components/schemas/InterleavedContent'
description: >-
The content of the chunk, which can be interleaved text, images,
or other types.
metadata: metadata:
type: object type: object
additionalProperties: additionalProperties:
@ -7034,12 +7037,29 @@ components:
- type: string - type: string
- type: array - type: array
- type: object - type: object
description: >-
Metadata associated with the chunk, such as document ID, source,
or other relevant information.
embedding:
type: array
items:
type: number
description: >-
Optional embedding for the chunk. If not provided, it will be computed
later.
additionalProperties: false additionalProperties: false
required: required:
- content - content
- metadata - metadata
title: Chunk title: Chunk
description: The chunks to insert. description: >-
A chunk of content that can be inserted into a vector database.
description: >-
The chunks to insert. Each `Chunk` should contain content which can be
interleaved text, images, or other types. `metadata`: `dict[str, Any]`
and `embedding`: `List[float]` are optional. If `metadata` is provided,
you configure how Llama Stack formats the chunk during generation. If
`embedding` is not provided, it will be computed later.
ttl_seconds: ttl_seconds:
type: integer type: integer
description: The time to live of the chunks. description: The time to live of the chunks.
@ -8537,6 +8557,9 @@ components:
properties: properties:
content: content:
$ref: '#/components/schemas/InterleavedContent' $ref: '#/components/schemas/InterleavedContent'
description: >-
The content of the chunk, which can be interleaved text, images,
or other types.
metadata: metadata:
type: object type: object
additionalProperties: additionalProperties:
@ -8547,11 +8570,23 @@ components:
- type: string - type: string
- type: array - type: array
- type: object - type: object
description: >-
Metadata associated with the chunk, such as document ID, source,
or other relevant information.
embedding:
type: array
items:
type: number
description: >-
Optional embedding for the chunk. If not provided, it will be computed
later.
additionalProperties: false additionalProperties: false
required: required:
- content - content
- metadata - metadata
title: Chunk title: Chunk
description: >-
A chunk of content that can be inserted into a vector database.
scores: scores:
type: array type: array
items: items:

View file

@ -57,6 +57,31 @@ chunks = [
] ]
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks) client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
``` ```
#### Using Precomputed Embeddings
If you decide to precompute embeddings for your documents, you can insert them directly into the vector database by
including the embedding vectors in the chunk data. This is useful if you have a separate embedding service or if you
want to customize the ingestion process.
```python
chunks_with_embeddings = [
{
"content": "First chunk of text",
"mime_type": "text/plain",
"embedding": [0.1, 0.2, 0.3, ...], # Your precomputed embedding vector
"metadata": {"document_id": "doc1", "section": "introduction"},
},
{
"content": "Second chunk of text",
"mime_type": "text/plain",
"embedding": [0.2, 0.3, 0.4, ...], # Your precomputed embedding vector
"metadata": {"document_id": "doc1", "section": "methodology"},
},
]
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks_with_embeddings)
```
When providing precomputed embeddings, ensure the embedding dimension matches the embedding_dimension specified when
registering the vector database.
### Retrieval ### Retrieval
You can query the vector database to retrieve documents based on their embeddings. You can query the vector database to retrieve documents based on their embeddings.
```python ```python

View file

@ -19,8 +19,16 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class Chunk(BaseModel): class Chunk(BaseModel):
"""
A chunk of content that can be inserted into a vector database.
:param content: The content of the chunk, which can be interleaved text, images, or other types.
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
:param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information.
"""
content: InterleavedContent content: InterleavedContent
metadata: dict[str, Any] = Field(default_factory=dict) metadata: dict[str, Any] = Field(default_factory=dict)
embedding: list[float] | None = None
@json_schema_type @json_schema_type
@ -50,7 +58,10 @@ class VectorIO(Protocol):
"""Insert chunks into a vector database. """Insert chunks into a vector database.
:param vector_db_id: The identifier of the vector database to insert the chunks into. :param vector_db_id: The identifier of the vector database to insert the chunks into.
:param chunks: The chunks to insert. :param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types.
`metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional.
If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
If `embedding` is not provided, it will be computed later.
:param ttl_seconds: The time to live of the chunks. :param ttl_seconds: The time to live of the chunks.
""" """
... ...

View file

@ -146,7 +146,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
] ]
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
metadata = chunk.metadata metadata = chunk.metadata
tokens += metadata["token_count"] tokens += metadata.get("token_count", 0)
tokens += metadata.get("metadata_token_count", 0) tokens += metadata.get("metadata_token_count", 0)
if tokens > query_config.max_tokens_in_context: if tokens > query_config.max_tokens_in_context:

View file

@ -171,6 +171,22 @@ def make_overlapped_chunks(
return chunks return chunks
def _validate_embedding(embedding: NDArray, index: int, expected_dimension: int):
"""Helper method to validate embedding format and dimensions"""
if not isinstance(embedding, (list | np.ndarray)):
raise ValueError(f"Embedding at index {index} must be a list or numpy array, got {type(embedding)}")
if isinstance(embedding, np.ndarray):
if not np.issubdtype(embedding.dtype, np.number):
raise ValueError(f"Embedding at index {index} contains non-numeric values")
else:
if not all(isinstance(e, (float | int | np.number)) for e in embedding):
raise ValueError(f"Embedding at index {index} contains non-numeric values")
if len(embedding) != expected_dimension:
raise ValueError(f"Embedding at index {index} has dimension {len(embedding)}, expected {expected_dimension}")
class EmbeddingIndex(ABC): class EmbeddingIndex(ABC):
@abstractmethod @abstractmethod
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
@ -199,11 +215,22 @@ class VectorDBWithIndex:
self, self,
chunks: list[Chunk], chunks: list[Chunk],
) -> None: ) -> None:
embeddings_response = await self.inference_api.embeddings( chunks_to_embed = []
self.vector_db.embedding_model, [x.content for x in chunks] for i, c in enumerate(chunks):
) if c.embedding is None:
embeddings = np.array(embeddings_response.embeddings) chunks_to_embed.append(c)
else:
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
if chunks_to_embed:
resp = await self.inference_api.embeddings(
self.vector_db.embedding_model,
[c.content for c in chunks_to_embed],
)
for c, embedding in zip(chunks_to_embed, resp.embeddings, strict=False):
c.embedding = embedding
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
await self.index.add_chunks(chunks, embeddings) await self.index.add_chunks(chunks, embeddings)
async def query_chunks( async def query_chunks(

View file

@ -120,3 +120,37 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, sample_ch
top_match = response.chunks[0] top_match = response.chunks[0]
assert top_match is not None assert top_match is not None
assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}" assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}"
def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id):
vector_db_id = "test_precomputed_embeddings_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=384,
)
chunks_with_embeddings = [
Chunk(
content="This is a test chunk with precomputed embedding.",
metadata={"document_id": "doc1", "source": "precomputed"},
embedding=[0.1] * 384,
),
]
client_with_empty_registry.vector_io.insert(
vector_db_id=vector_db_id,
chunks=chunks_with_embeddings,
)
# Query for the first document
response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
query="precomputed embedding test",
)
# Verify the top result is the expected document
assert response is not None
assert len(response.chunks) > 0
assert response.chunks[0].metadata["document_id"] == "doc1"
assert response.chunks[0].metadata["source"] == "precomputed"

View file

@ -50,6 +50,7 @@ def mock_vector_db(vector_db_id) -> MagicMock:
mock_vector_db = MagicMock(spec=VectorDB) mock_vector_db = MagicMock(spec=VectorDB)
mock_vector_db.embedding_model = "embedding_model" mock_vector_db.embedding_model = "embedding_model"
mock_vector_db.identifier = vector_db_id mock_vector_db.identifier = vector_db_id
mock_vector_db.embedding_dimension = 384
return mock_vector_db return mock_vector_db

View file

@ -8,11 +8,20 @@ import base64
import mimetypes import mimetypes
import os import os
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest import pytest
from llama_stack.apis.tools import RAGDocument from llama_stack.apis.tools import RAGDocument
from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc, make_overlapped_chunks from llama_stack.apis.vector_io import Chunk
from llama_stack.providers.utils.memory.vector_store import (
URL,
VectorDBWithIndex,
_validate_embedding,
content_from_doc,
make_overlapped_chunks,
)
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf" DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
# Depending on the machine, this can get parsed a couple of ways # Depending on the machine, this can get parsed a couple of ways
@ -36,6 +45,72 @@ def data_url_from_file(file_path: str) -> str:
return data_url return data_url
class TestChunk:
def test_chunk(self):
chunk = Chunk(
content="Example chunk content",
metadata={"key": "value"},
embedding=[0.1, 0.2, 0.3],
)
assert chunk.content == "Example chunk content"
assert chunk.metadata == {"key": "value"}
assert chunk.embedding == [0.1, 0.2, 0.3]
chunk_no_embedding = Chunk(
content="Example chunk content",
metadata={"key": "value"},
)
assert chunk_no_embedding.embedding is None
class TestValidateEmbedding:
def test_valid_list_embeddings(self):
_validate_embedding([0.1, 0.2, 0.3], 0, 3)
_validate_embedding([1, 2, 3], 1, 3)
_validate_embedding([0.1, 2, 3.5], 2, 3)
def test_valid_numpy_embeddings(self):
_validate_embedding(np.array([0.1, 0.2, 0.3], dtype=np.float32), 0, 3)
_validate_embedding(np.array([0.1, 0.2, 0.3], dtype=np.float64), 1, 3)
_validate_embedding(np.array([1, 2, 3], dtype=np.int32), 2, 3)
_validate_embedding(np.array([1, 2, 3], dtype=np.int64), 3, 3)
def test_invalid_embedding_type(self):
error_msg = "must be a list or numpy array"
with pytest.raises(ValueError, match=error_msg):
_validate_embedding("not a list", 0, 3)
with pytest.raises(ValueError, match=error_msg):
_validate_embedding(None, 1, 3)
with pytest.raises(ValueError, match=error_msg):
_validate_embedding(42, 2, 3)
def test_non_numeric_values(self):
error_msg = "contains non-numeric values"
with pytest.raises(ValueError, match=error_msg):
_validate_embedding([0.1, "string", 0.3], 0, 3)
with pytest.raises(ValueError, match=error_msg):
_validate_embedding([0.1, None, 0.3], 1, 3)
with pytest.raises(ValueError, match=error_msg):
_validate_embedding([1, {}, 3], 2, 3)
def test_wrong_dimension(self):
with pytest.raises(ValueError, match="has dimension 4, expected 3"):
_validate_embedding([0.1, 0.2, 0.3, 0.4], 0, 3)
with pytest.raises(ValueError, match="has dimension 2, expected 3"):
_validate_embedding([0.1, 0.2], 1, 3)
with pytest.raises(ValueError, match="has dimension 0, expected 3"):
_validate_embedding([], 2, 3)
class TestVectorStore: class TestVectorStore:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_returns_content_from_pdf_data_uri(self): async def test_returns_content_from_pdf_data_uri(self):
@ -126,3 +201,126 @@ class TestVectorStore:
assert str(excinfo.value) == "Failed to serialize metadata to string" assert str(excinfo.value) == "Failed to serialize metadata to string"
assert isinstance(excinfo.value.__cause__, TypeError) assert isinstance(excinfo.value.__cause__, TypeError)
assert str(excinfo.value.__cause__) == "Cannot convert to string" assert str(excinfo.value.__cause__) == "Cannot convert to string"
class TestVectorDBWithIndex:
@pytest.mark.asyncio
async def test_insert_chunks_without_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model without embeddings"
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
)
chunks = [
Chunk(content="Test 1", embedding=None, metadata={}),
Chunk(content="Test 2", embedding=None, metadata={}),
]
mock_inference_api.embeddings.return_value.embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
await vector_db_with_index.insert_chunks(chunks)
mock_inference_api.embeddings.assert_called_once_with("test-model without embeddings", ["Test 1", "Test 2"])
mock_index.add_chunks.assert_called_once()
args = mock_index.add_chunks.call_args[0]
assert args[0] == chunks
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
@pytest.mark.asyncio
async def test_insert_chunks_with_valid_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model with embeddings"
mock_vector_db.embedding_dimension = 3
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
)
chunks = [
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3], metadata={}),
Chunk(content="Test 2", embedding=[0.4, 0.5, 0.6], metadata={}),
]
await vector_db_with_index.insert_chunks(chunks)
mock_inference_api.embeddings.assert_not_called()
mock_index.add_chunks.assert_called_once()
args = mock_index.add_chunks.call_args[0]
assert args[0] == chunks
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
@pytest.mark.asyncio
async def test_insert_chunks_with_invalid_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_dimension = 3
mock_vector_db.embedding_model = "test-model with invalid embeddings"
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
)
# Verify Chunk raises ValueError for invalid embedding type
with pytest.raises(ValueError, match="Input should be a valid list"):
Chunk(content="Test 1", embedding="invalid_type", metadata={})
# Verify Chunk raises ValueError for invalid embedding type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
with pytest.raises(ValueError, match="Input should be a valid list"):
await vector_db_with_index.insert_chunks(
[
Chunk(content="Test 1", embedding=None, metadata={}),
Chunk(content="Test 2", embedding="invalid_type", metadata={}),
]
)
# Verify Chunk raises ValueError for invalid embedding element type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
with pytest.raises(ValueError, match=" Input should be a valid number, unable to parse string as a number "):
await vector_db_with_index.insert_chunks(
Chunk(content="Test 1", embedding=[0.1, "string", 0.3], metadata={})
)
chunks_wrong_dim = [
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3, 0.4], metadata={}),
]
with pytest.raises(ValueError, match="has dimension 4, expected 3"):
await vector_db_with_index.insert_chunks(chunks_wrong_dim)
mock_inference_api.embeddings.assert_not_called()
mock_index.add_chunks.assert_not_called()
@pytest.mark.asyncio
async def test_insert_chunks_with_partially_precomputed_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model with partial embeddings"
mock_vector_db.embedding_dimension = 3
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
)
chunks = [
Chunk(content="Test 1", embedding=None, metadata={}),
Chunk(content="Test 2", embedding=[0.2, 0.2, 0.2], metadata={}),
Chunk(content="Test 3", embedding=None, metadata={}),
]
mock_inference_api.embeddings.return_value.embeddings = [[0.1, 0.1, 0.1], [0.3, 0.3, 0.3]]
await vector_db_with_index.insert_chunks(chunks)
mock_inference_api.embeddings.assert_called_once_with(
"test-model with partial embeddings", ["Test 1", "Test 3"]
)
mock_index.add_chunks.assert_called_once()
args = mock_index.add_chunks.call_args[0]
assert len(args[0]) == 3
assert np.array_equal(args[1], np.array([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]], dtype=np.float32))