diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py index 8eb56b7ca..6160a22f5 100644 --- a/llama_stack/distribution/routers/vector_io.py +++ b/llama_stack/distribution/routers/vector_io.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio from typing import Any from llama_stack.apis.common.content_types import ( @@ -21,7 +22,7 @@ from llama_stack.apis.vector_io import ( ) from llama_stack.apis.vector_io.vector_io import VectorStoreChunkingStrategy, VectorStoreFileObject from llama_stack.log import get_logger -from llama_stack.providers.datatypes import RoutingTable +from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable logger = get_logger(name=__name__, category="core") @@ -272,3 +273,26 @@ class VectorIORouter(VectorIO): attributes=attributes, chunking_strategy=chunking_strategy, ) + + async def health(self) -> dict[str, HealthResponse]: + health_statuses = {} + timeout = 1 # increasing the timeout to 1 second for health checks + for provider_id, impl in self.routing_table.impls_by_provider_id.items(): + try: + # check if the provider has a health method + if not hasattr(impl, "health"): + continue + health = await asyncio.wait_for(impl.health(), timeout=timeout) + health_statuses[provider_id] = health + except (asyncio.TimeoutError, TimeoutError): + health_statuses[provider_id] = HealthResponse( + status=HealthStatus.ERROR, + message=f"Health check timed out after {timeout} seconds", + ) + except NotImplementedError: + health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED) + except Exception as e: + health_statuses[provider_id] = HealthResponse( + status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}" + ) + return health_statuses diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 56fd25725..8758d4c9a 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -24,8 +24,11 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) -from llama_stack.providers.datatypes import VectorDBsProtocolPrivate -from llama_stack.providers.inline.vector_io.faiss.provider_patch import * # noqa: F403 +from llama_stack.providers.datatypes import ( + HealthResponse, + HealthStatus, + VectorDBsProtocolPrivate, +) from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin @@ -42,6 +45,7 @@ VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::" FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::" OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::" +VECTOR_DIMENSION = 128 # sample dimension class FaissIndex(EmbeddingIndex): @@ -176,6 +180,21 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr # Cleanup if needed pass + async def health(self) -> HealthResponse: + """ + Performs a health check by verifying connectivity to the inline faiss DB. + This method is used by the Provider API to verify + that the service is running correctly. + Returns: + + HealthResponse: A dictionary containing the health status. + """ + try: + faiss.IndexFlatL2(VECTOR_DIMENSION) + return HealthResponse(status=HealthStatus.OK) + except Exception as e: + return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") + async def register_vector_db( self, vector_db: VectorDB, diff --git a/llama_stack/providers/inline/vector_io/faiss/provider_patch.py b/llama_stack/providers/inline/vector_io/faiss/provider_patch.py deleted file mode 100644 index 94f31c6e5..000000000 --- a/llama_stack/providers/inline/vector_io/faiss/provider_patch.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -""" -Patch for the provider impl to fix the health check for the FAISS provider. -It is the workaround fix with current implementation if place for get_providers_health -as it returns a dict mapping API names to a single health response, but list_providers -expects a dict mapping API names to a dict of provider IDs to health responses. -""" - -import logging - -import faiss - -from llama_stack.distribution.providers import ProviderImpl -from llama_stack.providers.datatypes import HealthResponse, HealthStatus - -logger = logging.getLogger(__name__) - -# Store the original methods -original_list_providers = ProviderImpl.list_providers - -VECTOR_DIMENSION = 128 # sample dimension - - -# Helper method to check FAISS health directly -async def check_faiss_health(): - """Check the health of the FAISS vector database directly.""" - try: - # Create FAISS index to verify readiness - faiss.IndexFlatL2(VECTOR_DIMENSION) - return HealthResponse(status=HealthStatus.OK) - except Exception as e: - return HealthResponse(status=HealthStatus.ERROR, message=f"FAISS health check failed: {str(e)}") - - -async def patched_list_providers(self): - """Patched version of list_providers to include FAISS health check.""" - logger.debug("Using patched list_providers method") - # Get the original response - response = await original_list_providers(self) - # To find the FAISS provider in the response - for provider in response.data: - if provider.provider_id == "faiss" and provider.api == "vector_io": - health_result = await check_faiss_health() - logger.debug("FAISS health check result: %s", health_result) - provider.health = health_result - logger.debug("Updated FAISS health to: %s", provider.health) - return response - - -# Apply the patch by replacing the original method with patched version -# Added type: ignore because mypy cannot infer the correct type -# The typing error doesn't affect runtime behavior - it's only a static type check warning -ProviderImpl.list_providers = patched_list_providers # type: ignore -logger.debug("Successfully applied patch for FAISS provider health check") diff --git a/tests/unit/providers/vector_io/test_faiss.py b/tests/unit/providers/vector_io/test_faiss.py index 62f9b3538..8348b84e3 100644 --- a/tests/unit/providers/vector_io/test_faiss.py +++ b/tests/unit/providers/vector_io/test_faiss.py @@ -11,9 +11,11 @@ import numpy as np import pytest import pytest_asyncio +from llama_stack.apis.files import Files from llama_stack.apis.inference import EmbeddingsResponse, Inference from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse +from llama_stack.providers.datatypes import HealthStatus from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.faiss import ( FaissIndex, @@ -76,6 +78,12 @@ def mock_inference_api(sample_embeddings): return mock_api +@pytest.fixture +def mock_files_api(): + mock_api = MagicMock(spec=Files) + return mock_api + + @pytest.fixture def faiss_config(): config = MagicMock(spec=FaissVectorIOConfig) @@ -90,11 +98,19 @@ async def faiss_index(embedding_dimension): @pytest_asyncio.fixture -async def faiss_adapter(faiss_config, mock_inference_api) -> FaissVectorIOAdapter: - adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api) - await adapter.initialize() - yield adapter - await adapter.shutdown() +async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter: + # Create the adapter + adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api) + + # Create a mock KVStore + mock_kvstore = MagicMock() + mock_kvstore.values_in_range = AsyncMock(return_value=[]) + + # Patch the initialize method to avoid the kvstore_impl call + with patch.object(FaissVectorIOAdapter, "initialize"): + # Set the kvstore directly + adapter.kvstore = mock_kvstore + yield adapter @pytest.mark.asyncio @@ -118,3 +134,49 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_ assert response.chunks[0] == sample_chunks[0] assert response.chunks[1] == sample_chunks[1] + + +@pytest.mark.asyncio +async def test_health_success(): + """Test that the health check returns OK status when faiss is working correctly.""" + # Create a fresh instance of FaissVectorIOAdapter for testing + config = MagicMock() + inference_api = MagicMock() + files_api = MagicMock() + + with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat: + mock_index_flat.return_value = MagicMock() + adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api) + + # Calling the health method directly + response = await adapter.health() + + # Verifying the response + assert isinstance(response, dict) + assert response["status"] == HealthStatus.OK + assert "message" not in response + + # Verifying that IndexFlatL2 was called with the correct dimension + mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128 + + +@pytest.mark.asyncio +async def test_health_failure(): + """Test that the health check returns ERROR status when faiss encounters an error.""" + # Create a fresh instance of FaissVectorIOAdapter for testing + config = MagicMock() + inference_api = MagicMock() + files_api = MagicMock() + + with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat: + mock_index_flat.side_effect = Exception("Test error") + + adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api) + + # Calling the health method directly + response = await adapter.health() + + # Verifying the response + assert isinstance(response, dict) + assert response["status"] == HealthStatus.ERROR + assert response["message"] == "Health check failed: Test error" diff --git a/tests/unit/providers/vector_io/test_faiss_provider_patch.py b/tests/unit/providers/vector_io/test_faiss_provider_patch.py deleted file mode 100644 index d6c10c8b6..000000000 --- a/tests/unit/providers/vector_io/test_faiss_provider_patch.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -""" -Unit tests for the FAISS provider health check implementation via provider patch. -""" - -import unittest -from unittest.mock import AsyncMock, MagicMock, patch - -from llama_stack.distribution.providers import ProviderImpl -from llama_stack.providers.datatypes import HealthResponse, HealthStatus -from llama_stack.providers.inline.vector_io.faiss.provider_patch import ( - check_faiss_health, - patched_list_providers, -) - - -class TestFaissProviderPatch(unittest.TestCase): - """Test cases for the FAISS provider patch.""" - - def setUp(self): - """Set up test fixtures.""" - self.provider_impl = MagicMock(spec=ProviderImpl) - self.mock_response = MagicMock() - self.mock_response.data = [] - # Set up the original list_providers method - self.original_list_providers = AsyncMock(return_value=self.mock_response) - - async def test_check_faiss_health_success(self): - """Test the check_faiss_health function when FAISS is working properly.""" - with patch("faiss.IndexFlatL2") as mock_index: - mock_index.return_value = MagicMock() - # Call the health check function - result = await check_faiss_health() - - self.assertEqual(result.status, HealthStatus.OK) - mock_index.assert_called_once() - - async def test_check_faiss_health_failure(self): - """Test the check_faiss_health function when FAISS fails.""" - with patch("faiss.IndexFlatL2") as mock_index: - # Configure the mock to simulate a failure - mock_index.side_effect = Exception("FAISS initialization failed") - result = await check_faiss_health() - - self.assertEqual(result.status, HealthStatus.ERROR) - self.assertIn("FAISS health check failed", result.message) - mock_index.assert_called_once() - - async def test_patched_list_providers_no_faiss(self): - """Test the patched_list_providers method when no FAISS provider is found.""" - # Set up the mock response with NO FAISS provider - self.mock_response.data = [ - MagicMock(provider_id="other", api="vector_io"), - MagicMock(provider_id="faiss", api="other_api"), - ] - with patch( - "llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers", - self.original_list_providers, - ): - result = await patched_list_providers(self.provider_impl) - - self.assertEqual(result, self.mock_response) - self.original_list_providers.assert_called_once_with(self.provider_impl) - # Verify that no health checks were performed - for provider in result.data: - self.assertNotEqual(provider.provider_id, "faiss") - - async def test_patched_list_providers_with_faiss(self): - """Test the patched_list_providers method when a FAISS provider is found.""" - # Create a mock FAISS provider - mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io") - mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)) - # Set up the mock response with a FAISS provider - self.mock_response.data = [ - MagicMock(provider_id="other", api="vector_io"), - mock_faiss_provider, - ] - with ( - patch( - "llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers", - self.original_list_providers, - ), - patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health, - ): - mock_health.return_value = HealthResponse(status=HealthStatus.OK) - result = await patched_list_providers(self.provider_impl) - self.assertEqual(result, self.mock_response) - self.original_list_providers.assert_called_once_with(self.provider_impl) - mock_health.assert_called_once() - # Verify that the FAISS provider's health was updated - for provider in result.data: - if provider.provider_id == "faiss" and provider.api == "vector_io": - self.assertEqual(provider.health.status, HealthStatus.OK) - - async def test_patched_list_providers_with_faiss_health_failure(self): - """Test the patched_list_providers method when the FAISS health check fails.""" - mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io") - mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)) - self.mock_response.data = [ - MagicMock(provider_id="other", api="vector_io"), - mock_faiss_provider, - ] - with ( - patch( - "llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers", - self.original_list_providers, - ), - patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health, - ): - # Configure the mock health check to simulate a failure - error_response = HealthResponse(status=HealthStatus.ERROR, message="FAISS health check failed: Test error") - mock_health.return_value = error_response - - result = await patched_list_providers(self.provider_impl) - self.assertEqual(result, self.mock_response) - self.original_list_providers.assert_called_once_with(self.provider_impl) - mock_health.assert_called_once() - # Verify that the FAISS provider's health was updated with the error - for provider in result.data: - if provider.provider_id == "faiss" and provider.api == "vector_io": - self.assertEqual(provider.health.status, HealthStatus.ERROR) - self.assertEqual(provider.health.message, "FAISS health check failed: Test error") - - async def test_patched_list_providers_with_exception(self): - """Test the patched_list_providers method when an exception occurs during health check.""" - mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io") - mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)) - - self.mock_response.data = [ - MagicMock(provider_id="other", api="vector_io"), - mock_faiss_provider, - ] - with ( - patch( - "llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers", - self.original_list_providers, - ), - patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health, - ): - # Configure the mock health check to raise an exception - mock_health.side_effect = Exception("Unexpected error") - result = await patched_list_providers(self.provider_impl) - - self.assertEqual(result, self.mock_response) - self.original_list_providers.assert_called_once_with(self.provider_impl) - mock_health.assert_called_once() - # Verify that the FAISS provider's health was updated with an error - for provider in result.data: - if provider.provider_id == "faiss" and provider.api == "vector_io": - self.assertEqual(provider.health.status, HealthStatus.ERROR) - self.assertIn("Failed to check FAISS health", provider.health.message) - - -if __name__ == "__main__": - unittest.main()