mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
refactor(test): unify vector_io tests and make them configurable (#1398)
## Test Plan `LLAMA_STACK_CONFIG=inference=sentence-transformers,vector_io=sqlite-vec pytest -s -v test_vector_io.py --embedding-model all-miniLM-L6-V2 --inference-model='' --vision-inference-model=''` ``` test_vector_io.py::test_vector_db_retrieve[txt=:vis=:emb=all-miniLM-L6-V2] PASSED test_vector_io.py::test_vector_db_register[txt=:vis=:emb=all-miniLM-L6-V2] PASSED test_vector_io.py::test_insert_chunks[txt=:vis=:emb=all-miniLM-L6-V2-test_case0] PASSED test_vector_io.py::test_insert_chunks[txt=:vis=:emb=all-miniLM-L6-V2-test_case1] PASSED test_vector_io.py::test_insert_chunks[txt=:vis=:emb=all-miniLM-L6-V2-test_case2] PASSED test_vector_io.py::test_insert_chunks[txt=:vis=:emb=all-miniLM-L6-V2-test_case3] PASSED test_vector_io.py::test_insert_chunks[txt=:vis=:emb=all-miniLM-L6-V2-test_case4] PASSED ``` Same thing with: - LLAMA_STACK_CONFIG=inference=sentence-transformers,vector_io=faiss - LLAMA_STACK_CONFIG=fireworks (Note that ergonomics will soon be improved re: cmd-line options and env variables)
This commit is contained in:
parent
fd8c991393
commit
dd0db8038b
27 changed files with 117 additions and 559 deletions
|
@ -44,9 +44,9 @@ class TelemetryConfig(BaseModel):
|
|||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str = "runtime", db_name: str = "trace_store.db") -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
||||
"sqlite_db_path": "${env.SQLITE_DB_PATH:~/.llama/" + __distro_dir__ + "/" + db_name + "}",
|
||||
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
|
||||
}
|
||||
|
|
|
@ -15,5 +15,5 @@ class SQLiteVectorIOConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||
}
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,108 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..conftest import (
|
||||
get_provider_fixture_overrides,
|
||||
get_provider_fixture_overrides_from_test_config,
|
||||
get_test_config_for_api,
|
||||
)
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from .fixtures import VECTOR_IO_FIXTURES
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "sentence_transformers",
|
||||
"vector_io": "faiss",
|
||||
},
|
||||
id="sentence_transformers",
|
||||
marks=pytest.mark.sentence_transformers,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"vector_io": "pgvector",
|
||||
},
|
||||
id="pgvector",
|
||||
marks=pytest.mark.pgvector,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"vector_io": "faiss",
|
||||
},
|
||||
id="ollama",
|
||||
marks=pytest.mark.ollama,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"vector_io": "sqlite_vec",
|
||||
},
|
||||
id="sqlite_vec",
|
||||
marks=pytest.mark.ollama,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "sentence_transformers",
|
||||
"vector_io": "chroma",
|
||||
},
|
||||
id="chroma",
|
||||
marks=pytest.mark.chroma,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"vector_io": "qdrant",
|
||||
},
|
||||
id="qdrant",
|
||||
marks=pytest.mark.qdrant,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "fireworks",
|
||||
"vector_io": "weaviate",
|
||||
},
|
||||
id="weaviate",
|
||||
marks=pytest.mark.weaviate,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for fixture_name in VECTOR_IO_FIXTURES:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{fixture_name}: marks tests as {fixture_name} specific",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
test_config = get_test_config_for_api(metafunc.config, "vector_io")
|
||||
if "embedding_model" in metafunc.fixturenames:
|
||||
model = getattr(test_config, "embedding_model", None)
|
||||
# Fall back to the default if not specified by the config file
|
||||
model = model or metafunc.config.getoption("--embedding-model")
|
||||
if model:
|
||||
params = [pytest.param(model, id="")]
|
||||
else:
|
||||
params = [pytest.param("all-minilm:l6-v2", id="")]
|
||||
|
||||
metafunc.parametrize("embedding_model", params, indirect=True)
|
||||
|
||||
if "vector_io_stack" in metafunc.fixturenames:
|
||||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
"vector_io": VECTOR_IO_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides_from_test_config(metafunc.config, "vector_io", DEFAULT_PROVIDER_COMBINATIONS)
|
||||
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||
or DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
metafunc.parametrize("vector_io_stack", combinations, indirect=True)
|
|
@ -1,180 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.qdrant import QdrantVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def embedding_model(request):
|
||||
if hasattr(request, "param"):
|
||||
return request.param
|
||||
return request.config.getoption("--embedding-model", None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_io_remote() -> ProviderFixture:
|
||||
return remote_stack_fixture()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_io_faiss() -> ProviderFixture:
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig(
|
||||
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_io_sqlite_vec() -> ProviderFixture:
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="sqlite_vec",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig(
|
||||
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_io_pgvector() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="pgvector",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorVectorIOConfig(
|
||||
host=os.getenv("PGVECTOR_HOST", "localhost"),
|
||||
port=os.getenv("PGVECTOR_PORT", 5432),
|
||||
db=get_env_or_fail("PGVECTOR_DB"),
|
||||
user=get_env_or_fail("PGVECTOR_USER"),
|
||||
password=get_env_or_fail("PGVECTOR_PASSWORD"),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_io_weaviate() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
config=WeaviateVectorIOConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
provider_data=dict(
|
||||
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
|
||||
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_io_chroma() -> ProviderFixture:
|
||||
url = os.getenv("CHROMA_URL")
|
||||
if url:
|
||||
config = ChromaVectorIOConfig(url=url)
|
||||
provider_type = "remote::chromadb"
|
||||
else:
|
||||
if not os.getenv("CHROMA_DB_PATH"):
|
||||
raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set")
|
||||
config = InlineChromaVectorIOConfig(db_path=os.getenv("CHROMA_DB_PATH"))
|
||||
provider_type = "inline::chromadb"
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="chroma",
|
||||
provider_type=provider_type,
|
||||
config=config.model_dump(),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_io_qdrant() -> ProviderFixture:
|
||||
url = os.getenv("QDRANT_URL")
|
||||
if url:
|
||||
config = QdrantVectorIOConfig(url=url)
|
||||
provider_type = "remote::qdrant"
|
||||
else:
|
||||
raise ValueError("QDRANT_URL must be set")
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="qdrant",
|
||||
provider_type=provider_type,
|
||||
config=config.model_dump(),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma", "qdrant", "sqlite_vec"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def vector_io_stack(embedding_model, request):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "vector_io"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.vector_io, Api.inference],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[
|
||||
ModelInput(
|
||||
model_id=embedding_model,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return test_stack.impls[Api.vector_io], test_stack.impls[Api.vector_dbs]
|
|
@ -1,160 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import sqlite3
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import sqlite_vec
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
|
||||
SQLiteVecIndex,
|
||||
SQLiteVecVectorIOAdapter,
|
||||
generate_chunk_id,
|
||||
)
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
SQLITE_VEC_PROVIDER = "sqlite_vec"
|
||||
EMBEDDING_DIMENSION = 384
|
||||
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def sqlite_connection(loop):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
try:
|
||||
conn.enable_load_extension(True)
|
||||
sqlite_vec.load(conn)
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
async def sqlite_vec_index(sqlite_connection):
|
||||
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_chunks():
|
||||
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
|
||||
n, k = 10, 3
|
||||
sample = [
|
||||
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
|
||||
for j in range(k)
|
||||
for i in range(n)
|
||||
]
|
||||
return sample
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_embeddings(sample_chunks):
|
||||
np.random.seed(42)
|
||||
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
|
||||
cur = sqlite_vec_index.connection.cursor()
|
||||
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
|
||||
count = cur.fetchone()[0]
|
||||
assert count == len(sample_chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
|
||||
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks):
|
||||
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
|
||||
# Reduce batch size to force multiple batches for same document
|
||||
# since there are 10 chunks per document and batch size is 2
|
||||
batch_size = 2
|
||||
sample_embeddings = np.random.rand(len(sample_chunks), EMBEDDING_DIMENSION).astype(np.float32)
|
||||
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
|
||||
|
||||
cur = sqlite_vec_index.connection.cursor()
|
||||
|
||||
# Retrieve all chunk IDs to check for duplicates
|
||||
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
|
||||
chunk_ids = [row[0] for row in cur.fetchall()]
|
||||
cur.close()
|
||||
|
||||
# Ensure all chunk IDs are unique
|
||||
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def sqlite_vec_adapter(sqlite_connection):
|
||||
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
|
||||
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)
|
||||
await adapter.initialize()
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_vector_db(sqlite_vec_adapter):
|
||||
vector_db = VectorDB(
|
||||
identifier="test_db",
|
||||
embedding_model=EMBEDDING_MODEL,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
metadata={},
|
||||
provider_id=SQLITE_VEC_PROVIDER,
|
||||
)
|
||||
await sqlite_vec_adapter.register_vector_db(vector_db)
|
||||
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
|
||||
assert any(db.identifier == "test_db" for db in vector_dbs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_vector_db(sqlite_vec_adapter):
|
||||
vector_db = VectorDB(
|
||||
identifier="test_db",
|
||||
embedding_model=EMBEDDING_MODEL,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
metadata={},
|
||||
provider_id=SQLITE_VEC_PROVIDER,
|
||||
)
|
||||
await sqlite_vec_adapter.register_vector_db(vector_db)
|
||||
await sqlite_vec_adapter.unregister_vector_db("test_db")
|
||||
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
|
||||
assert not any(db.identifier == "test_db" for db in vector_dbs)
|
||||
|
||||
|
||||
def test_generate_chunk_id():
|
||||
chunks = [
|
||||
Chunk(content="test", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="test ", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="test 3", metadata={"document_id": "doc-1"}),
|
||||
]
|
||||
|
||||
chunk_ids = sorted([generate_chunk_id(chunk.metadata["document_id"], chunk.content) for chunk in chunks])
|
||||
assert chunk_ids == [
|
||||
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
|
||||
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
|
||||
"f68df25d-d9aa-ab4d-5684-64a233add20d",
|
||||
]
|
|
@ -1,160 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/vector_io/test_vector_io.py \
|
||||
# -m "pgvector" --env EMBEDDING_DIMENSION=384 PGVECTOR_PORT=7432 \
|
||||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_chunks():
|
||||
docs = [
|
||||
RAGDocument(
|
||||
document_id="doc1",
|
||||
content="Python is a high-level programming language.",
|
||||
metadata={"category": "programming", "difficulty": "beginner"},
|
||||
),
|
||||
RAGDocument(
|
||||
document_id="doc2",
|
||||
content="Machine learning is a subset of artificial intelligence.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
),
|
||||
RAGDocument(
|
||||
document_id="doc3",
|
||||
content="Data structures are fundamental to computer science.",
|
||||
metadata={"category": "computer science", "difficulty": "intermediate"},
|
||||
),
|
||||
RAGDocument(
|
||||
document_id="doc4",
|
||||
content="Neural networks are inspired by biological neural networks.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
),
|
||||
]
|
||||
chunks = []
|
||||
for doc in docs:
|
||||
chunks.extend(make_overlapped_chunks(doc.document_id, doc.content, window_len=512, overlap_len=64))
|
||||
return chunks
|
||||
|
||||
|
||||
async def register_vector_db(vector_dbs_impl: VectorDB, embedding_model: str):
|
||||
vector_db_id = f"test_vector_db_{uuid.uuid4().hex}"
|
||||
return await vector_dbs_impl.register_vector_db(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=384,
|
||||
)
|
||||
|
||||
|
||||
class TestVectorIO:
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_list(self, vector_io_stack, embedding_model):
|
||||
_, vector_dbs_impl = vector_io_stack
|
||||
|
||||
# Register a test bank
|
||||
registered_vector_db = await register_vector_db(vector_dbs_impl, embedding_model)
|
||||
|
||||
try:
|
||||
# Verify our bank shows up in list
|
||||
response = await vector_dbs_impl.list_vector_dbs()
|
||||
assert isinstance(response, ListVectorDBsResponse)
|
||||
assert any(vector_db.vector_db_id == registered_vector_db.vector_db_id for vector_db in response.data)
|
||||
finally:
|
||||
# Clean up
|
||||
await vector_dbs_impl.unregister_vector_db(registered_vector_db.vector_db_id)
|
||||
|
||||
# Verify our bank was removed
|
||||
response = await vector_dbs_impl.list_vector_dbs()
|
||||
assert isinstance(response, ListVectorDBsResponse)
|
||||
assert all(vector_db.vector_db_id != registered_vector_db.vector_db_id for vector_db in response.data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_register(self, vector_io_stack, embedding_model):
|
||||
_, vector_dbs_impl = vector_io_stack
|
||||
|
||||
vector_db_id = f"test_vector_db_{uuid.uuid4().hex}"
|
||||
|
||||
try:
|
||||
# Register initial bank
|
||||
await vector_dbs_impl.register_vector_db(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=384,
|
||||
)
|
||||
|
||||
# Verify our bank exists
|
||||
response = await vector_dbs_impl.list_vector_dbs()
|
||||
assert isinstance(response, ListVectorDBsResponse)
|
||||
assert any(vector_db.vector_db_id == vector_db_id for vector_db in response.data)
|
||||
|
||||
# Try registering same bank again
|
||||
await vector_dbs_impl.register_vector_db(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=384,
|
||||
)
|
||||
|
||||
# Verify still only one instance of our bank
|
||||
response = await vector_dbs_impl.list_vector_dbs()
|
||||
assert isinstance(response, ListVectorDBsResponse)
|
||||
assert len([vector_db for vector_db in response.data if vector_db.vector_db_id == vector_db_id]) == 1
|
||||
finally:
|
||||
# Clean up
|
||||
await vector_dbs_impl.unregister_vector_db(vector_db_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(self, vector_io_stack, embedding_model, sample_chunks):
|
||||
vector_io_impl, vector_dbs_impl = vector_io_stack
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_impl.insert_chunks("test_vector_db", sample_chunks)
|
||||
|
||||
registered_db = await register_vector_db(vector_dbs_impl, embedding_model)
|
||||
await vector_io_impl.insert_chunks(registered_db.vector_db_id, sample_chunks)
|
||||
|
||||
query1 = "programming language"
|
||||
response1 = await vector_io_impl.query_chunks(registered_db.vector_db_id, query1)
|
||||
assert_valid_response(response1)
|
||||
assert any("Python" in chunk.content for chunk in response1.chunks)
|
||||
|
||||
# Test case 3: Query with semantic similarity
|
||||
query3 = "AI and brain-inspired computing"
|
||||
response3 = await vector_io_impl.query_chunks(registered_db.vector_db_id, query3)
|
||||
assert_valid_response(response3)
|
||||
assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks)
|
||||
|
||||
# Test case 4: Query with limit on number of results
|
||||
query4 = "computer"
|
||||
params4 = {"max_chunks": 2}
|
||||
response4 = await vector_io_impl.query_chunks(registered_db.vector_db_id, query4, params4)
|
||||
assert_valid_response(response4)
|
||||
assert len(response4.chunks) <= 2
|
||||
|
||||
# Test case 5: Query with threshold on similarity score
|
||||
query5 = "quantum computing" # Not directly related to any document
|
||||
params5 = {"score_threshold": 0.01}
|
||||
response5 = await vector_io_impl.query_chunks(registered_db.vector_db_id, query5, params5)
|
||||
assert_valid_response(response5)
|
||||
print("The scores are:", response5.scores)
|
||||
assert all(score >= 0.01 for score in response5.scores)
|
||||
|
||||
|
||||
def assert_valid_response(response: QueryChunksResponse):
|
||||
assert len(response.chunks) > 0
|
||||
assert len(response.scores) > 0
|
||||
assert len(response.chunks) == len(response.scores)
|
||||
for chunk in response.chunks:
|
||||
assert isinstance(chunk.content, str)
|
|
@ -55,11 +55,11 @@ class SqliteKVStoreConfig(CommonConfig):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db"):
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "kvstore.db"):
|
||||
return {
|
||||
"type": "sqlite",
|
||||
"namespace": None,
|
||||
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + db_name,
|
||||
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue