fix faiss db health check correctly via review

This commit is contained in:
Sumit Jaiswal 2025-06-16 23:17:56 +05:30
parent 47373a65cf
commit be6ef8bc17
No known key found for this signature in database
GPG key ID: A4604B39D64D6AEC
5 changed files with 113 additions and 227 deletions

View file

@ -11,9 +11,11 @@ import numpy as np
import pytest
import pytest_asyncio
from llama_stack.apis.files import Files
from llama_stack.apis.inference import EmbeddingsResponse, Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import (
FaissIndex,
@ -76,6 +78,12 @@ def mock_inference_api(sample_embeddings):
return mock_api
@pytest.fixture
def mock_files_api():
mock_api = MagicMock(spec=Files)
return mock_api
@pytest.fixture
def faiss_config():
config = MagicMock(spec=FaissVectorIOConfig)
@ -90,11 +98,19 @@ async def faiss_index(embedding_dimension):
@pytest_asyncio.fixture
async def faiss_adapter(faiss_config, mock_inference_api) -> FaissVectorIOAdapter:
adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api)
await adapter.initialize()
yield adapter
await adapter.shutdown()
async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter:
# Create the adapter
adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api)
# Create a mock KVStore
mock_kvstore = MagicMock()
mock_kvstore.values_in_range = AsyncMock(return_value=[])
# Patch the initialize method to avoid the kvstore_impl call
with patch.object(FaissVectorIOAdapter, "initialize"):
# Set the kvstore directly
adapter.kvstore = mock_kvstore
yield adapter
@pytest.mark.asyncio
@ -118,3 +134,49 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
assert response.chunks[0] == sample_chunks[0]
assert response.chunks[1] == sample_chunks[1]
@pytest.mark.asyncio
async def test_health_success():
"""Test that the health check returns OK status when faiss is working correctly."""
# Create a fresh instance of FaissVectorIOAdapter for testing
config = MagicMock()
inference_api = MagicMock()
files_api = MagicMock()
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
mock_index_flat.return_value = MagicMock()
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
# Calling the health method directly
response = await adapter.health()
# Verifying the response
assert isinstance(response, dict)
assert response["status"] == HealthStatus.OK
assert "message" not in response
# Verifying that IndexFlatL2 was called with the correct dimension
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
@pytest.mark.asyncio
async def test_health_failure():
"""Test that the health check returns ERROR status when faiss encounters an error."""
# Create a fresh instance of FaissVectorIOAdapter for testing
config = MagicMock()
inference_api = MagicMock()
files_api = MagicMock()
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
mock_index_flat.side_effect = Exception("Test error")
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
# Calling the health method directly
response = await adapter.health()
# Verifying the response
assert isinstance(response, dict)
assert response["status"] == HealthStatus.ERROR
assert response["message"] == "Health check failed: Test error"

View file

@ -1,160 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Unit tests for the FAISS provider health check implementation via provider patch.
"""
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
from llama_stack.distribution.providers import ProviderImpl
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
from llama_stack.providers.inline.vector_io.faiss.provider_patch import (
check_faiss_health,
patched_list_providers,
)
class TestFaissProviderPatch(unittest.TestCase):
"""Test cases for the FAISS provider patch."""
def setUp(self):
"""Set up test fixtures."""
self.provider_impl = MagicMock(spec=ProviderImpl)
self.mock_response = MagicMock()
self.mock_response.data = []
# Set up the original list_providers method
self.original_list_providers = AsyncMock(return_value=self.mock_response)
async def test_check_faiss_health_success(self):
"""Test the check_faiss_health function when FAISS is working properly."""
with patch("faiss.IndexFlatL2") as mock_index:
mock_index.return_value = MagicMock()
# Call the health check function
result = await check_faiss_health()
self.assertEqual(result.status, HealthStatus.OK)
mock_index.assert_called_once()
async def test_check_faiss_health_failure(self):
"""Test the check_faiss_health function when FAISS fails."""
with patch("faiss.IndexFlatL2") as mock_index:
# Configure the mock to simulate a failure
mock_index.side_effect = Exception("FAISS initialization failed")
result = await check_faiss_health()
self.assertEqual(result.status, HealthStatus.ERROR)
self.assertIn("FAISS health check failed", result.message)
mock_index.assert_called_once()
async def test_patched_list_providers_no_faiss(self):
"""Test the patched_list_providers method when no FAISS provider is found."""
# Set up the mock response with NO FAISS provider
self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"),
MagicMock(provider_id="faiss", api="other_api"),
]
with patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers,
):
result = await patched_list_providers(self.provider_impl)
self.assertEqual(result, self.mock_response)
self.original_list_providers.assert_called_once_with(self.provider_impl)
# Verify that no health checks were performed
for provider in result.data:
self.assertNotEqual(provider.provider_id, "faiss")
async def test_patched_list_providers_with_faiss(self):
"""Test the patched_list_providers method when a FAISS provider is found."""
# Create a mock FAISS provider
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
# Set up the mock response with a FAISS provider
self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"),
mock_faiss_provider,
]
with (
patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers,
),
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
):
mock_health.return_value = HealthResponse(status=HealthStatus.OK)
result = await patched_list_providers(self.provider_impl)
self.assertEqual(result, self.mock_response)
self.original_list_providers.assert_called_once_with(self.provider_impl)
mock_health.assert_called_once()
# Verify that the FAISS provider's health was updated
for provider in result.data:
if provider.provider_id == "faiss" and provider.api == "vector_io":
self.assertEqual(provider.health.status, HealthStatus.OK)
async def test_patched_list_providers_with_faiss_health_failure(self):
"""Test the patched_list_providers method when the FAISS health check fails."""
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"),
mock_faiss_provider,
]
with (
patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers,
),
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
):
# Configure the mock health check to simulate a failure
error_response = HealthResponse(status=HealthStatus.ERROR, message="FAISS health check failed: Test error")
mock_health.return_value = error_response
result = await patched_list_providers(self.provider_impl)
self.assertEqual(result, self.mock_response)
self.original_list_providers.assert_called_once_with(self.provider_impl)
mock_health.assert_called_once()
# Verify that the FAISS provider's health was updated with the error
for provider in result.data:
if provider.provider_id == "faiss" and provider.api == "vector_io":
self.assertEqual(provider.health.status, HealthStatus.ERROR)
self.assertEqual(provider.health.message, "FAISS health check failed: Test error")
async def test_patched_list_providers_with_exception(self):
"""Test the patched_list_providers method when an exception occurs during health check."""
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"),
mock_faiss_provider,
]
with (
patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers,
),
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
):
# Configure the mock health check to raise an exception
mock_health.side_effect = Exception("Unexpected error")
result = await patched_list_providers(self.provider_impl)
self.assertEqual(result, self.mock_response)
self.original_list_providers.assert_called_once_with(self.provider_impl)
mock_health.assert_called_once()
# Verify that the FAISS provider's health was updated with an error
for provider in result.data:
if provider.provider_id == "faiss" and provider.api == "vector_io":
self.assertEqual(provider.health.status, HealthStatus.ERROR)
self.assertIn("Failed to check FAISS health", provider.health.message)
if __name__ == "__main__":
unittest.main()