refactor(test): unify vector_io tests and make them configurable (#1398)

## Test Plan


`LLAMA_STACK_CONFIG=inference=sentence-transformers,vector_io=sqlite-vec
pytest -s -v test_vector_io.py --embedding-model all-miniLM-L6-V2
--inference-model='' --vision-inference-model=''`

```
test_vector_io.py::test_vector_db_retrieve[txt=:vis=:emb=all-miniLM-L6-V2] PASSED
test_vector_io.py::test_vector_db_register[txt=:vis=:emb=all-miniLM-L6-V2] PASSED
test_vector_io.py::test_insert_chunks[txt=:vis=:emb=all-miniLM-L6-V2-test_case0] PASSED
test_vector_io.py::test_insert_chunks[txt=:vis=:emb=all-miniLM-L6-V2-test_case1] PASSED
test_vector_io.py::test_insert_chunks[txt=:vis=:emb=all-miniLM-L6-V2-test_case2] PASSED
test_vector_io.py::test_insert_chunks[txt=:vis=:emb=all-miniLM-L6-V2-test_case3] PASSED
test_vector_io.py::test_insert_chunks[txt=:vis=:emb=all-miniLM-L6-V2-test_case4] PASSED
```

Same thing with:
- LLAMA_STACK_CONFIG=inference=sentence-transformers,vector_io=faiss
- LLAMA_STACK_CONFIG=fireworks

(Note that ergonomics will soon be improved re: cmd-line options and env
variables)
This commit is contained in:
Ashwin Bharambe 2025-03-04 13:37:45 -08:00 committed by GitHub
parent fd8c991393
commit dd0db8038b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 117 additions and 559 deletions

View file

@ -248,7 +248,7 @@ def _generate_run_config(
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
if hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(__distro_dir__=f"distributions/{image_name}")
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
else:
config = {}

View file

@ -44,9 +44,9 @@ class TelemetryConfig(BaseModel):
return v
@classmethod
def sample_run_config(cls, __distro_dir__: str = "runtime", db_name: str = "trace_store.db") -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:~/.llama/" + __distro_dir__ + "/" + db_name + "}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
}

View file

@ -15,5 +15,5 @@ class SQLiteVectorIOConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + "sqlite_vec.db",
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + "sqlite_vec.db",
}

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,108 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import (
get_provider_fixture_overrides,
get_provider_fixture_overrides_from_test_config,
get_test_config_for_api,
)
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import VECTOR_IO_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "sentence_transformers",
"vector_io": "faiss",
},
id="sentence_transformers",
marks=pytest.mark.sentence_transformers,
),
pytest.param(
{
"inference": "ollama",
"vector_io": "pgvector",
},
id="pgvector",
marks=pytest.mark.pgvector,
),
pytest.param(
{
"inference": "ollama",
"vector_io": "faiss",
},
id="ollama",
marks=pytest.mark.ollama,
),
pytest.param(
{
"inference": "ollama",
"vector_io": "sqlite_vec",
},
id="sqlite_vec",
marks=pytest.mark.ollama,
),
pytest.param(
{
"inference": "sentence_transformers",
"vector_io": "chroma",
},
id="chroma",
marks=pytest.mark.chroma,
),
pytest.param(
{
"inference": "ollama",
"vector_io": "qdrant",
},
id="qdrant",
marks=pytest.mark.qdrant,
),
pytest.param(
{
"inference": "fireworks",
"vector_io": "weaviate",
},
id="weaviate",
marks=pytest.mark.weaviate,
),
]
def pytest_configure(config):
for fixture_name in VECTOR_IO_FIXTURES:
config.addinivalue_line(
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
def pytest_generate_tests(metafunc):
test_config = get_test_config_for_api(metafunc.config, "vector_io")
if "embedding_model" in metafunc.fixturenames:
model = getattr(test_config, "embedding_model", None)
# Fall back to the default if not specified by the config file
model = model or metafunc.config.getoption("--embedding-model")
if model:
params = [pytest.param(model, id="")]
else:
params = [pytest.param("all-minilm:l6-v2", id="")]
metafunc.parametrize("embedding_model", params, indirect=True)
if "vector_io_stack" in metafunc.fixturenames:
available_fixtures = {
"inference": INFERENCE_FIXTURES,
"vector_io": VECTOR_IO_FIXTURES,
}
combinations = (
get_provider_fixture_overrides_from_test_config(metafunc.config, "vector_io", DEFAULT_PROVIDER_COMBINATIONS)
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("vector_io_stack", combinations, indirect=True)

View file

@ -1,180 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import tempfile
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.qdrant import QdrantVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate import WeaviateVectorIOConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def embedding_model(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--embedding-model", None)
@pytest.fixture(scope="session")
def vector_io_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def vector_io_faiss() -> ProviderFixture:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture(
providers=[
Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def vector_io_sqlite_vec() -> ProviderFixture:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture(
providers=[
Provider(
provider_id="sqlite_vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def vector_io_pgvector() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="pgvector",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig(
host=os.getenv("PGVECTOR_HOST", "localhost"),
port=os.getenv("PGVECTOR_PORT", 5432),
db=get_env_or_fail("PGVECTOR_DB"),
user=get_env_or_fail("PGVECTOR_USER"),
password=get_env_or_fail("PGVECTOR_PASSWORD"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def vector_io_weaviate() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="weaviate",
provider_type="remote::weaviate",
config=WeaviateVectorIOConfig().model_dump(),
)
],
provider_data=dict(
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
),
)
@pytest.fixture(scope="session")
def vector_io_chroma() -> ProviderFixture:
url = os.getenv("CHROMA_URL")
if url:
config = ChromaVectorIOConfig(url=url)
provider_type = "remote::chromadb"
else:
if not os.getenv("CHROMA_DB_PATH"):
raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set")
config = InlineChromaVectorIOConfig(db_path=os.getenv("CHROMA_DB_PATH"))
provider_type = "inline::chromadb"
return ProviderFixture(
providers=[
Provider(
provider_id="chroma",
provider_type=provider_type,
config=config.model_dump(),
)
]
)
@pytest.fixture(scope="session")
def vector_io_qdrant() -> ProviderFixture:
url = os.getenv("QDRANT_URL")
if url:
config = QdrantVectorIOConfig(url=url)
provider_type = "remote::qdrant"
else:
raise ValueError("QDRANT_URL must be set")
return ProviderFixture(
providers=[
Provider(
provider_id="qdrant",
provider_type=provider_type,
config=config.model_dump(),
)
]
)
VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma", "qdrant", "sqlite_vec"]
@pytest_asyncio.fixture(scope="session")
async def vector_io_stack(embedding_model, request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["inference", "vector_io"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test(
[Api.vector_io, Api.inference],
providers,
provider_data,
models=[
ModelInput(
model_id=embedding_model,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
},
)
],
)
return test_stack.impls[Api.vector_io], test_stack.impls[Api.vector_dbs]

View file

@ -1,160 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
import pytest
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks
# How to run this test:
#
# pytest llama_stack/providers/tests/vector_io/test_vector_io.py \
# -m "pgvector" --env EMBEDDING_DIMENSION=384 PGVECTOR_PORT=7432 \
# -v -s --tb=short --disable-warnings
@pytest.fixture(scope="session")
def sample_chunks():
docs = [
RAGDocument(
document_id="doc1",
content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"},
),
RAGDocument(
document_id="doc2",
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"},
),
RAGDocument(
document_id="doc3",
content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"},
),
RAGDocument(
document_id="doc4",
content="Neural networks are inspired by biological neural networks.",
metadata={"category": "AI", "difficulty": "advanced"},
),
]
chunks = []
for doc in docs:
chunks.extend(make_overlapped_chunks(doc.document_id, doc.content, window_len=512, overlap_len=64))
return chunks
async def register_vector_db(vector_dbs_impl: VectorDB, embedding_model: str):
vector_db_id = f"test_vector_db_{uuid.uuid4().hex}"
return await vector_dbs_impl.register_vector_db(
vector_db_id=vector_db_id,
embedding_model=embedding_model,
embedding_dimension=384,
)
class TestVectorIO:
@pytest.mark.asyncio
async def test_banks_list(self, vector_io_stack, embedding_model):
_, vector_dbs_impl = vector_io_stack
# Register a test bank
registered_vector_db = await register_vector_db(vector_dbs_impl, embedding_model)
try:
# Verify our bank shows up in list
response = await vector_dbs_impl.list_vector_dbs()
assert isinstance(response, ListVectorDBsResponse)
assert any(vector_db.vector_db_id == registered_vector_db.vector_db_id for vector_db in response.data)
finally:
# Clean up
await vector_dbs_impl.unregister_vector_db(registered_vector_db.vector_db_id)
# Verify our bank was removed
response = await vector_dbs_impl.list_vector_dbs()
assert isinstance(response, ListVectorDBsResponse)
assert all(vector_db.vector_db_id != registered_vector_db.vector_db_id for vector_db in response.data)
@pytest.mark.asyncio
async def test_banks_register(self, vector_io_stack, embedding_model):
_, vector_dbs_impl = vector_io_stack
vector_db_id = f"test_vector_db_{uuid.uuid4().hex}"
try:
# Register initial bank
await vector_dbs_impl.register_vector_db(
vector_db_id=vector_db_id,
embedding_model=embedding_model,
embedding_dimension=384,
)
# Verify our bank exists
response = await vector_dbs_impl.list_vector_dbs()
assert isinstance(response, ListVectorDBsResponse)
assert any(vector_db.vector_db_id == vector_db_id for vector_db in response.data)
# Try registering same bank again
await vector_dbs_impl.register_vector_db(
vector_db_id=vector_db_id,
embedding_model=embedding_model,
embedding_dimension=384,
)
# Verify still only one instance of our bank
response = await vector_dbs_impl.list_vector_dbs()
assert isinstance(response, ListVectorDBsResponse)
assert len([vector_db for vector_db in response.data if vector_db.vector_db_id == vector_db_id]) == 1
finally:
# Clean up
await vector_dbs_impl.unregister_vector_db(vector_db_id)
@pytest.mark.asyncio
async def test_query_documents(self, vector_io_stack, embedding_model, sample_chunks):
vector_io_impl, vector_dbs_impl = vector_io_stack
with pytest.raises(ValueError):
await vector_io_impl.insert_chunks("test_vector_db", sample_chunks)
registered_db = await register_vector_db(vector_dbs_impl, embedding_model)
await vector_io_impl.insert_chunks(registered_db.vector_db_id, sample_chunks)
query1 = "programming language"
response1 = await vector_io_impl.query_chunks(registered_db.vector_db_id, query1)
assert_valid_response(response1)
assert any("Python" in chunk.content for chunk in response1.chunks)
# Test case 3: Query with semantic similarity
query3 = "AI and brain-inspired computing"
response3 = await vector_io_impl.query_chunks(registered_db.vector_db_id, query3)
assert_valid_response(response3)
assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks)
# Test case 4: Query with limit on number of results
query4 = "computer"
params4 = {"max_chunks": 2}
response4 = await vector_io_impl.query_chunks(registered_db.vector_db_id, query4, params4)
assert_valid_response(response4)
assert len(response4.chunks) <= 2
# Test case 5: Query with threshold on similarity score
query5 = "quantum computing" # Not directly related to any document
params5 = {"score_threshold": 0.01}
response5 = await vector_io_impl.query_chunks(registered_db.vector_db_id, query5, params5)
assert_valid_response(response5)
print("The scores are:", response5.scores)
assert all(score >= 0.01 for score in response5.scores)
def assert_valid_response(response: QueryChunksResponse):
assert len(response.chunks) > 0
assert len(response.scores) > 0
assert len(response.chunks) == len(response.scores)
for chunk in response.chunks:
assert isinstance(chunk.content, str)

View file

@ -55,11 +55,11 @@ class SqliteKVStoreConfig(CommonConfig):
)
@classmethod
def sample_run_config(cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db"):
def sample_run_config(cls, __distro_dir__: str, db_name: str = "kvstore.db"):
return {
"type": "sqlite",
"namespace": None,
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + db_name,
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
}

View file

@ -34,7 +34,7 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
available_models = {

View file

@ -62,7 +62,7 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
default_tool_groups = [
ToolGroupInput(

View file

@ -48,7 +48,7 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_provider = Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
embedding_provider = Provider(
provider_id="sentence-transformers",

View file

@ -100,7 +100,7 @@ def get_distribution_template() -> DistributionTemplate:
Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_CHROMADB+chromadb}",

View file

@ -56,7 +56,7 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
available_models = {

View file

@ -51,7 +51,7 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
inference_model = ModelInput(

View file

@ -52,7 +52,7 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
inference_model = ModelInput(

View file

@ -58,7 +58,7 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
inference_model = ModelInput(

View file

@ -67,7 +67,7 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
inference_model = ModelInput(

View file

@ -45,7 +45,7 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_provider_sqlite = Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
inference_model = ModelInput(

View file

@ -55,7 +55,7 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
inference_model = ModelInput(

View file

@ -46,7 +46,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(
__distro_dir__=f"distributions/{name}",
__distro_dir__=f"~/.llama/distributions/{name}",
),
),
Provider(

View file

@ -86,7 +86,7 @@ class RunConfigSettings(BaseModel):
config_class = instantiate_class_type(config_class)
if hasattr(config_class, "sample_run_config"):
config = config_class.sample_run_config(__distro_dir__=f"distributions/{name}")
config = config_class.sample_run_config(__distro_dir__=f"~/.llama/distributions/{name}")
else:
config = {}
@ -107,7 +107,7 @@ class RunConfigSettings(BaseModel):
apis=apis,
providers=provider_configs,
metadata_store=SqliteKVStoreConfig.sample_run_config(
__distro_dir__=f"distributions/{name}",
__distro_dir__=f"~/.llama/distributions/{name}",
db_name="registry.db",
),
models=self.default_models or [],

View file

@ -55,7 +55,7 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
inference_model = ModelInput(

View file

@ -49,7 +49,7 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
embedding_provider = Provider(
provider_id="sentence-transformers",

View file

@ -46,7 +46,7 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
embedding_provider = Provider(
provider_id="sentence-transformers",

View file

@ -128,6 +128,7 @@ def distro_from_adhoc_config_spec(adhoc_config_spec: str) -> str:
api_providers = adhoc_config_spec.replace(";", ",").split(",")
provider_registry = get_provider_registry()
distro_dir = tempfile.mkdtemp()
provider_configs_by_api = {}
for api_provider in api_providers:
api_str, provider = api_provider.split("=")
@ -147,7 +148,7 @@ def distro_from_adhoc_config_spec(adhoc_config_spec: str) -> str:
# call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config())
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
provider_configs_by_api[api_str] = [
Provider(

View file

@ -4,83 +4,119 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import random
import pytest
INLINE_VECTOR_DB_PROVIDERS = [
"faiss",
# TODO: add sqlite_vec to templates
# "sqlite_vec",
from llama_stack.apis.vector_io import Chunk
@pytest.fixture(scope="session")
def sample_chunks():
return [
Chunk(
content="Python is a high-level programming language that emphasizes code readability and allows programmers to express concepts in fewer lines of code than would be possible in languages such as C++ or Java.",
metadata={"document_id": "doc1"},
),
Chunk(
content="Machine learning is a subset of artificial intelligence that enables systems to automatically learn and improve from experience without being explicitly programmed, using statistical techniques to give computer systems the ability to progressively improve performance on a specific task.",
metadata={"document_id": "doc2"},
),
Chunk(
content="Data structures are fundamental to computer science because they provide organized ways to store and access data efficiently, enable faster processing of data through optimized algorithms, and form the building blocks for more complex software systems.",
metadata={"document_id": "doc3"},
),
Chunk(
content="Neural networks are inspired by biological neural networks found in animal brains, using interconnected nodes called artificial neurons to process information through weighted connections that can be trained to recognize patterns and solve complex problems through iterative learning.",
metadata={"document_id": "doc4"},
),
]
@pytest.fixture(scope="function")
def empty_vector_db_registry(llama_stack_client):
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
def client_with_empty_registry(client_with_models):
def clear_registry():
vector_dbs = [vector_db.identifier for vector_db in client_with_models.vector_dbs.list()]
for vector_db_id in vector_dbs:
llama_stack_client.vector_dbs.unregister(vector_db_id=vector_db_id)
client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id)
clear_registry()
yield client_with_models
# you must clean after the last test if you were running tests against
# a stateful server instance
clear_registry()
@pytest.fixture(scope="function")
def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry, provider_id):
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_id=provider_id,
)
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
return vector_dbs
@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS)
def test_vector_db_retrieve(llama_stack_client, embedding_model_id, empty_vector_db_registry, provider_id):
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id):
# Register a memory bank first
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
llama_stack_client.vector_dbs.register(
vector_db_id = "test_vector_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=384,
provider_id=provider_id,
)
# Retrieve the memory bank and validate its properties
response = llama_stack_client.vector_dbs.retrieve(vector_db_id=vector_db_id)
response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=vector_db_id)
assert response is not None
assert response.identifier == vector_db_id
assert response.embedding_model == embedding_model_id
assert response.provider_id == provider_id
assert response.provider_resource_id == vector_db_id
def test_vector_db_list(llama_stack_client, empty_vector_db_registry):
vector_dbs_after_register = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
assert len(vector_dbs_after_register) == 0
@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS)
def test_vector_db_register(llama_stack_client, embedding_model_id, empty_vector_db_registry, provider_id):
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
llama_stack_client.vector_dbs.register(
def test_vector_db_register(client_with_empty_registry, embedding_model_id):
vector_db_id = "test_vector_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=384,
provider_id=provider_id,
)
vector_dbs_after_register = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert vector_dbs_after_register == [vector_db_id]
client_with_empty_registry.vector_dbs.unregister(vector_db_id=vector_db_id)
@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS)
def test_vector_db_unregister(llama_stack_client, single_entry_vector_db_registry, provider_id):
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
assert len(vector_dbs) == 1
vector_db_id = vector_dbs[0]
llama_stack_client.vector_dbs.unregister(vector_db_id=vector_db_id)
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert len(vector_dbs) == 0
@pytest.mark.parametrize(
"test_case",
[
("What makes Python different from C++ and Java?", "doc1"),
("How do systems learn without explicit programming?", "doc2"),
("Why are data structures important in computer science?", "doc3"),
("What is the biological inspiration for neural networks?", "doc4"),
("How does machine learning improve over time?", "doc2"),
],
)
def test_insert_chunks(client_with_empty_registry, embedding_model_id, sample_chunks, test_case):
vector_db_id = "test_vector_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=384,
)
client_with_empty_registry.vector_io.insert(
vector_db_id=vector_db_id,
chunks=sample_chunks,
)
response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
query="What is the capital of France?",
)
assert response is not None
assert len(response.chunks) > 1
assert len(response.scores) > 1
query, expected_doc_id = test_case
response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
query=query,
)
assert response is not None
top_match = response.chunks[0]
assert top_match is not None
assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}"

View file

@ -11,7 +11,6 @@ import numpy as np
import pytest
import sqlite_vec
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
SQLiteVecIndex,
@ -19,9 +18,13 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
generate_chunk_id,
)
# This test is a unit test for the SQLiteVecVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py \
# pytest tests/unit/providers/vector_io/test_sqlite_vec.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
SQLITE_VEC_PROVIDER = "sqlite_vec"
@ -116,35 +119,6 @@ async def sqlite_vec_adapter(sqlite_connection):
await adapter.shutdown()
@pytest.mark.asyncio
async def test_register_vector_db(sqlite_vec_adapter):
vector_db = VectorDB(
identifier="test_db",
embedding_model=EMBEDDING_MODEL,
embedding_dimension=EMBEDDING_DIMENSION,
metadata={},
provider_id=SQLITE_VEC_PROVIDER,
)
await sqlite_vec_adapter.register_vector_db(vector_db)
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
assert any(db.identifier == "test_db" for db in vector_dbs)
@pytest.mark.asyncio
async def test_unregister_vector_db(sqlite_vec_adapter):
vector_db = VectorDB(
identifier="test_db",
embedding_model=EMBEDDING_MODEL,
embedding_dimension=EMBEDDING_DIMENSION,
metadata={},
provider_id=SQLITE_VEC_PROVIDER,
)
await sqlite_vec_adapter.register_vector_db(vector_db)
await sqlite_vec_adapter.unregister_vector_db("test_db")
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
assert not any(db.identifier == "test_db" for db in vector_dbs)
def test_generate_chunk_id():
chunks = [
Chunk(content="test", metadata={"document_id": "doc-1"}),