feat: Enable ingestion of custom embeddings

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-05-29 20:58:41 -04:00 committed by Francisco Arceo
parent 2603f10f95
commit 73456878e5
8 changed files with 224 additions and 15 deletions

View file

@ -10020,7 +10020,8 @@
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent"
"$ref": "#/components/schemas/InterleavedContent",
"description": "The content of the chunk, which can be interleaved text, images, or other types."
},
"metadata": {
"type": "object",
@ -10045,7 +10046,15 @@
"type": "object"
}
]
}
},
"description": "Metadata associated with the chunk, such as document ID, source, or other relevant information."
},
"embedding": {
"type": "array",
"items": {
"type": "number"
},
"description": "Optional embedding for the chunk. If not provided, it will be computed later."
}
},
"additionalProperties": false,
@ -10053,9 +10062,10 @@
"content",
"metadata"
],
"title": "Chunk"
"title": "Chunk",
"description": "A chunk of content that can be inserted into a vector database."
},
"description": "The chunks to insert."
"description": "The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types. `metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional. If `metadata` is provided, you configure how Llama Stack formats the chunk during generation. If `embedding` is not provided, it will be computed later."
},
"ttl_seconds": {
"type": "integer",
@ -12285,7 +12295,8 @@
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent"
"$ref": "#/components/schemas/InterleavedContent",
"description": "The content of the chunk, which can be interleaved text, images, or other types."
},
"metadata": {
"type": "object",
@ -12310,7 +12321,15 @@
"type": "object"
}
]
}
},
"description": "Metadata associated with the chunk, such as document ID, source, or other relevant information."
},
"embedding": {
"type": "array",
"items": {
"type": "number"
},
"description": "Optional embedding for the chunk. If not provided, it will be computed later."
}
},
"additionalProperties": false,
@ -12318,7 +12337,8 @@
"content",
"metadata"
],
"title": "Chunk"
"title": "Chunk",
"description": "A chunk of content that can be inserted into a vector database."
}
},
"scores": {

View file

@ -7024,6 +7024,9 @@ components:
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
The content of the chunk, which can be interleaved text, images,
or other types.
metadata:
type: object
additionalProperties:
@ -7034,12 +7037,29 @@ components:
- type: string
- type: array
- type: object
description: >-
Metadata associated with the chunk, such as document ID, source,
or other relevant information.
embedding:
type: array
items:
type: number
description: >-
Optional embedding for the chunk. If not provided, it will be computed
later.
additionalProperties: false
required:
- content
- metadata
title: Chunk
description: The chunks to insert.
description: >-
A chunk of content that can be inserted into a vector database.
description: >-
The chunks to insert. Each `Chunk` should contain content which can be
interleaved text, images, or other types. `metadata`: `dict[str, Any]`
and `embedding`: `List[float]` are optional. If `metadata` is provided,
you configure how Llama Stack formats the chunk during generation. If
`embedding` is not provided, it will be computed later.
ttl_seconds:
type: integer
description: The time to live of the chunks.
@ -8537,6 +8557,9 @@ components:
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
The content of the chunk, which can be interleaved text, images,
or other types.
metadata:
type: object
additionalProperties:
@ -8547,11 +8570,23 @@ components:
- type: string
- type: array
- type: object
description: >-
Metadata associated with the chunk, such as document ID, source,
or other relevant information.
embedding:
type: array
items:
type: number
description: >-
Optional embedding for the chunk. If not provided, it will be computed
later.
additionalProperties: false
required:
- content
- metadata
title: Chunk
description: >-
A chunk of content that can be inserted into a vector database.
scores:
type: array
items:

View file

@ -57,6 +57,31 @@ chunks = [
]
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
```
#### Using Precomputed Embeddings
If you decide to precompute embeddings for your documents, you can insert them directly into the vector database by
including the embedding vectors in the chunk data. This is useful if you have a separate embedding service or if you
want to customize the ingestion process.
```python
chunks_with_embeddings = [
{
"content": "First chunk of text",
"mime_type": "text/plain",
"embedding": [0.1, 0.2, 0.3, ...], # Your precomputed embedding vector
"metadata": {"document_id": "doc1", "section": "introduction"},
},
{
"content": "Second chunk of text",
"mime_type": "text/plain",
"embedding": [0.2, 0.3, 0.4, ...], # Your precomputed embedding vector
"metadata": {"document_id": "doc1", "section": "methodology"},
},
]
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks_with_embeddings)
```
When providing precomputed embeddings, ensure the embedding dimension matches the embedding_dimension specified when
registering the vector database.
### Retrieval
You can query the vector database to retrieve documents based on their embeddings.
```python

View file

@ -19,8 +19,16 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class Chunk(BaseModel):
"""
A chunk of content that can be inserted into a vector database.
:param content: The content of the chunk, which can be interleaved text, images, or other types.
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
:param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information.
"""
content: InterleavedContent
metadata: dict[str, Any] = Field(default_factory=dict)
embedding: list[float] | None = None
@json_schema_type
@ -50,7 +58,10 @@ class VectorIO(Protocol):
"""Insert chunks into a vector database.
:param vector_db_id: The identifier of the vector database to insert the chunks into.
:param chunks: The chunks to insert.
:param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types.
`metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional.
If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
If `embedding` is not provided, it will be computed later.
:param ttl_seconds: The time to live of the chunks.
"""
...

View file

@ -146,7 +146,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
]
for i, chunk in enumerate(chunks):
metadata = chunk.metadata
tokens += metadata["token_count"]
tokens += metadata.get("token_count", 0)
tokens += metadata.get("metadata_token_count", 0)
if tokens > query_config.max_tokens_in_context:

View file

@ -199,11 +199,16 @@ class VectorDBWithIndex:
self,
chunks: list[Chunk],
) -> None:
embeddings_response = await self.inference_api.embeddings(
self.vector_db.embedding_model, [x.content for x in chunks]
)
embeddings = np.array(embeddings_response.embeddings)
chunks_to_embed = [c for c in chunks if c.embedding is None]
if chunks_to_embed:
resp = await self.inference_api.embeddings(
self.vector_db.embedding_model,
[c.content for c in chunks_to_embed],
)
for c, embedding in zip(chunks_to_embed, resp.embeddings, strict=False):
c.embedding = embedding
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
await self.index.add_chunks(chunks, embeddings)
async def query_chunks(

View file

@ -120,3 +120,37 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, sample_ch
top_match = response.chunks[0]
assert top_match is not None
assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}"
def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id):
vector_db_id = "test_precomputed_embeddings_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=384,
)
chunks_with_embeddings = [
Chunk(
content="This is a test chunk with precomputed embedding.",
metadata={"document_id": "doc1", "source": "precomputed"},
embedding=[0.1] * 384,
),
]
client_with_empty_registry.vector_io.insert(
vector_db_id=vector_db_id,
chunks=chunks_with_embeddings,
)
# Query for the first document
response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
query="precomputed embedding test",
)
# Verify the top result is the expected document
assert response is not None
assert len(response.chunks) > 0
assert response.chunks[0].metadata["document_id"] == "doc1"
assert response.chunks[0].metadata["source"] == "precomputed"

View file

@ -8,11 +8,19 @@ import base64
import mimetypes
import os
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from llama_stack.apis.tools import RAGDocument
from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc, make_overlapped_chunks
from llama_stack.apis.vector_io import Chunk
from llama_stack.providers.utils.memory.vector_store import (
URL,
VectorDBWithIndex,
content_from_doc,
make_overlapped_chunks,
)
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
# Depending on the machine, this can get parsed a couple of ways
@ -36,6 +44,25 @@ def data_url_from_file(file_path: str) -> str:
return data_url
class TestChunk:
def test_chunk(self):
chunk = Chunk(
content="Example chunk content",
metadata={"key": "value"},
embedding=[0.1, 0.2, 0.3],
)
assert chunk.content == "Example chunk content"
assert chunk.metadata == {"key": "value"}
assert chunk.embedding == [0.1, 0.2, 0.3]
chunk_no_embedding = Chunk(
content="Example chunk content",
metadata={"key": "value"},
)
assert chunk_no_embedding.embedding is None
class TestVectorStore:
@pytest.mark.asyncio
async def test_returns_content_from_pdf_data_uri(self):
@ -126,3 +153,55 @@ class TestVectorStore:
assert str(excinfo.value) == "Failed to serialize metadata to string"
assert isinstance(excinfo.value.__cause__, TypeError)
assert str(excinfo.value.__cause__) == "Cannot convert to string"
class TestVectorDBWithIndex:
@pytest.mark.asyncio
async def test_insert_chunks_without_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model without embeddings"
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
)
chunks = [
Chunk(content="Test 1", embedding=None, metadata={}),
Chunk(content="Test 2", embedding=None, metadata={}),
]
mock_inference_api.embeddings.return_value.embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
await vector_db_with_index.insert_chunks(chunks)
mock_inference_api.embeddings.assert_called_once_with("test-model without embeddings", ["Test 1", "Test 2"])
mock_index.add_chunks.assert_called_once()
args = mock_index.add_chunks.call_args[0]
assert args[0] == chunks
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
@pytest.mark.asyncio
async def test_insert_chunks_with_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model with embeddings"
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
)
chunks = [
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3], metadata={}),
Chunk(content="Test 2", embedding=[0.4, 0.5, 0.6], metadata={}),
]
await vector_db_with_index.insert_chunks(chunks)
mock_inference_api.embeddings.assert_not_called()
mock_index.add_chunks.assert_called_once()
args = mock_index.add_chunks.call_args[0]
assert args[0] == chunks
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))