diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py index a20ce70b6..44c1fafa7 100644 --- a/llama_stack/distribution/routers/vector_io.py +++ b/llama_stack/distribution/routers/vector_io.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio from typing import Any from llama_stack.apis.common.content_types import ( @@ -21,7 +22,7 @@ from llama_stack.apis.vector_io import ( ) from llama_stack.apis.vector_io.vector_io import VectorStoreChunkingStrategy, VectorStoreFileObject from llama_stack.log import get_logger -from llama_stack.providers.datatypes import RoutingTable +from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable logger = get_logger(name=__name__, category="core") @@ -276,3 +277,26 @@ class VectorIORouter(VectorIO): attributes=attributes, chunking_strategy=chunking_strategy, ) + + async def health(self) -> dict[str, HealthResponse]: + health_statuses = {} + timeout = 1 # increasing the timeout to 1 second for health checks + for provider_id, impl in self.routing_table.impls_by_provider_id.items(): + try: + # check if the provider has a health method + if not hasattr(impl, "health"): + continue + health = await asyncio.wait_for(impl.health(), timeout=timeout) + health_statuses[provider_id] = health + except (asyncio.TimeoutError, TimeoutError): + health_statuses[provider_id] = HealthResponse( + status=HealthStatus.ERROR, + message=f"Health check timed out after {timeout} seconds", + ) + except NotImplementedError: + health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED) + except Exception as e: + health_statuses[provider_id] = HealthResponse( + status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}" + ) + return health_statuses diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index a2f4417e0..0864ba3a7 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -24,7 +24,11 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) -from llama_stack.providers.datatypes import VectorDBsProtocolPrivate +from llama_stack.providers.datatypes import ( + HealthResponse, + HealthStatus, + VectorDBsProtocolPrivate, +) from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin @@ -175,6 +179,22 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr # Cleanup if needed pass + async def health(self) -> HealthResponse: + """ + Performs a health check by verifying connectivity to the inline faiss DB. + This method is used by the Provider API to verify + that the service is running correctly. + Returns: + + HealthResponse: A dictionary containing the health status. + """ + try: + vector_dimension = 128 # sample dimension + faiss.IndexFlatL2(vector_dimension) + return HealthResponse(status=HealthStatus.OK) + except Exception as e: + return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") + async def register_vector_db( self, vector_db: VectorDB, diff --git a/tests/unit/providers/vector_io/test_faiss.py b/tests/unit/providers/vector_io/test_faiss.py index 62f9b3538..8348b84e3 100644 --- a/tests/unit/providers/vector_io/test_faiss.py +++ b/tests/unit/providers/vector_io/test_faiss.py @@ -11,9 +11,11 @@ import numpy as np import pytest import pytest_asyncio +from llama_stack.apis.files import Files from llama_stack.apis.inference import EmbeddingsResponse, Inference from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse +from llama_stack.providers.datatypes import HealthStatus from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.faiss import ( FaissIndex, @@ -76,6 +78,12 @@ def mock_inference_api(sample_embeddings): return mock_api +@pytest.fixture +def mock_files_api(): + mock_api = MagicMock(spec=Files) + return mock_api + + @pytest.fixture def faiss_config(): config = MagicMock(spec=FaissVectorIOConfig) @@ -90,11 +98,19 @@ async def faiss_index(embedding_dimension): @pytest_asyncio.fixture -async def faiss_adapter(faiss_config, mock_inference_api) -> FaissVectorIOAdapter: - adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api) - await adapter.initialize() - yield adapter - await adapter.shutdown() +async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter: + # Create the adapter + adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api) + + # Create a mock KVStore + mock_kvstore = MagicMock() + mock_kvstore.values_in_range = AsyncMock(return_value=[]) + + # Patch the initialize method to avoid the kvstore_impl call + with patch.object(FaissVectorIOAdapter, "initialize"): + # Set the kvstore directly + adapter.kvstore = mock_kvstore + yield adapter @pytest.mark.asyncio @@ -118,3 +134,49 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_ assert response.chunks[0] == sample_chunks[0] assert response.chunks[1] == sample_chunks[1] + + +@pytest.mark.asyncio +async def test_health_success(): + """Test that the health check returns OK status when faiss is working correctly.""" + # Create a fresh instance of FaissVectorIOAdapter for testing + config = MagicMock() + inference_api = MagicMock() + files_api = MagicMock() + + with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat: + mock_index_flat.return_value = MagicMock() + adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api) + + # Calling the health method directly + response = await adapter.health() + + # Verifying the response + assert isinstance(response, dict) + assert response["status"] == HealthStatus.OK + assert "message" not in response + + # Verifying that IndexFlatL2 was called with the correct dimension + mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128 + + +@pytest.mark.asyncio +async def test_health_failure(): + """Test that the health check returns ERROR status when faiss encounters an error.""" + # Create a fresh instance of FaissVectorIOAdapter for testing + config = MagicMock() + inference_api = MagicMock() + files_api = MagicMock() + + with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat: + mock_index_flat.side_effect = Exception("Test error") + + adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api) + + # Calling the health method directly + response = await adapter.health() + + # Verifying the response + assert isinstance(response, dict) + assert response["status"] == HealthStatus.ERROR + assert response["message"] == "Health check failed: Test error"