[memory refactor][2/n] Update faiss and make it pass tests (#830)

See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

Second part:

- updates routing table / router code 
- updates the faiss implementation


## Test Plan

```
pytest -s -v -k sentence test_vector_io.py --env EMBEDDING_DIMENSION=384
```
This commit is contained in:
Ashwin Bharambe 2025-01-22 10:02:15 -08:00 committed by GitHub
parent 3ae8585b65
commit 78a481bb22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 343 additions and 353 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import (
get_provider_fixture_overrides,
get_provider_fixture_overrides_from_test_config,
get_test_config_for_api,
)
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import VECTOR_IO_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "sentence_transformers",
"vector_io": "faiss",
},
id="sentence_transformers",
marks=pytest.mark.sentence_transformers,
),
pytest.param(
{
"inference": "ollama",
"vector_io": "faiss",
},
id="ollama",
marks=pytest.mark.ollama,
),
pytest.param(
{
"inference": "sentence_transformers",
"vector_io": "chroma",
},
id="chroma",
marks=pytest.mark.chroma,
),
pytest.param(
{
"inference": "bedrock",
"vector_io": "qdrant",
},
id="qdrant",
marks=pytest.mark.qdrant,
),
pytest.param(
{
"inference": "fireworks",
"vector_io": "weaviate",
},
id="weaviate",
marks=pytest.mark.weaviate,
),
]
def pytest_configure(config):
for fixture_name in VECTOR_IO_FIXTURES:
config.addinivalue_line(
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
def pytest_generate_tests(metafunc):
test_config = get_test_config_for_api(metafunc.config, "vector_io")
if "embedding_model" in metafunc.fixturenames:
model = getattr(test_config, "embedding_model", None)
# Fall back to the default if not specified by the config file
model = model or metafunc.config.getoption("--embedding-model")
if model:
params = [pytest.param(model, id="")]
else:
params = [pytest.param("all-MiniLM-L6-v2", id="")]
metafunc.parametrize("embedding_model", params, indirect=True)
if "vector_io_stack" in metafunc.fixturenames:
available_fixtures = {
"inference": INFERENCE_FIXTURES,
"vector_io": VECTOR_IO_FIXTURES,
}
combinations = (
get_provider_fixture_overrides_from_test_config(
metafunc.config, "vector_io", DEFAULT_PROVIDER_COMBINATIONS
)
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("vector_io_stack", combinations, indirect=True)

View file

@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import tempfile
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig
from llama_stack.providers.remote.vector_io.chroma import ChromaRemoteImplConfig
from llama_stack.providers.remote.vector_io.pgvector import PGVectorConfig
from llama_stack.providers.remote.vector_io.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def embedding_model(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--embedding-model", None)
@pytest.fixture(scope="session")
def vector_io_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def vector_io_faiss() -> ProviderFixture:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture(
providers=[
Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissImplConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def vector_io_pgvector() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="pgvector",
provider_type="remote::pgvector",
config=PGVectorConfig(
host=os.getenv("PGVECTOR_HOST", "localhost"),
port=os.getenv("PGVECTOR_PORT", 5432),
db=get_env_or_fail("PGVECTOR_DB"),
user=get_env_or_fail("PGVECTOR_USER"),
password=get_env_or_fail("PGVECTOR_PASSWORD"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def vector_io_weaviate() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="weaviate",
provider_type="remote::weaviate",
config=WeaviateConfig().model_dump(),
)
],
provider_data=dict(
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
),
)
@pytest.fixture(scope="session")
def vector_io_chroma() -> ProviderFixture:
url = os.getenv("CHROMA_URL")
if url:
config = ChromaRemoteImplConfig(url=url)
provider_type = "remote::chromadb"
else:
if not os.getenv("CHROMA_DB_PATH"):
raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set")
config = ChromaInlineImplConfig(db_path=os.getenv("CHROMA_DB_PATH"))
provider_type = "inline::chromadb"
return ProviderFixture(
providers=[
Provider(
provider_id="chroma",
provider_type=provider_type,
config=config.model_dump(),
)
]
)
VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma"]
@pytest_asyncio.fixture(scope="session")
async def vector_io_stack(embedding_model, request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["inference", "vector_io"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test(
[Api.vector_io, Api.inference],
providers,
provider_data,
models=[
ModelInput(
model_id=embedding_model,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
},
)
],
)
return test_stack.impls[Api.vector_io], test_stack.impls[Api.vector_dbs]

View file

@ -0,0 +1,200 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
import pytest
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.providers.utils.memory.vector_store import (
make_overlapped_chunks,
MemoryBankDocument,
)
# How to run this test:
#
# pytest llama_stack/providers/tests/memory/test_memory.py
# -m "sentence_transformers" --env EMBEDDING_DIMENSION=384
# -v -s --tb=short --disable-warnings
@pytest.fixture(scope="session")
def sample_chunks():
docs = [
MemoryBankDocument(
document_id="doc1",
content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"},
),
MemoryBankDocument(
document_id="doc2",
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"},
),
MemoryBankDocument(
document_id="doc3",
content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"},
),
MemoryBankDocument(
document_id="doc4",
content="Neural networks are inspired by biological neural networks.",
metadata={"category": "AI", "difficulty": "advanced"},
),
]
chunks = []
for doc in docs:
chunks.extend(
make_overlapped_chunks(
doc.document_id, doc.content, window_len=512, overlap_len=64
)
)
return chunks
async def register_vector_db(vector_dbs_impl: VectorDB, embedding_model: str):
vector_db_id = f"test_vector_db_{uuid.uuid4().hex}"
return await vector_dbs_impl.register_vector_db(
vector_db_id=vector_db_id,
embedding_model=embedding_model,
embedding_dimension=384,
)
class TestVectorIO:
@pytest.mark.asyncio
async def test_banks_list(self, vector_io_stack, embedding_model):
_, vector_dbs_impl = vector_io_stack
# Register a test bank
registered_vector_db = await register_vector_db(
vector_dbs_impl, embedding_model
)
try:
# Verify our bank shows up in list
response = await vector_dbs_impl.list_vector_dbs()
assert isinstance(response, ListVectorDBsResponse)
assert any(
vector_db.vector_db_id == registered_vector_db.vector_db_id
for vector_db in response.data
)
finally:
# Clean up
await vector_dbs_impl.unregister_vector_db(
registered_vector_db.vector_db_id
)
# Verify our bank was removed
response = await vector_dbs_impl.list_vector_dbs()
assert isinstance(response, ListVectorDBsResponse)
assert all(
vector_db.vector_db_id != registered_vector_db.vector_db_id
for vector_db in response.data
)
@pytest.mark.asyncio
async def test_banks_register(self, vector_io_stack, embedding_model):
_, vector_dbs_impl = vector_io_stack
vector_db_id = f"test_vector_db_{uuid.uuid4().hex}"
try:
# Register initial bank
await vector_dbs_impl.register_vector_db(
vector_db_id=vector_db_id,
embedding_model=embedding_model,
embedding_dimension=384,
)
# Verify our bank exists
response = await vector_dbs_impl.list_vector_dbs()
assert isinstance(response, ListVectorDBsResponse)
assert any(
vector_db.vector_db_id == vector_db_id for vector_db in response.data
)
# Try registering same bank again
await vector_dbs_impl.register_vector_db(
vector_db_id=vector_db_id,
embedding_model=embedding_model,
embedding_dimension=384,
)
# Verify still only one instance of our bank
response = await vector_dbs_impl.list_vector_dbs()
assert isinstance(response, ListVectorDBsResponse)
assert (
len(
[
vector_db
for vector_db in response.data
if vector_db.vector_db_id == vector_db_id
]
)
== 1
)
finally:
# Clean up
await vector_dbs_impl.unregister_vector_db(vector_db_id)
@pytest.mark.asyncio
async def test_query_documents(
self, vector_io_stack, embedding_model, sample_chunks
):
vector_io_impl, vector_dbs_impl = vector_io_stack
with pytest.raises(ValueError):
await vector_io_impl.insert_chunks("test_vector_db", sample_chunks)
registered_db = await register_vector_db(vector_dbs_impl, embedding_model)
await vector_io_impl.insert_chunks(registered_db.vector_db_id, sample_chunks)
query1 = "programming language"
response1 = await vector_io_impl.query_chunks(
registered_db.vector_db_id, query1
)
assert_valid_response(response1)
assert any("Python" in chunk.content for chunk in response1.chunks)
# Test case 3: Query with semantic similarity
query3 = "AI and brain-inspired computing"
response3 = await vector_io_impl.query_chunks(
registered_db.vector_db_id, query3
)
assert_valid_response(response3)
assert any(
"neural networks" in chunk.content.lower() for chunk in response3.chunks
)
# Test case 4: Query with limit on number of results
query4 = "computer"
params4 = {"max_chunks": 2}
response4 = await vector_io_impl.query_chunks(
registered_db.vector_db_id, query4, params4
)
assert_valid_response(response4)
assert len(response4.chunks) <= 2
# Test case 5: Query with threshold on similarity score
query5 = "quantum computing" # Not directly related to any document
params5 = {"score_threshold": 0.01}
response5 = await vector_io_impl.query_chunks(
registered_db.vector_db_id, query5, params5
)
assert_valid_response(response5)
print("The scores are:", response5.scores)
assert all(score >= 0.01 for score in response5.scores)
def assert_valid_response(response: QueryChunksResponse):
assert len(response.chunks) > 0
assert len(response.scores) > 0
assert len(response.chunks) == len(response.scores)
for chunk in response.chunks:
assert isinstance(chunk.content, str)

View file

@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import mimetypes
import os
from pathlib import Path
import pytest
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
MemoryBankDocument,
URL,
)
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
def read_file(file_path: str) -> bytes:
with open(file_path, "rb") as file:
return file.read()
def data_url_from_file(file_path: str) -> str:
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
class TestVectorStore:
@pytest.mark.asyncio
async def test_returns_content_from_pdf_data_uri(self):
data_uri = data_url_from_file(DUMMY_PDF_PATH)
doc = MemoryBankDocument(
document_id="dummy",
content=data_uri,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"
@pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = MemoryBankDocument(
document_id="dummy",
content=url,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"
@pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content_with_url_object(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = MemoryBankDocument(
document_id="dummy",
content=URL(
uri=url,
),
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"