updated tests and refactored the validation for readability

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-05-30 17:07:20 -04:00
parent 73456878e5
commit 681e697fff
2 changed files with 143 additions and 2 deletions

View file

@ -171,6 +171,22 @@ def make_overlapped_chunks(
return chunks
def _validate_embedding(embedding: NDArray, index: int, expected_dimension: int):
"""Helper method to validate embedding format and dimensions"""
if not isinstance(embedding, (list | np.ndarray)):
raise ValueError(f"Embedding at index {index} must be a list or numpy array, got {type(embedding)}")
if isinstance(embedding, np.ndarray):
if not np.issubdtype(embedding.dtype, np.number):
raise ValueError(f"Embedding at index {index} contains non-numeric values")
else:
if not all(isinstance(e, (float | int | np.number)) for e in embedding):
raise ValueError(f"Embedding at index {index} contains non-numeric values")
if len(embedding) != expected_dimension:
raise ValueError(f"Embedding at index {index} has dimension {len(embedding)}, expected {expected_dimension}")
class EmbeddingIndex(ABC):
@abstractmethod
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
@ -199,7 +215,13 @@ class VectorDBWithIndex:
self,
chunks: list[Chunk],
) -> None:
chunks_to_embed = [c for c in chunks if c.embedding is None]
chunks_to_embed = []
for i, c in enumerate(chunks):
if c.embedding is None:
chunks_to_embed.append(c)
else:
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
if chunks_to_embed:
resp = await self.inference_api.embeddings(
self.vector_db.embedding_model,