fix faiss db health check correctly via review

This commit is contained in:
Sumit Jaiswal 2025-06-16 23:17:56 +05:30
parent 47373a65cf
commit be6ef8bc17
No known key found for this signature in database
GPG key ID: A4604B39D64D6AEC
5 changed files with 113 additions and 227 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from llama_stack.apis.common.content_types import (
@ -21,7 +22,7 @@ from llama_stack.apis.vector_io import (
)
from llama_stack.apis.vector_io.vector_io import VectorStoreChunkingStrategy, VectorStoreFileObject
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core")
@ -272,3 +273,26 @@ class VectorIORouter(VectorIO):
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def health(self) -> dict[str, HealthResponse]:
health_statuses = {}
timeout = 1 # increasing the timeout to 1 second for health checks
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
try:
# check if the provider has a health method
if not hasattr(impl, "health"):
continue
health = await asyncio.wait_for(impl.health(), timeout=timeout)
health_statuses[provider_id] = health
except (asyncio.TimeoutError, TimeoutError):
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check timed out after {timeout} seconds",
)
except NotImplementedError:
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
except Exception as e:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
)
return health_statuses

View file

@ -24,8 +24,11 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.faiss.provider_patch import * # noqa: F403
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
VectorDBsProtocolPrivate,
)
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -42,6 +45,7 @@ VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
VECTOR_DIMENSION = 128 # sample dimension
class FaissIndex(EmbeddingIndex):
@ -176,6 +180,21 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
# Cleanup if needed
pass
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the inline faiss DB.
This method is used by the Provider API to verify
that the service is running correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
faiss.IndexFlatL2(VECTOR_DIMENSION)
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def register_vector_db(
self,
vector_db: VectorDB,

View file

@ -1,59 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Patch for the provider impl to fix the health check for the FAISS provider.
It is the workaround fix with current implementation if place for get_providers_health
as it returns a dict mapping API names to a single health response, but list_providers
expects a dict mapping API names to a dict of provider IDs to health responses.
"""
import logging
import faiss
from llama_stack.distribution.providers import ProviderImpl
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
logger = logging.getLogger(__name__)
# Store the original methods
original_list_providers = ProviderImpl.list_providers
VECTOR_DIMENSION = 128 # sample dimension
# Helper method to check FAISS health directly
async def check_faiss_health():
"""Check the health of the FAISS vector database directly."""
try:
# Create FAISS index to verify readiness
faiss.IndexFlatL2(VECTOR_DIMENSION)
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"FAISS health check failed: {str(e)}")
async def patched_list_providers(self):
"""Patched version of list_providers to include FAISS health check."""
logger.debug("Using patched list_providers method")
# Get the original response
response = await original_list_providers(self)
# To find the FAISS provider in the response
for provider in response.data:
if provider.provider_id == "faiss" and provider.api == "vector_io":
health_result = await check_faiss_health()
logger.debug("FAISS health check result: %s", health_result)
provider.health = health_result
logger.debug("Updated FAISS health to: %s", provider.health)
return response
# Apply the patch by replacing the original method with patched version
# Added type: ignore because mypy cannot infer the correct type
# The typing error doesn't affect runtime behavior - it's only a static type check warning
ProviderImpl.list_providers = patched_list_providers # type: ignore
logger.debug("Successfully applied patch for FAISS provider health check")

View file

@ -11,9 +11,11 @@ import numpy as np
import pytest
import pytest_asyncio
from llama_stack.apis.files import Files
from llama_stack.apis.inference import EmbeddingsResponse, Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import (
FaissIndex,
@ -76,6 +78,12 @@ def mock_inference_api(sample_embeddings):
return mock_api
@pytest.fixture
def mock_files_api():
mock_api = MagicMock(spec=Files)
return mock_api
@pytest.fixture
def faiss_config():
config = MagicMock(spec=FaissVectorIOConfig)
@ -90,11 +98,19 @@ async def faiss_index(embedding_dimension):
@pytest_asyncio.fixture
async def faiss_adapter(faiss_config, mock_inference_api) -> FaissVectorIOAdapter:
adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api)
await adapter.initialize()
yield adapter
await adapter.shutdown()
async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter:
# Create the adapter
adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api)
# Create a mock KVStore
mock_kvstore = MagicMock()
mock_kvstore.values_in_range = AsyncMock(return_value=[])
# Patch the initialize method to avoid the kvstore_impl call
with patch.object(FaissVectorIOAdapter, "initialize"):
# Set the kvstore directly
adapter.kvstore = mock_kvstore
yield adapter
@pytest.mark.asyncio
@ -118,3 +134,49 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
assert response.chunks[0] == sample_chunks[0]
assert response.chunks[1] == sample_chunks[1]
@pytest.mark.asyncio
async def test_health_success():
"""Test that the health check returns OK status when faiss is working correctly."""
# Create a fresh instance of FaissVectorIOAdapter for testing
config = MagicMock()
inference_api = MagicMock()
files_api = MagicMock()
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
mock_index_flat.return_value = MagicMock()
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
# Calling the health method directly
response = await adapter.health()
# Verifying the response
assert isinstance(response, dict)
assert response["status"] == HealthStatus.OK
assert "message" not in response
# Verifying that IndexFlatL2 was called with the correct dimension
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
@pytest.mark.asyncio
async def test_health_failure():
"""Test that the health check returns ERROR status when faiss encounters an error."""
# Create a fresh instance of FaissVectorIOAdapter for testing
config = MagicMock()
inference_api = MagicMock()
files_api = MagicMock()
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
mock_index_flat.side_effect = Exception("Test error")
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
# Calling the health method directly
response = await adapter.health()
# Verifying the response
assert isinstance(response, dict)
assert response["status"] == HealthStatus.ERROR
assert response["message"] == "Health check failed: Test error"

View file

@ -1,160 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Unit tests for the FAISS provider health check implementation via provider patch.
"""
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
from llama_stack.distribution.providers import ProviderImpl
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
from llama_stack.providers.inline.vector_io.faiss.provider_patch import (
check_faiss_health,
patched_list_providers,
)
class TestFaissProviderPatch(unittest.TestCase):
"""Test cases for the FAISS provider patch."""
def setUp(self):
"""Set up test fixtures."""
self.provider_impl = MagicMock(spec=ProviderImpl)
self.mock_response = MagicMock()
self.mock_response.data = []
# Set up the original list_providers method
self.original_list_providers = AsyncMock(return_value=self.mock_response)
async def test_check_faiss_health_success(self):
"""Test the check_faiss_health function when FAISS is working properly."""
with patch("faiss.IndexFlatL2") as mock_index:
mock_index.return_value = MagicMock()
# Call the health check function
result = await check_faiss_health()
self.assertEqual(result.status, HealthStatus.OK)
mock_index.assert_called_once()
async def test_check_faiss_health_failure(self):
"""Test the check_faiss_health function when FAISS fails."""
with patch("faiss.IndexFlatL2") as mock_index:
# Configure the mock to simulate a failure
mock_index.side_effect = Exception("FAISS initialization failed")
result = await check_faiss_health()
self.assertEqual(result.status, HealthStatus.ERROR)
self.assertIn("FAISS health check failed", result.message)
mock_index.assert_called_once()
async def test_patched_list_providers_no_faiss(self):
"""Test the patched_list_providers method when no FAISS provider is found."""
# Set up the mock response with NO FAISS provider
self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"),
MagicMock(provider_id="faiss", api="other_api"),
]
with patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers,
):
result = await patched_list_providers(self.provider_impl)
self.assertEqual(result, self.mock_response)
self.original_list_providers.assert_called_once_with(self.provider_impl)
# Verify that no health checks were performed
for provider in result.data:
self.assertNotEqual(provider.provider_id, "faiss")
async def test_patched_list_providers_with_faiss(self):
"""Test the patched_list_providers method when a FAISS provider is found."""
# Create a mock FAISS provider
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
# Set up the mock response with a FAISS provider
self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"),
mock_faiss_provider,
]
with (
patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers,
),
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
):
mock_health.return_value = HealthResponse(status=HealthStatus.OK)
result = await patched_list_providers(self.provider_impl)
self.assertEqual(result, self.mock_response)
self.original_list_providers.assert_called_once_with(self.provider_impl)
mock_health.assert_called_once()
# Verify that the FAISS provider's health was updated
for provider in result.data:
if provider.provider_id == "faiss" and provider.api == "vector_io":
self.assertEqual(provider.health.status, HealthStatus.OK)
async def test_patched_list_providers_with_faiss_health_failure(self):
"""Test the patched_list_providers method when the FAISS health check fails."""
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"),
mock_faiss_provider,
]
with (
patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers,
),
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
):
# Configure the mock health check to simulate a failure
error_response = HealthResponse(status=HealthStatus.ERROR, message="FAISS health check failed: Test error")
mock_health.return_value = error_response
result = await patched_list_providers(self.provider_impl)
self.assertEqual(result, self.mock_response)
self.original_list_providers.assert_called_once_with(self.provider_impl)
mock_health.assert_called_once()
# Verify that the FAISS provider's health was updated with the error
for provider in result.data:
if provider.provider_id == "faiss" and provider.api == "vector_io":
self.assertEqual(provider.health.status, HealthStatus.ERROR)
self.assertEqual(provider.health.message, "FAISS health check failed: Test error")
async def test_patched_list_providers_with_exception(self):
"""Test the patched_list_providers method when an exception occurs during health check."""
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"),
mock_faiss_provider,
]
with (
patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers,
),
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
):
# Configure the mock health check to raise an exception
mock_health.side_effect = Exception("Unexpected error")
result = await patched_list_providers(self.provider_impl)
self.assertEqual(result, self.mock_response)
self.original_list_providers.assert_called_once_with(self.provider_impl)
mock_health.assert_called_once()
# Verify that the FAISS provider's health was updated with an error
for provider in result.data:
if provider.provider_id == "faiss" and provider.api == "vector_io":
self.assertEqual(provider.health.status, HealthStatus.ERROR)
self.assertIn("Failed to check FAISS health", provider.health.message)
if __name__ == "__main__":
unittest.main()