mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
updated tests and refactored the validation for readability
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
73456878e5
commit
681e697fff
2 changed files with 143 additions and 2 deletions
|
@ -171,6 +171,22 @@ def make_overlapped_chunks(
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_embedding(embedding: NDArray, index: int, expected_dimension: int):
|
||||||
|
"""Helper method to validate embedding format and dimensions"""
|
||||||
|
if not isinstance(embedding, (list | np.ndarray)):
|
||||||
|
raise ValueError(f"Embedding at index {index} must be a list or numpy array, got {type(embedding)}")
|
||||||
|
|
||||||
|
if isinstance(embedding, np.ndarray):
|
||||||
|
if not np.issubdtype(embedding.dtype, np.number):
|
||||||
|
raise ValueError(f"Embedding at index {index} contains non-numeric values")
|
||||||
|
else:
|
||||||
|
if not all(isinstance(e, (float | int | np.number)) for e in embedding):
|
||||||
|
raise ValueError(f"Embedding at index {index} contains non-numeric values")
|
||||||
|
|
||||||
|
if len(embedding) != expected_dimension:
|
||||||
|
raise ValueError(f"Embedding at index {index} has dimension {len(embedding)}, expected {expected_dimension}")
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingIndex(ABC):
|
class EmbeddingIndex(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||||
|
@ -199,7 +215,13 @@ class VectorDBWithIndex:
|
||||||
self,
|
self,
|
||||||
chunks: list[Chunk],
|
chunks: list[Chunk],
|
||||||
) -> None:
|
) -> None:
|
||||||
chunks_to_embed = [c for c in chunks if c.embedding is None]
|
chunks_to_embed = []
|
||||||
|
for i, c in enumerate(chunks):
|
||||||
|
if c.embedding is None:
|
||||||
|
chunks_to_embed.append(c)
|
||||||
|
else:
|
||||||
|
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
|
||||||
|
|
||||||
if chunks_to_embed:
|
if chunks_to_embed:
|
||||||
resp = await self.inference_api.embeddings(
|
resp = await self.inference_api.embeddings(
|
||||||
self.vector_db.embedding_model,
|
self.vector_db.embedding_model,
|
||||||
|
|
|
@ -18,6 +18,7 @@ from llama_stack.apis.vector_io import Chunk
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
URL,
|
URL,
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
|
_validate_embedding,
|
||||||
content_from_doc,
|
content_from_doc,
|
||||||
make_overlapped_chunks,
|
make_overlapped_chunks,
|
||||||
)
|
)
|
||||||
|
@ -63,6 +64,53 @@ class TestChunk:
|
||||||
assert chunk_no_embedding.embedding is None
|
assert chunk_no_embedding.embedding is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateEmbedding:
|
||||||
|
def test_valid_list_embeddings(self):
|
||||||
|
_validate_embedding([0.1, 0.2, 0.3], 0, 3)
|
||||||
|
_validate_embedding([1, 2, 3], 1, 3)
|
||||||
|
_validate_embedding([0.1, 2, 3.5], 2, 3)
|
||||||
|
|
||||||
|
def test_valid_numpy_embeddings(self):
|
||||||
|
_validate_embedding(np.array([0.1, 0.2, 0.3], dtype=np.float32), 0, 3)
|
||||||
|
_validate_embedding(np.array([0.1, 0.2, 0.3], dtype=np.float64), 1, 3)
|
||||||
|
_validate_embedding(np.array([1, 2, 3], dtype=np.int32), 2, 3)
|
||||||
|
_validate_embedding(np.array([1, 2, 3], dtype=np.int64), 3, 3)
|
||||||
|
|
||||||
|
def test_invalid_embedding_type(self):
|
||||||
|
error_msg = "must be a list or numpy array"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
_validate_embedding("not a list", 0, 3)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
_validate_embedding(None, 1, 3)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
_validate_embedding(42, 2, 3)
|
||||||
|
|
||||||
|
def test_non_numeric_values(self):
|
||||||
|
error_msg = "contains non-numeric values"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
_validate_embedding([0.1, "string", 0.3], 0, 3)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
_validate_embedding([0.1, None, 0.3], 1, 3)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
_validate_embedding([1, {}, 3], 2, 3)
|
||||||
|
|
||||||
|
def test_wrong_dimension(self):
|
||||||
|
with pytest.raises(ValueError, match="has dimension 4, expected 3"):
|
||||||
|
_validate_embedding([0.1, 0.2, 0.3, 0.4], 0, 3)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="has dimension 2, expected 3"):
|
||||||
|
_validate_embedding([0.1, 0.2], 1, 3)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="has dimension 0, expected 3"):
|
||||||
|
_validate_embedding([], 2, 3)
|
||||||
|
|
||||||
|
|
||||||
class TestVectorStore:
|
class TestVectorStore:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_returns_content_from_pdf_data_uri(self):
|
async def test_returns_content_from_pdf_data_uri(self):
|
||||||
|
@ -183,9 +231,10 @@ class TestVectorDBWithIndex:
|
||||||
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_insert_chunks_with_embeddings(self):
|
async def test_insert_chunks_with_valid_embeddings(self):
|
||||||
mock_vector_db = MagicMock()
|
mock_vector_db = MagicMock()
|
||||||
mock_vector_db.embedding_model = "test-model with embeddings"
|
mock_vector_db.embedding_model = "test-model with embeddings"
|
||||||
|
mock_vector_db.embedding_dimension = 3
|
||||||
mock_index = AsyncMock()
|
mock_index = AsyncMock()
|
||||||
mock_inference_api = AsyncMock()
|
mock_inference_api = AsyncMock()
|
||||||
|
|
||||||
|
@ -205,3 +254,73 @@ class TestVectorDBWithIndex:
|
||||||
args = mock_index.add_chunks.call_args[0]
|
args = mock_index.add_chunks.call_args[0]
|
||||||
assert args[0] == chunks
|
assert args[0] == chunks
|
||||||
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_insert_chunks_with_invalid_embeddings(self):
|
||||||
|
mock_vector_db = MagicMock()
|
||||||
|
mock_vector_db.embedding_dimension = 3
|
||||||
|
mock_vector_db.embedding_model = "test-model with invalid embeddings"
|
||||||
|
mock_index = AsyncMock()
|
||||||
|
mock_inference_api = AsyncMock()
|
||||||
|
|
||||||
|
vector_db_with_index = VectorDBWithIndex(
|
||||||
|
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify Chunk raises ValueError for invalid embedding type
|
||||||
|
with pytest.raises(ValueError, match="Input should be a valid list"):
|
||||||
|
Chunk(content="Test 1", embedding="invalid_type", metadata={})
|
||||||
|
|
||||||
|
# Verify Chunk raises ValueError for invalid embedding type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
|
||||||
|
with pytest.raises(ValueError, match="Input should be a valid list"):
|
||||||
|
await vector_db_with_index.insert_chunks(
|
||||||
|
[
|
||||||
|
Chunk(content="Test 1", embedding=None, metadata={}),
|
||||||
|
Chunk(content="Test 2", embedding="invalid_type", metadata={}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify Chunk raises ValueError for invalid embedding element type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
|
||||||
|
with pytest.raises(ValueError, match=" Input should be a valid number, unable to parse string as a number "):
|
||||||
|
await vector_db_with_index.insert_chunks(
|
||||||
|
Chunk(content="Test 1", embedding=[0.1, "string", 0.3], metadata={})
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks_wrong_dim = [
|
||||||
|
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3, 0.4], metadata={}),
|
||||||
|
]
|
||||||
|
with pytest.raises(ValueError, match="has dimension 4, expected 3"):
|
||||||
|
await vector_db_with_index.insert_chunks(chunks_wrong_dim)
|
||||||
|
|
||||||
|
mock_inference_api.embeddings.assert_not_called()
|
||||||
|
mock_index.add_chunks.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_insert_chunks_with_partially_precomputed_embeddings(self):
|
||||||
|
mock_vector_db = MagicMock()
|
||||||
|
mock_vector_db.embedding_model = "test-model with partial embeddings"
|
||||||
|
mock_vector_db.embedding_dimension = 3
|
||||||
|
mock_index = AsyncMock()
|
||||||
|
mock_inference_api = AsyncMock()
|
||||||
|
|
||||||
|
vector_db_with_index = VectorDBWithIndex(
|
||||||
|
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
Chunk(content="Test 1", embedding=None, metadata={}),
|
||||||
|
Chunk(content="Test 2", embedding=[0.2, 0.2, 0.2], metadata={}),
|
||||||
|
Chunk(content="Test 3", embedding=None, metadata={}),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_inference_api.embeddings.return_value.embeddings = [[0.1, 0.1, 0.1], [0.3, 0.3, 0.3]]
|
||||||
|
|
||||||
|
await vector_db_with_index.insert_chunks(chunks)
|
||||||
|
|
||||||
|
mock_inference_api.embeddings.assert_called_once_with(
|
||||||
|
"test-model with partial embeddings", ["Test 1", "Test 3"]
|
||||||
|
)
|
||||||
|
mock_index.add_chunks.assert_called_once()
|
||||||
|
args = mock_index.add_chunks.call_args[0]
|
||||||
|
assert len(args[0]) == 3
|
||||||
|
assert np.array_equal(args[1], np.array([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]], dtype=np.float32))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue