forked from phoenix-oss/llama-stack-mirror
feat: Enable ingestion of precomputed embeddings (#2317)
This commit is contained in:
parent
31ce208bda
commit
f328436831
9 changed files with 366 additions and 15 deletions
34
docs/_static/llama-stack-spec.html
vendored
34
docs/_static/llama-stack-spec.html
vendored
|
@ -10020,7 +10020,8 @@
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"content": {
|
"content": {
|
||||||
"$ref": "#/components/schemas/InterleavedContent"
|
"$ref": "#/components/schemas/InterleavedContent",
|
||||||
|
"description": "The content of the chunk, which can be interleaved text, images, or other types."
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -10045,7 +10046,15 @@
|
||||||
"type": "object"
|
"type": "object"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
},
|
||||||
|
"description": "Metadata associated with the chunk, such as document ID, source, or other relevant information."
|
||||||
|
},
|
||||||
|
"embedding": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"description": "Optional embedding for the chunk. If not provided, it will be computed later."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -10053,9 +10062,10 @@
|
||||||
"content",
|
"content",
|
||||||
"metadata"
|
"metadata"
|
||||||
],
|
],
|
||||||
"title": "Chunk"
|
"title": "Chunk",
|
||||||
|
"description": "A chunk of content that can be inserted into a vector database."
|
||||||
},
|
},
|
||||||
"description": "The chunks to insert."
|
"description": "The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types. `metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional. If `metadata` is provided, you configure how Llama Stack formats the chunk during generation. If `embedding` is not provided, it will be computed later."
|
||||||
},
|
},
|
||||||
"ttl_seconds": {
|
"ttl_seconds": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
|
@ -12285,7 +12295,8 @@
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"content": {
|
"content": {
|
||||||
"$ref": "#/components/schemas/InterleavedContent"
|
"$ref": "#/components/schemas/InterleavedContent",
|
||||||
|
"description": "The content of the chunk, which can be interleaved text, images, or other types."
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -12310,7 +12321,15 @@
|
||||||
"type": "object"
|
"type": "object"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
},
|
||||||
|
"description": "Metadata associated with the chunk, such as document ID, source, or other relevant information."
|
||||||
|
},
|
||||||
|
"embedding": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"description": "Optional embedding for the chunk. If not provided, it will be computed later."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -12318,7 +12337,8 @@
|
||||||
"content",
|
"content",
|
||||||
"metadata"
|
"metadata"
|
||||||
],
|
],
|
||||||
"title": "Chunk"
|
"title": "Chunk",
|
||||||
|
"description": "A chunk of content that can be inserted into a vector database."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"scores": {
|
"scores": {
|
||||||
|
|
37
docs/_static/llama-stack-spec.yaml
vendored
37
docs/_static/llama-stack-spec.yaml
vendored
|
@ -7024,6 +7024,9 @@ components:
|
||||||
properties:
|
properties:
|
||||||
content:
|
content:
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
$ref: '#/components/schemas/InterleavedContent'
|
||||||
|
description: >-
|
||||||
|
The content of the chunk, which can be interleaved text, images,
|
||||||
|
or other types.
|
||||||
metadata:
|
metadata:
|
||||||
type: object
|
type: object
|
||||||
additionalProperties:
|
additionalProperties:
|
||||||
|
@ -7034,12 +7037,29 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: array
|
||||||
- type: object
|
- type: object
|
||||||
|
description: >-
|
||||||
|
Metadata associated with the chunk, such as document ID, source,
|
||||||
|
or other relevant information.
|
||||||
|
embedding:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: number
|
||||||
|
description: >-
|
||||||
|
Optional embedding for the chunk. If not provided, it will be computed
|
||||||
|
later.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- content
|
- content
|
||||||
- metadata
|
- metadata
|
||||||
title: Chunk
|
title: Chunk
|
||||||
description: The chunks to insert.
|
description: >-
|
||||||
|
A chunk of content that can be inserted into a vector database.
|
||||||
|
description: >-
|
||||||
|
The chunks to insert. Each `Chunk` should contain content which can be
|
||||||
|
interleaved text, images, or other types. `metadata`: `dict[str, Any]`
|
||||||
|
and `embedding`: `List[float]` are optional. If `metadata` is provided,
|
||||||
|
you configure how Llama Stack formats the chunk during generation. If
|
||||||
|
`embedding` is not provided, it will be computed later.
|
||||||
ttl_seconds:
|
ttl_seconds:
|
||||||
type: integer
|
type: integer
|
||||||
description: The time to live of the chunks.
|
description: The time to live of the chunks.
|
||||||
|
@ -8537,6 +8557,9 @@ components:
|
||||||
properties:
|
properties:
|
||||||
content:
|
content:
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
$ref: '#/components/schemas/InterleavedContent'
|
||||||
|
description: >-
|
||||||
|
The content of the chunk, which can be interleaved text, images,
|
||||||
|
or other types.
|
||||||
metadata:
|
metadata:
|
||||||
type: object
|
type: object
|
||||||
additionalProperties:
|
additionalProperties:
|
||||||
|
@ -8547,11 +8570,23 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: array
|
||||||
- type: object
|
- type: object
|
||||||
|
description: >-
|
||||||
|
Metadata associated with the chunk, such as document ID, source,
|
||||||
|
or other relevant information.
|
||||||
|
embedding:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: number
|
||||||
|
description: >-
|
||||||
|
Optional embedding for the chunk. If not provided, it will be computed
|
||||||
|
later.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- content
|
- content
|
||||||
- metadata
|
- metadata
|
||||||
title: Chunk
|
title: Chunk
|
||||||
|
description: >-
|
||||||
|
A chunk of content that can be inserted into a vector database.
|
||||||
scores:
|
scores:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
|
|
|
@ -57,6 +57,31 @@ chunks = [
|
||||||
]
|
]
|
||||||
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
|
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Using Precomputed Embeddings
|
||||||
|
If you decide to precompute embeddings for your documents, you can insert them directly into the vector database by
|
||||||
|
including the embedding vectors in the chunk data. This is useful if you have a separate embedding service or if you
|
||||||
|
want to customize the ingestion process.
|
||||||
|
```python
|
||||||
|
chunks_with_embeddings = [
|
||||||
|
{
|
||||||
|
"content": "First chunk of text",
|
||||||
|
"mime_type": "text/plain",
|
||||||
|
"embedding": [0.1, 0.2, 0.3, ...], # Your precomputed embedding vector
|
||||||
|
"metadata": {"document_id": "doc1", "section": "introduction"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "Second chunk of text",
|
||||||
|
"mime_type": "text/plain",
|
||||||
|
"embedding": [0.2, 0.3, 0.4, ...], # Your precomputed embedding vector
|
||||||
|
"metadata": {"document_id": "doc1", "section": "methodology"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks_with_embeddings)
|
||||||
|
```
|
||||||
|
When providing precomputed embeddings, ensure the embedding dimension matches the embedding_dimension specified when
|
||||||
|
registering the vector database.
|
||||||
|
|
||||||
### Retrieval
|
### Retrieval
|
||||||
You can query the vector database to retrieve documents based on their embeddings.
|
You can query the vector database to retrieve documents based on their embeddings.
|
||||||
```python
|
```python
|
||||||
|
|
|
@ -19,8 +19,16 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class Chunk(BaseModel):
|
class Chunk(BaseModel):
|
||||||
|
"""
|
||||||
|
A chunk of content that can be inserted into a vector database.
|
||||||
|
:param content: The content of the chunk, which can be interleaved text, images, or other types.
|
||||||
|
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
|
||||||
|
:param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information.
|
||||||
|
"""
|
||||||
|
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
embedding: list[float] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -50,7 +58,10 @@ class VectorIO(Protocol):
|
||||||
"""Insert chunks into a vector database.
|
"""Insert chunks into a vector database.
|
||||||
|
|
||||||
:param vector_db_id: The identifier of the vector database to insert the chunks into.
|
:param vector_db_id: The identifier of the vector database to insert the chunks into.
|
||||||
:param chunks: The chunks to insert.
|
:param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types.
|
||||||
|
`metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional.
|
||||||
|
If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
|
||||||
|
If `embedding` is not provided, it will be computed later.
|
||||||
:param ttl_seconds: The time to live of the chunks.
|
:param ttl_seconds: The time to live of the chunks.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -146,7 +146,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
]
|
]
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
metadata = chunk.metadata
|
metadata = chunk.metadata
|
||||||
tokens += metadata["token_count"]
|
tokens += metadata.get("token_count", 0)
|
||||||
tokens += metadata.get("metadata_token_count", 0)
|
tokens += metadata.get("metadata_token_count", 0)
|
||||||
|
|
||||||
if tokens > query_config.max_tokens_in_context:
|
if tokens > query_config.max_tokens_in_context:
|
||||||
|
|
|
@ -171,6 +171,22 @@ def make_overlapped_chunks(
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_embedding(embedding: NDArray, index: int, expected_dimension: int):
|
||||||
|
"""Helper method to validate embedding format and dimensions"""
|
||||||
|
if not isinstance(embedding, (list | np.ndarray)):
|
||||||
|
raise ValueError(f"Embedding at index {index} must be a list or numpy array, got {type(embedding)}")
|
||||||
|
|
||||||
|
if isinstance(embedding, np.ndarray):
|
||||||
|
if not np.issubdtype(embedding.dtype, np.number):
|
||||||
|
raise ValueError(f"Embedding at index {index} contains non-numeric values")
|
||||||
|
else:
|
||||||
|
if not all(isinstance(e, (float | int | np.number)) for e in embedding):
|
||||||
|
raise ValueError(f"Embedding at index {index} contains non-numeric values")
|
||||||
|
|
||||||
|
if len(embedding) != expected_dimension:
|
||||||
|
raise ValueError(f"Embedding at index {index} has dimension {len(embedding)}, expected {expected_dimension}")
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingIndex(ABC):
|
class EmbeddingIndex(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||||
|
@ -199,11 +215,22 @@ class VectorDBWithIndex:
|
||||||
self,
|
self,
|
||||||
chunks: list[Chunk],
|
chunks: list[Chunk],
|
||||||
) -> None:
|
) -> None:
|
||||||
embeddings_response = await self.inference_api.embeddings(
|
chunks_to_embed = []
|
||||||
self.vector_db.embedding_model, [x.content for x in chunks]
|
for i, c in enumerate(chunks):
|
||||||
)
|
if c.embedding is None:
|
||||||
embeddings = np.array(embeddings_response.embeddings)
|
chunks_to_embed.append(c)
|
||||||
|
else:
|
||||||
|
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
|
||||||
|
|
||||||
|
if chunks_to_embed:
|
||||||
|
resp = await self.inference_api.embeddings(
|
||||||
|
self.vector_db.embedding_model,
|
||||||
|
[c.content for c in chunks_to_embed],
|
||||||
|
)
|
||||||
|
for c, embedding in zip(chunks_to_embed, resp.embeddings, strict=False):
|
||||||
|
c.embedding = embedding
|
||||||
|
|
||||||
|
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
|
||||||
await self.index.add_chunks(chunks, embeddings)
|
await self.index.add_chunks(chunks, embeddings)
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
|
|
|
@ -120,3 +120,37 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, sample_ch
|
||||||
top_match = response.chunks[0]
|
top_match = response.chunks[0]
|
||||||
assert top_match is not None
|
assert top_match is not None
|
||||||
assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}"
|
assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id):
|
||||||
|
vector_db_id = "test_precomputed_embeddings_db"
|
||||||
|
client_with_empty_registry.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model=embedding_model_id,
|
||||||
|
embedding_dimension=384,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks_with_embeddings = [
|
||||||
|
Chunk(
|
||||||
|
content="This is a test chunk with precomputed embedding.",
|
||||||
|
metadata={"document_id": "doc1", "source": "precomputed"},
|
||||||
|
embedding=[0.1] * 384,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
client_with_empty_registry.vector_io.insert(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
chunks=chunks_with_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query for the first document
|
||||||
|
response = client_with_empty_registry.vector_io.query(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
query="precomputed embedding test",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the top result is the expected document
|
||||||
|
assert response is not None
|
||||||
|
assert len(response.chunks) > 0
|
||||||
|
assert response.chunks[0].metadata["document_id"] == "doc1"
|
||||||
|
assert response.chunks[0].metadata["source"] == "precomputed"
|
||||||
|
|
|
@ -50,6 +50,7 @@ def mock_vector_db(vector_db_id) -> MagicMock:
|
||||||
mock_vector_db = MagicMock(spec=VectorDB)
|
mock_vector_db = MagicMock(spec=VectorDB)
|
||||||
mock_vector_db.embedding_model = "embedding_model"
|
mock_vector_db.embedding_model = "embedding_model"
|
||||||
mock_vector_db.identifier = vector_db_id
|
mock_vector_db.identifier = vector_db_id
|
||||||
|
mock_vector_db.embedding_dimension = 384
|
||||||
return mock_vector_db
|
return mock_vector_db
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,11 +8,20 @@ import base64
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.tools import RAGDocument
|
from llama_stack.apis.tools import RAGDocument
|
||||||
from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc, make_overlapped_chunks
|
from llama_stack.apis.vector_io import Chunk
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
URL,
|
||||||
|
VectorDBWithIndex,
|
||||||
|
_validate_embedding,
|
||||||
|
content_from_doc,
|
||||||
|
make_overlapped_chunks,
|
||||||
|
)
|
||||||
|
|
||||||
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
|
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
|
||||||
# Depending on the machine, this can get parsed a couple of ways
|
# Depending on the machine, this can get parsed a couple of ways
|
||||||
|
@ -36,6 +45,72 @@ def data_url_from_file(file_path: str) -> str:
|
||||||
return data_url
|
return data_url
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunk:
|
||||||
|
def test_chunk(self):
|
||||||
|
chunk = Chunk(
|
||||||
|
content="Example chunk content",
|
||||||
|
metadata={"key": "value"},
|
||||||
|
embedding=[0.1, 0.2, 0.3],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chunk.content == "Example chunk content"
|
||||||
|
assert chunk.metadata == {"key": "value"}
|
||||||
|
assert chunk.embedding == [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
chunk_no_embedding = Chunk(
|
||||||
|
content="Example chunk content",
|
||||||
|
metadata={"key": "value"},
|
||||||
|
)
|
||||||
|
assert chunk_no_embedding.embedding is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateEmbedding:
|
||||||
|
def test_valid_list_embeddings(self):
|
||||||
|
_validate_embedding([0.1, 0.2, 0.3], 0, 3)
|
||||||
|
_validate_embedding([1, 2, 3], 1, 3)
|
||||||
|
_validate_embedding([0.1, 2, 3.5], 2, 3)
|
||||||
|
|
||||||
|
def test_valid_numpy_embeddings(self):
|
||||||
|
_validate_embedding(np.array([0.1, 0.2, 0.3], dtype=np.float32), 0, 3)
|
||||||
|
_validate_embedding(np.array([0.1, 0.2, 0.3], dtype=np.float64), 1, 3)
|
||||||
|
_validate_embedding(np.array([1, 2, 3], dtype=np.int32), 2, 3)
|
||||||
|
_validate_embedding(np.array([1, 2, 3], dtype=np.int64), 3, 3)
|
||||||
|
|
||||||
|
def test_invalid_embedding_type(self):
|
||||||
|
error_msg = "must be a list or numpy array"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
_validate_embedding("not a list", 0, 3)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
_validate_embedding(None, 1, 3)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
_validate_embedding(42, 2, 3)
|
||||||
|
|
||||||
|
def test_non_numeric_values(self):
|
||||||
|
error_msg = "contains non-numeric values"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
_validate_embedding([0.1, "string", 0.3], 0, 3)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
_validate_embedding([0.1, None, 0.3], 1, 3)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
_validate_embedding([1, {}, 3], 2, 3)
|
||||||
|
|
||||||
|
def test_wrong_dimension(self):
|
||||||
|
with pytest.raises(ValueError, match="has dimension 4, expected 3"):
|
||||||
|
_validate_embedding([0.1, 0.2, 0.3, 0.4], 0, 3)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="has dimension 2, expected 3"):
|
||||||
|
_validate_embedding([0.1, 0.2], 1, 3)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="has dimension 0, expected 3"):
|
||||||
|
_validate_embedding([], 2, 3)
|
||||||
|
|
||||||
|
|
||||||
class TestVectorStore:
|
class TestVectorStore:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_returns_content_from_pdf_data_uri(self):
|
async def test_returns_content_from_pdf_data_uri(self):
|
||||||
|
@ -126,3 +201,126 @@ class TestVectorStore:
|
||||||
assert str(excinfo.value) == "Failed to serialize metadata to string"
|
assert str(excinfo.value) == "Failed to serialize metadata to string"
|
||||||
assert isinstance(excinfo.value.__cause__, TypeError)
|
assert isinstance(excinfo.value.__cause__, TypeError)
|
||||||
assert str(excinfo.value.__cause__) == "Cannot convert to string"
|
assert str(excinfo.value.__cause__) == "Cannot convert to string"
|
||||||
|
|
||||||
|
|
||||||
|
class TestVectorDBWithIndex:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_insert_chunks_without_embeddings(self):
|
||||||
|
mock_vector_db = MagicMock()
|
||||||
|
mock_vector_db.embedding_model = "test-model without embeddings"
|
||||||
|
mock_index = AsyncMock()
|
||||||
|
mock_inference_api = AsyncMock()
|
||||||
|
|
||||||
|
vector_db_with_index = VectorDBWithIndex(
|
||||||
|
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
Chunk(content="Test 1", embedding=None, metadata={}),
|
||||||
|
Chunk(content="Test 2", embedding=None, metadata={}),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_inference_api.embeddings.return_value.embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
|
||||||
|
await vector_db_with_index.insert_chunks(chunks)
|
||||||
|
|
||||||
|
mock_inference_api.embeddings.assert_called_once_with("test-model without embeddings", ["Test 1", "Test 2"])
|
||||||
|
mock_index.add_chunks.assert_called_once()
|
||||||
|
args = mock_index.add_chunks.call_args[0]
|
||||||
|
assert args[0] == chunks
|
||||||
|
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_insert_chunks_with_valid_embeddings(self):
|
||||||
|
mock_vector_db = MagicMock()
|
||||||
|
mock_vector_db.embedding_model = "test-model with embeddings"
|
||||||
|
mock_vector_db.embedding_dimension = 3
|
||||||
|
mock_index = AsyncMock()
|
||||||
|
mock_inference_api = AsyncMock()
|
||||||
|
|
||||||
|
vector_db_with_index = VectorDBWithIndex(
|
||||||
|
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3], metadata={}),
|
||||||
|
Chunk(content="Test 2", embedding=[0.4, 0.5, 0.6], metadata={}),
|
||||||
|
]
|
||||||
|
|
||||||
|
await vector_db_with_index.insert_chunks(chunks)
|
||||||
|
|
||||||
|
mock_inference_api.embeddings.assert_not_called()
|
||||||
|
mock_index.add_chunks.assert_called_once()
|
||||||
|
args = mock_index.add_chunks.call_args[0]
|
||||||
|
assert args[0] == chunks
|
||||||
|
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_insert_chunks_with_invalid_embeddings(self):
|
||||||
|
mock_vector_db = MagicMock()
|
||||||
|
mock_vector_db.embedding_dimension = 3
|
||||||
|
mock_vector_db.embedding_model = "test-model with invalid embeddings"
|
||||||
|
mock_index = AsyncMock()
|
||||||
|
mock_inference_api = AsyncMock()
|
||||||
|
|
||||||
|
vector_db_with_index = VectorDBWithIndex(
|
||||||
|
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify Chunk raises ValueError for invalid embedding type
|
||||||
|
with pytest.raises(ValueError, match="Input should be a valid list"):
|
||||||
|
Chunk(content="Test 1", embedding="invalid_type", metadata={})
|
||||||
|
|
||||||
|
# Verify Chunk raises ValueError for invalid embedding type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
|
||||||
|
with pytest.raises(ValueError, match="Input should be a valid list"):
|
||||||
|
await vector_db_with_index.insert_chunks(
|
||||||
|
[
|
||||||
|
Chunk(content="Test 1", embedding=None, metadata={}),
|
||||||
|
Chunk(content="Test 2", embedding="invalid_type", metadata={}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify Chunk raises ValueError for invalid embedding element type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
|
||||||
|
with pytest.raises(ValueError, match=" Input should be a valid number, unable to parse string as a number "):
|
||||||
|
await vector_db_with_index.insert_chunks(
|
||||||
|
Chunk(content="Test 1", embedding=[0.1, "string", 0.3], metadata={})
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks_wrong_dim = [
|
||||||
|
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3, 0.4], metadata={}),
|
||||||
|
]
|
||||||
|
with pytest.raises(ValueError, match="has dimension 4, expected 3"):
|
||||||
|
await vector_db_with_index.insert_chunks(chunks_wrong_dim)
|
||||||
|
|
||||||
|
mock_inference_api.embeddings.assert_not_called()
|
||||||
|
mock_index.add_chunks.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_insert_chunks_with_partially_precomputed_embeddings(self):
|
||||||
|
mock_vector_db = MagicMock()
|
||||||
|
mock_vector_db.embedding_model = "test-model with partial embeddings"
|
||||||
|
mock_vector_db.embedding_dimension = 3
|
||||||
|
mock_index = AsyncMock()
|
||||||
|
mock_inference_api = AsyncMock()
|
||||||
|
|
||||||
|
vector_db_with_index = VectorDBWithIndex(
|
||||||
|
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
Chunk(content="Test 1", embedding=None, metadata={}),
|
||||||
|
Chunk(content="Test 2", embedding=[0.2, 0.2, 0.2], metadata={}),
|
||||||
|
Chunk(content="Test 3", embedding=None, metadata={}),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_inference_api.embeddings.return_value.embeddings = [[0.1, 0.1, 0.1], [0.3, 0.3, 0.3]]
|
||||||
|
|
||||||
|
await vector_db_with_index.insert_chunks(chunks)
|
||||||
|
|
||||||
|
mock_inference_api.embeddings.assert_called_once_with(
|
||||||
|
"test-model with partial embeddings", ["Test 1", "Test 3"]
|
||||||
|
)
|
||||||
|
mock_index.add_chunks.assert_called_once()
|
||||||
|
args = mock_index.add_chunks.call_args[0]
|
||||||
|
assert len(args[0]) == 3
|
||||||
|
assert np.array_equal(args[1], np.array([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]], dtype=np.float32))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue