mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
fix(faiss): handle case where distance is 0 by setting d to minimum positive… (#2387)
# What does this PR do? Adds try-catch to faiss `query_vector` function for when the distance between the query embedding and an embedding within the vector db is 0 (identical vectors). Catches `ZeroDivisionError` and then appends `(1.0 / sys.float_info.min)` to `scores` to represent maximum similarity. <!-- If resolving an issue, uncomment and update the line below --> Closes [#2381] ## Test Plan Checkout this PR Execute this code and there will no longer be a `ZeroDivisionError` exception ``` from llama_stack_client import LlamaStackClient base_url = "http://localhost:8321" client = LlamaStackClient(base_url=base_url) models = client.models.list() embedding_model = ( em := next(m for m in models if m.model_type == "embedding") ).identifier embedding_dimension = 384 _ = client.vector_dbs.register( vector_db_id="foo_db", embedding_model=embedding_model, embedding_dimension=embedding_dimension, provider_id="faiss", ) chunk = { "content": "foo", "mime_type": "text/plain", "metadata": { "document_id": "foo-id" } } client.vector_io.insert(vector_db_id="foo_db", chunks=[chunk]) client.vector_io.query(vector_db_id="foo_db", query="foo") ``` ### Running unit tests `uv run pytest tests/unit/rag/test_rag_query.py -v` --------- Signed-off-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
33ecefd284
commit
a34cef925b
4 changed files with 156 additions and 1 deletions
|
@ -112,7 +112,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
if i < 0:
|
if i < 0:
|
||||||
continue
|
continue
|
||||||
chunks.append(self.chunk_by_index[int(i)])
|
chunks.append(self.chunk_by_index[int(i)])
|
||||||
scores.append(1.0 / float(d))
|
scores.append(1.0 / float(d) if d != 0 else float("inf"))
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
|
@ -84,6 +84,7 @@ unit = [
|
||||||
"sqlalchemy",
|
"sqlalchemy",
|
||||||
"sqlalchemy[asyncio]>=2.0.41",
|
"sqlalchemy[asyncio]>=2.0.41",
|
||||||
"blobfile",
|
"blobfile",
|
||||||
|
"faiss-cpu"
|
||||||
]
|
]
|
||||||
# These are the core dependencies required for running integration tests. They are shared across all
|
# These are the core dependencies required for running integration tests. They are shared across all
|
||||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||||
|
|
120
tests/unit/providers/vector_io/test_faiss.py
Normal file
120
tests/unit/providers/vector_io/test_faiss.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import EmbeddingsResponse, Inference
|
||||||
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||||
|
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||||
|
from llama_stack.providers.inline.vector_io.faiss.faiss import (
|
||||||
|
FaissIndex,
|
||||||
|
FaissVectorIOAdapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This test is a unit test for the FaissVectorIOAdapter class. This should only contain
|
||||||
|
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||||
|
# tests/integration/vector_io/
|
||||||
|
#
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# pytest tests/unit/providers/vector_io/test_faiss.py \
|
||||||
|
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||||
|
|
||||||
|
FAISS_PROVIDER = "faiss"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def loop():
|
||||||
|
return asyncio.new_event_loop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def embedding_dimension():
|
||||||
|
return 384
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vector_db_id():
|
||||||
|
return "test_vector_db"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_chunks():
|
||||||
|
return [
|
||||||
|
Chunk(content="MOCK text content 1", mime_type="text/plain", metadata={"document_id": "mock-doc-1"}),
|
||||||
|
Chunk(content="MOCK text content 1", mime_type="text/plain", metadata={"document_id": "mock-doc-2"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_embeddings(embedding_dimension):
|
||||||
|
return np.random.rand(2, embedding_dimension).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_vector_db(vector_db_id, embedding_dimension) -> MagicMock:
|
||||||
|
mock_vector_db = MagicMock(spec=VectorDB)
|
||||||
|
mock_vector_db.embedding_model = "mock_embedding_model"
|
||||||
|
mock_vector_db.identifier = vector_db_id
|
||||||
|
mock_vector_db.embedding_dimension = embedding_dimension
|
||||||
|
return mock_vector_db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_inference_api(sample_embeddings):
|
||||||
|
mock_api = MagicMock(spec=Inference)
|
||||||
|
mock_api.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings))
|
||||||
|
return mock_api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def faiss_config():
|
||||||
|
config = MagicMock(spec=FaissVectorIOConfig)
|
||||||
|
config.kvstore = None
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def faiss_index(embedding_dimension):
|
||||||
|
index = await FaissIndex.create(dimension=embedding_dimension)
|
||||||
|
yield index
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def faiss_adapter(faiss_config, mock_inference_api) -> FaissVectorIOAdapter:
|
||||||
|
adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api)
|
||||||
|
await adapter.initialize()
|
||||||
|
yield adapter
|
||||||
|
await adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical(
|
||||||
|
faiss_index, sample_chunks, sample_embeddings, embedding_dimension
|
||||||
|
):
|
||||||
|
await faiss_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
|
||||||
|
with patch.object(faiss_index.index, "search") as mock_search:
|
||||||
|
mock_search.return_value = (np.array([[0.0, 0.1]]), np.array([[0, 1]]))
|
||||||
|
|
||||||
|
response = await faiss_index.query_vector(embedding=query_embedding, k=2, score_threshold=0.0)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
assert len(response.scores) == 2
|
||||||
|
|
||||||
|
assert response.scores[0] == float("inf") # infinity (1.0 / 0.0)
|
||||||
|
assert response.scores[1] == 10.0 # (1.0 / 0.1 = 10.0)
|
||||||
|
|
||||||
|
assert response.chunks[0] == sample_chunks[0]
|
||||||
|
assert response.chunks[1] == sample_chunks[1]
|
34
uv.lock
generated
34
uv.lock
generated
|
@ -715,6 +715,38 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702, upload-time = "2025-01-22T15:41:25.929Z" },
|
{ url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702, upload-time = "2025-01-22T15:41:25.929Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "faiss-cpu"
|
||||||
|
version = "1.11.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "numpy" },
|
||||||
|
{ name = "packaging" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/e7/9a/e33fc563f007924dd4ec3c5101fe5320298d6c13c158a24a9ed849058569/faiss_cpu-1.11.0.tar.gz", hash = "sha256:44877b896a2b30a61e35ea4970d008e8822545cb340eca4eff223ac7f40a1db9", size = 70218, upload-time = "2025-04-28T07:48:30.459Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ed/e5/7490368ec421e44efd60a21aa88d244653c674d8d6ee6bc455d8ee3d02ed/faiss_cpu-1.11.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:1995119152928c68096b0c1e5816e3ee5b1eebcf615b80370874523be009d0f6", size = 3307996, upload-time = "2025-04-28T07:47:29.126Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/dd/ac/a94fbbbf4f38c2ad11862af92c071ff346630ebf33f3d36fe75c3817c2f0/faiss_cpu-1.11.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:788d7bf24293fdecc1b93f1414ca5cc62ebd5f2fecfcbb1d77f0e0530621c95d", size = 7886309, upload-time = "2025-04-28T07:47:31.668Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/63/48/ad79f34f1b9eba58c32399ad4fbedec3f2a717d72fb03648e906aab48a52/faiss_cpu-1.11.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:73408d52429558f67889581c0c6d206eedcf6fabe308908f2bdcd28fd5e8be4a", size = 3778443, upload-time = "2025-04-28T07:47:33.685Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/95/67/3c6b94dd3223a8ecaff1c10c11b4ac6f3f13f1ba8ab6b6109c24b6e9b23d/faiss_cpu-1.11.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:1f53513682ca94c76472544fa5f071553e428a1453e0b9755c9673f68de45f12", size = 31295174, upload-time = "2025-04-28T07:47:36.309Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a4/2c/d843256aabdb7f20f0f87f61efe3fb7c2c8e7487915f560ba523cfcbab57/faiss_cpu-1.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:30489de0356d3afa0b492ca55da164d02453db2f7323c682b69334fde9e8d48e", size = 15003860, upload-time = "2025-04-28T07:47:39.381Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ed/83/8aefc4d07624a868e046cc23ede8a59bebda57f09f72aee2150ef0855a82/faiss_cpu-1.11.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:a90d1c81d0ecf2157e1d2576c482d734d10760652a5b2fcfa269916611e41f1c", size = 3307997, upload-time = "2025-04-28T07:47:41.905Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/2b/64/f97e91d89dc6327e08f619fe387d7d9945bc4be3b0f1ca1e494a41c92ebe/faiss_cpu-1.11.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:2c39a388b059fb82cd97fbaa7310c3580ced63bf285be531453bfffbe89ea3dd", size = 7886308, upload-time = "2025-04-28T07:47:44.677Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/44/0a/7c17b6df017b0bc127c6aa4066b028281e67ab83d134c7433c4e75cd6bb6/faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:a4e3433ffc7f9b8707a7963db04f8676a5756868d325644db2db9d67a618b7a0", size = 3778441, upload-time = "2025-04-28T07:47:46.914Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/53/45/7c85551025d9f0237d891b5cffdc5d4a366011d53b4b0a423b972cc52cea/faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:926645f1b6829623bc88e93bc8ca872504d604718ada3262e505177939aaee0a", size = 31295136, upload-time = "2025-04-28T07:47:49.299Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7f/9a/accade34b8668b21206c0c4cf0b96cd0b750b693ba5b255c1c10cfee460f/faiss_cpu-1.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:931db6ed2197c03a7fdf833b057c13529afa2cec8a827aa081b7f0543e4e671b", size = 15003710, upload-time = "2025-04-28T07:47:52.226Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3b/d3/7178fa07047fd770964a83543329bb5e3fc1447004cfd85186ccf65ec3ee/faiss_cpu-1.11.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:356437b9a46f98c25831cdae70ca484bd6c05065af6256d87f6505005e9135b9", size = 3313807, upload-time = "2025-04-28T07:47:54.533Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9e/71/25f5f7b70a9f22a3efe19e7288278da460b043a3b60ad98e4e47401ed5aa/faiss_cpu-1.11.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:c4a3d35993e614847f3221c6931529c0bac637a00eff0d55293e1db5cb98c85f", size = 7913537, upload-time = "2025-04-28T07:47:56.723Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b0/c8/a5cb8466c981ad47750e1d5fda3d4223c82f9da947538749a582b3a2d35c/faiss_cpu-1.11.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8f9af33e0b8324e8199b93eb70ac4a951df02802a9dcff88e9afc183b11666f0", size = 3785180, upload-time = "2025-04-28T07:47:59.004Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7f/37/eaf15a7d80e1aad74f56cf737b31b4547a1a664ad3c6e4cfaf90e82454a8/faiss_cpu-1.11.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:48b7e7876829e6bdf7333041800fa3c1753bb0c47e07662e3ef55aca86981430", size = 31287630, upload-time = "2025-04-28T07:48:01.248Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ff/5c/902a78347e9c47baaf133e47863134e564c39f9afe105795b16ee986b0df/faiss_cpu-1.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:bdc199311266d2be9d299da52361cad981393327b2b8aa55af31a1b75eaaf522", size = 15005398, upload-time = "2025-04-28T07:48:04.232Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/92/90/d2329ce56423cc61f4c20ae6b4db001c6f88f28bf5a7ef7f8bbc246fd485/faiss_cpu-1.11.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:0c98e5feff83b87348e44eac4d578d6f201780dae6f27f08a11d55536a20b3a8", size = 3313807, upload-time = "2025-04-28T07:48:06.486Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/24/14/8af8f996d54e6097a86e6048b1a2c958c52dc985eb4f935027615079939e/faiss_cpu-1.11.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:796e90389427b1c1fb06abdb0427bb343b6350f80112a2e6090ac8f176ff7416", size = 7913539, upload-time = "2025-04-28T07:48:08.338Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b2/2b/437c2f36c3aa3cffe041479fced1c76420d3e92e1f434f1da3be3e6f32b1/faiss_cpu-1.11.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2b6e355dda72b3050991bc32031b558b8f83a2b3537a2b9e905a84f28585b47e", size = 3785181, upload-time = "2025-04-28T07:48:10.594Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/66/75/955527414371843f558234df66fa0b62c6e86e71e4022b1be9333ac6004c/faiss_cpu-1.11.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6c482d07194638c169b4422774366e7472877d09181ea86835e782e6304d4185", size = 31287635, upload-time = "2025-04-28T07:48:12.93Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/50/51/35b7a3f47f7859363a367c344ae5d415ea9eda65db0a7d497c7ea2c0b576/faiss_cpu-1.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:13eac45299532b10e911bff1abbb19d1bf5211aa9e72afeade653c3f1e50e042", size = 15005455, upload-time = "2025-04-28T07:48:16.173Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastapi"
|
name = "fastapi"
|
||||||
version = "0.115.8"
|
version = "0.115.8"
|
||||||
|
@ -1549,6 +1581,7 @@ unit = [
|
||||||
{ name = "aiosqlite" },
|
{ name = "aiosqlite" },
|
||||||
{ name = "blobfile" },
|
{ name = "blobfile" },
|
||||||
{ name = "chardet" },
|
{ name = "chardet" },
|
||||||
|
{ name = "faiss-cpu" },
|
||||||
{ name = "mcp" },
|
{ name = "mcp" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
||||||
|
@ -1649,6 +1682,7 @@ unit = [
|
||||||
{ name = "aiosqlite" },
|
{ name = "aiosqlite" },
|
||||||
{ name = "blobfile" },
|
{ name = "blobfile" },
|
||||||
{ name = "chardet" },
|
{ name = "chardet" },
|
||||||
|
{ name = "faiss-cpu" },
|
||||||
{ name = "mcp" },
|
{ name = "mcp" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue