feat: Enable ingestion of precomputed embeddings (#2317)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 3s
Integration Tests / test-matrix (http, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, post_training) (push) Failing after 9s
Integration Tests / test-matrix (http, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, datasets) (push) Failing after 10s
Integration Tests / test-matrix (http, inference) (push) Failing after 10s
Integration Tests / test-matrix (library, agents) (push) Failing after 9s
Integration Tests / test-matrix (http, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, datasets) (push) Failing after 8s
Integration Tests / test-matrix (http, providers) (push) Failing after 9s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (library, inference) (push) Failing after 9s
Test External Providers / test-external-providers (venv) (push) Failing after 6s
Integration Tests / test-matrix (library, inspect) (push) Failing after 8s
Integration Tests / test-matrix (library, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, scoring) (push) Failing after 8s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 7s
Unit Tests / unit-tests (3.10) (push) Failing after 9s
Unit Tests / unit-tests (3.13) (push) Failing after 7s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 9s
Unit Tests / unit-tests (3.12) (push) Failing after 9s
Update ReadTheDocs / update-readthedocs (push) Failing after 7s
Pre-commit / pre-commit (push) Successful in 1m15s

This commit is contained in:
Francisco Arceo 2025-05-31 04:03:37 -06:00 committed by GitHub
parent 31ce208bda
commit f328436831
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 366 additions and 15 deletions

View file

@ -120,3 +120,37 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, sample_ch
top_match = response.chunks[0]
assert top_match is not None
assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}"
def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id):
vector_db_id = "test_precomputed_embeddings_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=384,
)
chunks_with_embeddings = [
Chunk(
content="This is a test chunk with precomputed embedding.",
metadata={"document_id": "doc1", "source": "precomputed"},
embedding=[0.1] * 384,
),
]
client_with_empty_registry.vector_io.insert(
vector_db_id=vector_db_id,
chunks=chunks_with_embeddings,
)
# Query for the first document
response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
query="precomputed embedding test",
)
# Verify the top result is the expected document
assert response is not None
assert len(response.chunks) > 0
assert response.chunks[0].metadata["document_id"] == "doc1"
assert response.chunks[0].metadata["source"] == "precomputed"

View file

@ -50,6 +50,7 @@ def mock_vector_db(vector_db_id) -> MagicMock:
mock_vector_db = MagicMock(spec=VectorDB)
mock_vector_db.embedding_model = "embedding_model"
mock_vector_db.identifier = vector_db_id
mock_vector_db.embedding_dimension = 384
return mock_vector_db

View file

@ -8,11 +8,20 @@ import base64
import mimetypes
import os
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from llama_stack.apis.tools import RAGDocument
from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc, make_overlapped_chunks
from llama_stack.apis.vector_io import Chunk
from llama_stack.providers.utils.memory.vector_store import (
URL,
VectorDBWithIndex,
_validate_embedding,
content_from_doc,
make_overlapped_chunks,
)
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
# Depending on the machine, this can get parsed a couple of ways
@ -36,6 +45,72 @@ def data_url_from_file(file_path: str) -> str:
return data_url
class TestChunk:
def test_chunk(self):
chunk = Chunk(
content="Example chunk content",
metadata={"key": "value"},
embedding=[0.1, 0.2, 0.3],
)
assert chunk.content == "Example chunk content"
assert chunk.metadata == {"key": "value"}
assert chunk.embedding == [0.1, 0.2, 0.3]
chunk_no_embedding = Chunk(
content="Example chunk content",
metadata={"key": "value"},
)
assert chunk_no_embedding.embedding is None
class TestValidateEmbedding:
def test_valid_list_embeddings(self):
_validate_embedding([0.1, 0.2, 0.3], 0, 3)
_validate_embedding([1, 2, 3], 1, 3)
_validate_embedding([0.1, 2, 3.5], 2, 3)
def test_valid_numpy_embeddings(self):
_validate_embedding(np.array([0.1, 0.2, 0.3], dtype=np.float32), 0, 3)
_validate_embedding(np.array([0.1, 0.2, 0.3], dtype=np.float64), 1, 3)
_validate_embedding(np.array([1, 2, 3], dtype=np.int32), 2, 3)
_validate_embedding(np.array([1, 2, 3], dtype=np.int64), 3, 3)
def test_invalid_embedding_type(self):
error_msg = "must be a list or numpy array"
with pytest.raises(ValueError, match=error_msg):
_validate_embedding("not a list", 0, 3)
with pytest.raises(ValueError, match=error_msg):
_validate_embedding(None, 1, 3)
with pytest.raises(ValueError, match=error_msg):
_validate_embedding(42, 2, 3)
def test_non_numeric_values(self):
error_msg = "contains non-numeric values"
with pytest.raises(ValueError, match=error_msg):
_validate_embedding([0.1, "string", 0.3], 0, 3)
with pytest.raises(ValueError, match=error_msg):
_validate_embedding([0.1, None, 0.3], 1, 3)
with pytest.raises(ValueError, match=error_msg):
_validate_embedding([1, {}, 3], 2, 3)
def test_wrong_dimension(self):
with pytest.raises(ValueError, match="has dimension 4, expected 3"):
_validate_embedding([0.1, 0.2, 0.3, 0.4], 0, 3)
with pytest.raises(ValueError, match="has dimension 2, expected 3"):
_validate_embedding([0.1, 0.2], 1, 3)
with pytest.raises(ValueError, match="has dimension 0, expected 3"):
_validate_embedding([], 2, 3)
class TestVectorStore:
@pytest.mark.asyncio
async def test_returns_content_from_pdf_data_uri(self):
@ -126,3 +201,126 @@ class TestVectorStore:
assert str(excinfo.value) == "Failed to serialize metadata to string"
assert isinstance(excinfo.value.__cause__, TypeError)
assert str(excinfo.value.__cause__) == "Cannot convert to string"
class TestVectorDBWithIndex:
@pytest.mark.asyncio
async def test_insert_chunks_without_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model without embeddings"
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
)
chunks = [
Chunk(content="Test 1", embedding=None, metadata={}),
Chunk(content="Test 2", embedding=None, metadata={}),
]
mock_inference_api.embeddings.return_value.embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
await vector_db_with_index.insert_chunks(chunks)
mock_inference_api.embeddings.assert_called_once_with("test-model without embeddings", ["Test 1", "Test 2"])
mock_index.add_chunks.assert_called_once()
args = mock_index.add_chunks.call_args[0]
assert args[0] == chunks
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
@pytest.mark.asyncio
async def test_insert_chunks_with_valid_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model with embeddings"
mock_vector_db.embedding_dimension = 3
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
)
chunks = [
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3], metadata={}),
Chunk(content="Test 2", embedding=[0.4, 0.5, 0.6], metadata={}),
]
await vector_db_with_index.insert_chunks(chunks)
mock_inference_api.embeddings.assert_not_called()
mock_index.add_chunks.assert_called_once()
args = mock_index.add_chunks.call_args[0]
assert args[0] == chunks
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
@pytest.mark.asyncio
async def test_insert_chunks_with_invalid_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_dimension = 3
mock_vector_db.embedding_model = "test-model with invalid embeddings"
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
)
# Verify Chunk raises ValueError for invalid embedding type
with pytest.raises(ValueError, match="Input should be a valid list"):
Chunk(content="Test 1", embedding="invalid_type", metadata={})
# Verify Chunk raises ValueError for invalid embedding type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
with pytest.raises(ValueError, match="Input should be a valid list"):
await vector_db_with_index.insert_chunks(
[
Chunk(content="Test 1", embedding=None, metadata={}),
Chunk(content="Test 2", embedding="invalid_type", metadata={}),
]
)
# Verify Chunk raises ValueError for invalid embedding element type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
with pytest.raises(ValueError, match=" Input should be a valid number, unable to parse string as a number "):
await vector_db_with_index.insert_chunks(
Chunk(content="Test 1", embedding=[0.1, "string", 0.3], metadata={})
)
chunks_wrong_dim = [
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3, 0.4], metadata={}),
]
with pytest.raises(ValueError, match="has dimension 4, expected 3"):
await vector_db_with_index.insert_chunks(chunks_wrong_dim)
mock_inference_api.embeddings.assert_not_called()
mock_index.add_chunks.assert_not_called()
@pytest.mark.asyncio
async def test_insert_chunks_with_partially_precomputed_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model with partial embeddings"
mock_vector_db.embedding_dimension = 3
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
)
chunks = [
Chunk(content="Test 1", embedding=None, metadata={}),
Chunk(content="Test 2", embedding=[0.2, 0.2, 0.2], metadata={}),
Chunk(content="Test 3", embedding=None, metadata={}),
]
mock_inference_api.embeddings.return_value.embeddings = [[0.1, 0.1, 0.1], [0.3, 0.3, 0.3]]
await vector_db_with_index.insert_chunks(chunks)
mock_inference_api.embeddings.assert_called_once_with(
"test-model with partial embeddings", ["Test 1", "Test 3"]
)
mock_index.add_chunks.assert_called_once()
args = mock_index.add_chunks.call_args[0]
assert len(args[0]) == 3
assert np.array_equal(args[1], np.array([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]], dtype=np.float32))