mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
refactor(test): unify vector_io tests and make them configurable
This commit is contained in:
parent
fd8c991393
commit
c43ed8d0e6
8 changed files with 90 additions and 666 deletions
|
@ -15,5 +15,5 @@ class SQLiteVectorIOConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
|
@ -1,108 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from ..conftest import (
|
|
||||||
get_provider_fixture_overrides,
|
|
||||||
get_provider_fixture_overrides_from_test_config,
|
|
||||||
get_test_config_for_api,
|
|
||||||
)
|
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
|
||||||
from .fixtures import VECTOR_IO_FIXTURES
|
|
||||||
|
|
||||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
"inference": "sentence_transformers",
|
|
||||||
"vector_io": "faiss",
|
|
||||||
},
|
|
||||||
id="sentence_transformers",
|
|
||||||
marks=pytest.mark.sentence_transformers,
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
"inference": "ollama",
|
|
||||||
"vector_io": "pgvector",
|
|
||||||
},
|
|
||||||
id="pgvector",
|
|
||||||
marks=pytest.mark.pgvector,
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
"inference": "ollama",
|
|
||||||
"vector_io": "faiss",
|
|
||||||
},
|
|
||||||
id="ollama",
|
|
||||||
marks=pytest.mark.ollama,
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
"inference": "ollama",
|
|
||||||
"vector_io": "sqlite_vec",
|
|
||||||
},
|
|
||||||
id="sqlite_vec",
|
|
||||||
marks=pytest.mark.ollama,
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
"inference": "sentence_transformers",
|
|
||||||
"vector_io": "chroma",
|
|
||||||
},
|
|
||||||
id="chroma",
|
|
||||||
marks=pytest.mark.chroma,
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
"inference": "ollama",
|
|
||||||
"vector_io": "qdrant",
|
|
||||||
},
|
|
||||||
id="qdrant",
|
|
||||||
marks=pytest.mark.qdrant,
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
"inference": "fireworks",
|
|
||||||
"vector_io": "weaviate",
|
|
||||||
},
|
|
||||||
id="weaviate",
|
|
||||||
marks=pytest.mark.weaviate,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
|
||||||
for fixture_name in VECTOR_IO_FIXTURES:
|
|
||||||
config.addinivalue_line(
|
|
||||||
"markers",
|
|
||||||
f"{fixture_name}: marks tests as {fixture_name} specific",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
|
||||||
test_config = get_test_config_for_api(metafunc.config, "vector_io")
|
|
||||||
if "embedding_model" in metafunc.fixturenames:
|
|
||||||
model = getattr(test_config, "embedding_model", None)
|
|
||||||
# Fall back to the default if not specified by the config file
|
|
||||||
model = model or metafunc.config.getoption("--embedding-model")
|
|
||||||
if model:
|
|
||||||
params = [pytest.param(model, id="")]
|
|
||||||
else:
|
|
||||||
params = [pytest.param("all-minilm:l6-v2", id="")]
|
|
||||||
|
|
||||||
metafunc.parametrize("embedding_model", params, indirect=True)
|
|
||||||
|
|
||||||
if "vector_io_stack" in metafunc.fixturenames:
|
|
||||||
available_fixtures = {
|
|
||||||
"inference": INFERENCE_FIXTURES,
|
|
||||||
"vector_io": VECTOR_IO_FIXTURES,
|
|
||||||
}
|
|
||||||
combinations = (
|
|
||||||
get_provider_fixture_overrides_from_test_config(metafunc.config, "vector_io", DEFAULT_PROVIDER_COMBINATIONS)
|
|
||||||
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
|
||||||
or DEFAULT_PROVIDER_COMBINATIONS
|
|
||||||
)
|
|
||||||
metafunc.parametrize("vector_io_stack", combinations, indirect=True)
|
|
|
@ -1,180 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelInput, ModelType
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
|
||||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
|
||||||
from llama_stack.providers.inline.vector_io.faiss import FaissVectorIOConfig
|
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
|
||||||
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
|
|
||||||
from llama_stack.providers.remote.vector_io.pgvector import PGVectorVectorIOConfig
|
|
||||||
from llama_stack.providers.remote.vector_io.qdrant import QdrantVectorIOConfig
|
|
||||||
from llama_stack.providers.remote.vector_io.weaviate import WeaviateVectorIOConfig
|
|
||||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
|
||||||
from ..env import get_env_or_fail
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def embedding_model(request):
|
|
||||||
if hasattr(request, "param"):
|
|
||||||
return request.param
|
|
||||||
return request.config.getoption("--embedding-model", None)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def vector_io_remote() -> ProviderFixture:
|
|
||||||
return remote_stack_fixture()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def vector_io_faiss() -> ProviderFixture:
|
|
||||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
|
||||||
return ProviderFixture(
|
|
||||||
providers=[
|
|
||||||
Provider(
|
|
||||||
provider_id="faiss",
|
|
||||||
provider_type="inline::faiss",
|
|
||||||
config=FaissVectorIOConfig(
|
|
||||||
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
|
||||||
).model_dump(),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def vector_io_sqlite_vec() -> ProviderFixture:
|
|
||||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
|
||||||
return ProviderFixture(
|
|
||||||
providers=[
|
|
||||||
Provider(
|
|
||||||
provider_id="sqlite_vec",
|
|
||||||
provider_type="inline::sqlite-vec",
|
|
||||||
config=SQLiteVectorIOConfig(
|
|
||||||
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
|
||||||
).model_dump(),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def vector_io_pgvector() -> ProviderFixture:
|
|
||||||
return ProviderFixture(
|
|
||||||
providers=[
|
|
||||||
Provider(
|
|
||||||
provider_id="pgvector",
|
|
||||||
provider_type="remote::pgvector",
|
|
||||||
config=PGVectorVectorIOConfig(
|
|
||||||
host=os.getenv("PGVECTOR_HOST", "localhost"),
|
|
||||||
port=os.getenv("PGVECTOR_PORT", 5432),
|
|
||||||
db=get_env_or_fail("PGVECTOR_DB"),
|
|
||||||
user=get_env_or_fail("PGVECTOR_USER"),
|
|
||||||
password=get_env_or_fail("PGVECTOR_PASSWORD"),
|
|
||||||
).model_dump(),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def vector_io_weaviate() -> ProviderFixture:
|
|
||||||
return ProviderFixture(
|
|
||||||
providers=[
|
|
||||||
Provider(
|
|
||||||
provider_id="weaviate",
|
|
||||||
provider_type="remote::weaviate",
|
|
||||||
config=WeaviateVectorIOConfig().model_dump(),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
provider_data=dict(
|
|
||||||
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
|
|
||||||
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def vector_io_chroma() -> ProviderFixture:
|
|
||||||
url = os.getenv("CHROMA_URL")
|
|
||||||
if url:
|
|
||||||
config = ChromaVectorIOConfig(url=url)
|
|
||||||
provider_type = "remote::chromadb"
|
|
||||||
else:
|
|
||||||
if not os.getenv("CHROMA_DB_PATH"):
|
|
||||||
raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set")
|
|
||||||
config = InlineChromaVectorIOConfig(db_path=os.getenv("CHROMA_DB_PATH"))
|
|
||||||
provider_type = "inline::chromadb"
|
|
||||||
return ProviderFixture(
|
|
||||||
providers=[
|
|
||||||
Provider(
|
|
||||||
provider_id="chroma",
|
|
||||||
provider_type=provider_type,
|
|
||||||
config=config.model_dump(),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def vector_io_qdrant() -> ProviderFixture:
|
|
||||||
url = os.getenv("QDRANT_URL")
|
|
||||||
if url:
|
|
||||||
config = QdrantVectorIOConfig(url=url)
|
|
||||||
provider_type = "remote::qdrant"
|
|
||||||
else:
|
|
||||||
raise ValueError("QDRANT_URL must be set")
|
|
||||||
return ProviderFixture(
|
|
||||||
providers=[
|
|
||||||
Provider(
|
|
||||||
provider_id="qdrant",
|
|
||||||
provider_type=provider_type,
|
|
||||||
config=config.model_dump(),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma", "qdrant", "sqlite_vec"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
|
||||||
async def vector_io_stack(embedding_model, request):
|
|
||||||
fixture_dict = request.param
|
|
||||||
|
|
||||||
providers = {}
|
|
||||||
provider_data = {}
|
|
||||||
for key in ["inference", "vector_io"]:
|
|
||||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
|
||||||
providers[key] = fixture.providers
|
|
||||||
if fixture.provider_data:
|
|
||||||
provider_data.update(fixture.provider_data)
|
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
|
||||||
[Api.vector_io, Api.inference],
|
|
||||||
providers,
|
|
||||||
provider_data,
|
|
||||||
models=[
|
|
||||||
ModelInput(
|
|
||||||
model_id=embedding_model,
|
|
||||||
model_type=ModelType.embedding,
|
|
||||||
metadata={
|
|
||||||
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
return test_stack.impls[Api.vector_io], test_stack.impls[Api.vector_dbs]
|
|
|
@ -1,160 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import sqlite3
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import sqlite_vec
|
|
||||||
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
|
|
||||||
SQLiteVecIndex,
|
|
||||||
SQLiteVecVectorIOAdapter,
|
|
||||||
generate_chunk_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# How to run this test:
|
|
||||||
#
|
|
||||||
# pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py \
|
|
||||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
|
||||||
|
|
||||||
SQLITE_VEC_PROVIDER = "sqlite_vec"
|
|
||||||
EMBEDDING_DIMENSION = 384
|
|
||||||
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def loop():
|
|
||||||
return asyncio.new_event_loop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def sqlite_connection(loop):
|
|
||||||
conn = sqlite3.connect(":memory:")
|
|
||||||
try:
|
|
||||||
conn.enable_load_extension(True)
|
|
||||||
sqlite_vec.load(conn)
|
|
||||||
yield conn
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
async def sqlite_vec_index(sqlite_connection):
|
|
||||||
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def sample_chunks():
|
|
||||||
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
|
|
||||||
n, k = 10, 3
|
|
||||||
sample = [
|
|
||||||
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
|
|
||||||
for j in range(k)
|
|
||||||
for i in range(n)
|
|
||||||
]
|
|
||||||
return sample
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def sample_embeddings(sample_chunks):
|
|
||||||
np.random.seed(42)
|
|
||||||
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
|
|
||||||
cur = sqlite_vec_index.connection.cursor()
|
|
||||||
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
|
|
||||||
count = cur.fetchone()[0]
|
|
||||||
assert count == len(sample_chunks)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
|
||||||
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
|
|
||||||
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
|
|
||||||
assert isinstance(response, QueryChunksResponse)
|
|
||||||
assert len(response.chunks) == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks):
|
|
||||||
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
|
|
||||||
# Reduce batch size to force multiple batches for same document
|
|
||||||
# since there are 10 chunks per document and batch size is 2
|
|
||||||
batch_size = 2
|
|
||||||
sample_embeddings = np.random.rand(len(sample_chunks), EMBEDDING_DIMENSION).astype(np.float32)
|
|
||||||
|
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
|
|
||||||
|
|
||||||
cur = sqlite_vec_index.connection.cursor()
|
|
||||||
|
|
||||||
# Retrieve all chunk IDs to check for duplicates
|
|
||||||
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
|
|
||||||
chunk_ids = [row[0] for row in cur.fetchall()]
|
|
||||||
cur.close()
|
|
||||||
|
|
||||||
# Ensure all chunk IDs are unique
|
|
||||||
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
async def sqlite_vec_adapter(sqlite_connection):
|
|
||||||
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
|
|
||||||
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)
|
|
||||||
await adapter.initialize()
|
|
||||||
yield adapter
|
|
||||||
await adapter.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_register_vector_db(sqlite_vec_adapter):
|
|
||||||
vector_db = VectorDB(
|
|
||||||
identifier="test_db",
|
|
||||||
embedding_model=EMBEDDING_MODEL,
|
|
||||||
embedding_dimension=EMBEDDING_DIMENSION,
|
|
||||||
metadata={},
|
|
||||||
provider_id=SQLITE_VEC_PROVIDER,
|
|
||||||
)
|
|
||||||
await sqlite_vec_adapter.register_vector_db(vector_db)
|
|
||||||
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
|
|
||||||
assert any(db.identifier == "test_db" for db in vector_dbs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unregister_vector_db(sqlite_vec_adapter):
|
|
||||||
vector_db = VectorDB(
|
|
||||||
identifier="test_db",
|
|
||||||
embedding_model=EMBEDDING_MODEL,
|
|
||||||
embedding_dimension=EMBEDDING_DIMENSION,
|
|
||||||
metadata={},
|
|
||||||
provider_id=SQLITE_VEC_PROVIDER,
|
|
||||||
)
|
|
||||||
await sqlite_vec_adapter.register_vector_db(vector_db)
|
|
||||||
await sqlite_vec_adapter.unregister_vector_db("test_db")
|
|
||||||
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
|
|
||||||
assert not any(db.identifier == "test_db" for db in vector_dbs)
|
|
||||||
|
|
||||||
|
|
||||||
def test_generate_chunk_id():
|
|
||||||
chunks = [
|
|
||||||
Chunk(content="test", metadata={"document_id": "doc-1"}),
|
|
||||||
Chunk(content="test ", metadata={"document_id": "doc-1"}),
|
|
||||||
Chunk(content="test 3", metadata={"document_id": "doc-1"}),
|
|
||||||
]
|
|
||||||
|
|
||||||
chunk_ids = sorted([generate_chunk_id(chunk.metadata["document_id"], chunk.content) for chunk in chunks])
|
|
||||||
assert chunk_ids == [
|
|
||||||
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
|
|
||||||
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
|
|
||||||
"f68df25d-d9aa-ab4d-5684-64a233add20d",
|
|
||||||
]
|
|
|
@ -1,160 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from llama_stack.apis.tools import RAGDocument
|
|
||||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
|
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks
|
|
||||||
|
|
||||||
# How to run this test:
|
|
||||||
#
|
|
||||||
# pytest llama_stack/providers/tests/vector_io/test_vector_io.py \
|
|
||||||
# -m "pgvector" --env EMBEDDING_DIMENSION=384 PGVECTOR_PORT=7432 \
|
|
||||||
# -v -s --tb=short --disable-warnings
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def sample_chunks():
|
|
||||||
docs = [
|
|
||||||
RAGDocument(
|
|
||||||
document_id="doc1",
|
|
||||||
content="Python is a high-level programming language.",
|
|
||||||
metadata={"category": "programming", "difficulty": "beginner"},
|
|
||||||
),
|
|
||||||
RAGDocument(
|
|
||||||
document_id="doc2",
|
|
||||||
content="Machine learning is a subset of artificial intelligence.",
|
|
||||||
metadata={"category": "AI", "difficulty": "advanced"},
|
|
||||||
),
|
|
||||||
RAGDocument(
|
|
||||||
document_id="doc3",
|
|
||||||
content="Data structures are fundamental to computer science.",
|
|
||||||
metadata={"category": "computer science", "difficulty": "intermediate"},
|
|
||||||
),
|
|
||||||
RAGDocument(
|
|
||||||
document_id="doc4",
|
|
||||||
content="Neural networks are inspired by biological neural networks.",
|
|
||||||
metadata={"category": "AI", "difficulty": "advanced"},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
chunks = []
|
|
||||||
for doc in docs:
|
|
||||||
chunks.extend(make_overlapped_chunks(doc.document_id, doc.content, window_len=512, overlap_len=64))
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
|
|
||||||
async def register_vector_db(vector_dbs_impl: VectorDB, embedding_model: str):
|
|
||||||
vector_db_id = f"test_vector_db_{uuid.uuid4().hex}"
|
|
||||||
return await vector_dbs_impl.register_vector_db(
|
|
||||||
vector_db_id=vector_db_id,
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
embedding_dimension=384,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestVectorIO:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_banks_list(self, vector_io_stack, embedding_model):
|
|
||||||
_, vector_dbs_impl = vector_io_stack
|
|
||||||
|
|
||||||
# Register a test bank
|
|
||||||
registered_vector_db = await register_vector_db(vector_dbs_impl, embedding_model)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Verify our bank shows up in list
|
|
||||||
response = await vector_dbs_impl.list_vector_dbs()
|
|
||||||
assert isinstance(response, ListVectorDBsResponse)
|
|
||||||
assert any(vector_db.vector_db_id == registered_vector_db.vector_db_id for vector_db in response.data)
|
|
||||||
finally:
|
|
||||||
# Clean up
|
|
||||||
await vector_dbs_impl.unregister_vector_db(registered_vector_db.vector_db_id)
|
|
||||||
|
|
||||||
# Verify our bank was removed
|
|
||||||
response = await vector_dbs_impl.list_vector_dbs()
|
|
||||||
assert isinstance(response, ListVectorDBsResponse)
|
|
||||||
assert all(vector_db.vector_db_id != registered_vector_db.vector_db_id for vector_db in response.data)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_banks_register(self, vector_io_stack, embedding_model):
|
|
||||||
_, vector_dbs_impl = vector_io_stack
|
|
||||||
|
|
||||||
vector_db_id = f"test_vector_db_{uuid.uuid4().hex}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Register initial bank
|
|
||||||
await vector_dbs_impl.register_vector_db(
|
|
||||||
vector_db_id=vector_db_id,
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
embedding_dimension=384,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify our bank exists
|
|
||||||
response = await vector_dbs_impl.list_vector_dbs()
|
|
||||||
assert isinstance(response, ListVectorDBsResponse)
|
|
||||||
assert any(vector_db.vector_db_id == vector_db_id for vector_db in response.data)
|
|
||||||
|
|
||||||
# Try registering same bank again
|
|
||||||
await vector_dbs_impl.register_vector_db(
|
|
||||||
vector_db_id=vector_db_id,
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
embedding_dimension=384,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify still only one instance of our bank
|
|
||||||
response = await vector_dbs_impl.list_vector_dbs()
|
|
||||||
assert isinstance(response, ListVectorDBsResponse)
|
|
||||||
assert len([vector_db for vector_db in response.data if vector_db.vector_db_id == vector_db_id]) == 1
|
|
||||||
finally:
|
|
||||||
# Clean up
|
|
||||||
await vector_dbs_impl.unregister_vector_db(vector_db_id)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_documents(self, vector_io_stack, embedding_model, sample_chunks):
|
|
||||||
vector_io_impl, vector_dbs_impl = vector_io_stack
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
await vector_io_impl.insert_chunks("test_vector_db", sample_chunks)
|
|
||||||
|
|
||||||
registered_db = await register_vector_db(vector_dbs_impl, embedding_model)
|
|
||||||
await vector_io_impl.insert_chunks(registered_db.vector_db_id, sample_chunks)
|
|
||||||
|
|
||||||
query1 = "programming language"
|
|
||||||
response1 = await vector_io_impl.query_chunks(registered_db.vector_db_id, query1)
|
|
||||||
assert_valid_response(response1)
|
|
||||||
assert any("Python" in chunk.content for chunk in response1.chunks)
|
|
||||||
|
|
||||||
# Test case 3: Query with semantic similarity
|
|
||||||
query3 = "AI and brain-inspired computing"
|
|
||||||
response3 = await vector_io_impl.query_chunks(registered_db.vector_db_id, query3)
|
|
||||||
assert_valid_response(response3)
|
|
||||||
assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks)
|
|
||||||
|
|
||||||
# Test case 4: Query with limit on number of results
|
|
||||||
query4 = "computer"
|
|
||||||
params4 = {"max_chunks": 2}
|
|
||||||
response4 = await vector_io_impl.query_chunks(registered_db.vector_db_id, query4, params4)
|
|
||||||
assert_valid_response(response4)
|
|
||||||
assert len(response4.chunks) <= 2
|
|
||||||
|
|
||||||
# Test case 5: Query with threshold on similarity score
|
|
||||||
query5 = "quantum computing" # Not directly related to any document
|
|
||||||
params5 = {"score_threshold": 0.01}
|
|
||||||
response5 = await vector_io_impl.query_chunks(registered_db.vector_db_id, query5, params5)
|
|
||||||
assert_valid_response(response5)
|
|
||||||
print("The scores are:", response5.scores)
|
|
||||||
assert all(score >= 0.01 for score in response5.scores)
|
|
||||||
|
|
||||||
|
|
||||||
def assert_valid_response(response: QueryChunksResponse):
|
|
||||||
assert len(response.chunks) > 0
|
|
||||||
assert len(response.scores) > 0
|
|
||||||
assert len(response.chunks) == len(response.scores)
|
|
||||||
for chunk in response.chunks:
|
|
||||||
assert isinstance(chunk.content, str)
|
|
|
@ -128,6 +128,7 @@ def distro_from_adhoc_config_spec(adhoc_config_spec: str) -> str:
|
||||||
api_providers = adhoc_config_spec.replace(";", ",").split(",")
|
api_providers = adhoc_config_spec.replace(";", ",").split(",")
|
||||||
provider_registry = get_provider_registry()
|
provider_registry = get_provider_registry()
|
||||||
|
|
||||||
|
distro_dir = tempfile.mkdtemp()
|
||||||
provider_configs_by_api = {}
|
provider_configs_by_api = {}
|
||||||
for api_provider in api_providers:
|
for api_provider in api_providers:
|
||||||
api_str, provider = api_provider.split("=")
|
api_str, provider = api_provider.split("=")
|
||||||
|
@ -147,7 +148,7 @@ def distro_from_adhoc_config_spec(adhoc_config_spec: str) -> str:
|
||||||
|
|
||||||
# call method "sample_run_config" on the provider spec config class
|
# call method "sample_run_config" on the provider spec config class
|
||||||
provider_config_type = instantiate_class_type(provider_spec.config_class)
|
provider_config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
provider_config = replace_env_vars(provider_config_type.sample_run_config())
|
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
|
||||||
|
|
||||||
provider_configs_by_api[api_str] = [
|
provider_configs_by_api[api_str] = [
|
||||||
Provider(
|
Provider(
|
||||||
|
|
|
@ -4,83 +4,119 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import random
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
INLINE_VECTOR_DB_PROVIDERS = [
|
from llama_stack.apis.vector_io import Chunk
|
||||||
"faiss",
|
|
||||||
# TODO: add sqlite_vec to templates
|
|
||||||
# "sqlite_vec",
|
@pytest.fixture(scope="session")
|
||||||
|
def sample_chunks():
|
||||||
|
return [
|
||||||
|
Chunk(
|
||||||
|
content="Python is a high-level programming language that emphasizes code readability and allows programmers to express concepts in fewer lines of code than would be possible in languages such as C++ or Java.",
|
||||||
|
metadata={"document_id": "doc1"},
|
||||||
|
),
|
||||||
|
Chunk(
|
||||||
|
content="Machine learning is a subset of artificial intelligence that enables systems to automatically learn and improve from experience without being explicitly programmed, using statistical techniques to give computer systems the ability to progressively improve performance on a specific task.",
|
||||||
|
metadata={"document_id": "doc2"},
|
||||||
|
),
|
||||||
|
Chunk(
|
||||||
|
content="Data structures are fundamental to computer science because they provide organized ways to store and access data efficiently, enable faster processing of data through optimized algorithms, and form the building blocks for more complex software systems.",
|
||||||
|
metadata={"document_id": "doc3"},
|
||||||
|
),
|
||||||
|
Chunk(
|
||||||
|
content="Neural networks are inspired by biological neural networks found in animal brains, using interconnected nodes called artificial neurons to process information through weighted connections that can be trained to recognize patterns and solve complex problems through iterative learning.",
|
||||||
|
metadata={"document_id": "doc4"},
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def empty_vector_db_registry(llama_stack_client):
|
def client_with_empty_registry(client_with_models):
|
||||||
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
def clear_registry():
|
||||||
|
vector_dbs = [vector_db.identifier for vector_db in client_with_models.vector_dbs.list()]
|
||||||
for vector_db_id in vector_dbs:
|
for vector_db_id in vector_dbs:
|
||||||
llama_stack_client.vector_dbs.unregister(vector_db_id=vector_db_id)
|
client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||||
|
|
||||||
|
clear_registry()
|
||||||
|
yield client_with_models
|
||||||
|
|
||||||
|
# you must clean after the last test if you were running tests against
|
||||||
|
# a stateful server instance
|
||||||
|
clear_registry()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id):
|
||||||
def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry, provider_id):
|
|
||||||
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
|
||||||
llama_stack_client.vector_dbs.register(
|
|
||||||
vector_db_id=vector_db_id,
|
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
|
||||||
embedding_dimension=384,
|
|
||||||
provider_id=provider_id,
|
|
||||||
)
|
|
||||||
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
|
||||||
return vector_dbs
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS)
|
|
||||||
def test_vector_db_retrieve(llama_stack_client, embedding_model_id, empty_vector_db_registry, provider_id):
|
|
||||||
# Register a memory bank first
|
# Register a memory bank first
|
||||||
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
vector_db_id = "test_vector_db"
|
||||||
llama_stack_client.vector_dbs.register(
|
client_with_empty_registry.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
embedding_model=embedding_model_id,
|
embedding_model=embedding_model_id,
|
||||||
embedding_dimension=384,
|
embedding_dimension=384,
|
||||||
provider_id=provider_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Retrieve the memory bank and validate its properties
|
# Retrieve the memory bank and validate its properties
|
||||||
response = llama_stack_client.vector_dbs.retrieve(vector_db_id=vector_db_id)
|
response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=vector_db_id)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert response.identifier == vector_db_id
|
assert response.identifier == vector_db_id
|
||||||
assert response.embedding_model == embedding_model_id
|
assert response.embedding_model == embedding_model_id
|
||||||
assert response.provider_id == provider_id
|
|
||||||
assert response.provider_resource_id == vector_db_id
|
assert response.provider_resource_id == vector_db_id
|
||||||
|
|
||||||
|
|
||||||
def test_vector_db_list(llama_stack_client, empty_vector_db_registry):
|
def test_vector_db_register(client_with_empty_registry, embedding_model_id):
|
||||||
vector_dbs_after_register = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
vector_db_id = "test_vector_db"
|
||||||
assert len(vector_dbs_after_register) == 0
|
client_with_empty_registry.vector_dbs.register(
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS)
|
|
||||||
def test_vector_db_register(llama_stack_client, embedding_model_id, empty_vector_db_registry, provider_id):
|
|
||||||
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
|
||||||
llama_stack_client.vector_dbs.register(
|
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
embedding_model=embedding_model_id,
|
embedding_model=embedding_model_id,
|
||||||
embedding_dimension=384,
|
embedding_dimension=384,
|
||||||
provider_id=provider_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_dbs_after_register = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
assert vector_dbs_after_register == [vector_db_id]
|
assert vector_dbs_after_register == [vector_db_id]
|
||||||
|
|
||||||
|
client_with_empty_registry.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS)
|
vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
def test_vector_db_unregister(llama_stack_client, single_entry_vector_db_registry, provider_id):
|
|
||||||
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
|
||||||
assert len(vector_dbs) == 1
|
|
||||||
|
|
||||||
vector_db_id = vector_dbs[0]
|
|
||||||
llama_stack_client.vector_dbs.unregister(vector_db_id=vector_db_id)
|
|
||||||
|
|
||||||
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
|
||||||
assert len(vector_dbs) == 0
|
assert len(vector_dbs) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
("What makes Python different from C++ and Java?", "doc1"),
|
||||||
|
("How do systems learn without explicit programming?", "doc2"),
|
||||||
|
("Why are data structures important in computer science?", "doc3"),
|
||||||
|
("What is the biological inspiration for neural networks?", "doc4"),
|
||||||
|
("How does machine learning improve over time?", "doc2"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_insert_chunks(client_with_empty_registry, embedding_model_id, sample_chunks, test_case):
|
||||||
|
vector_db_id = "test_vector_db"
|
||||||
|
client_with_empty_registry.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model=embedding_model_id,
|
||||||
|
embedding_dimension=384,
|
||||||
|
)
|
||||||
|
|
||||||
|
client_with_empty_registry.vector_io.insert(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
chunks=sample_chunks,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client_with_empty_registry.vector_io.query(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
query="What is the capital of France?",
|
||||||
|
)
|
||||||
|
assert response is not None
|
||||||
|
assert len(response.chunks) > 1
|
||||||
|
assert len(response.scores) > 1
|
||||||
|
|
||||||
|
query, expected_doc_id = test_case
|
||||||
|
response = client_with_empty_registry.vector_io.query(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
assert response is not None
|
||||||
|
top_match = response.chunks[0]
|
||||||
|
assert top_match is not None
|
||||||
|
assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue