diff --git a/llama_stack/providers/inline/vector_io/faiss/provider_patch.py b/llama_stack/providers/inline/vector_io/faiss/provider_patch.py index 5fb4a7594..94f31c6e5 100644 --- a/llama_stack/providers/inline/vector_io/faiss/provider_patch.py +++ b/llama_stack/providers/inline/vector_io/faiss/provider_patch.py @@ -1,6 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. """ Patch for the provider impl to fix the health check for the FAISS provider. +It is the workaround fix with current implementation if place for get_providers_health +as it returns a dict mapping API names to a single health response, but list_providers +expects a dict mapping API names to a dict of provider IDs to health responses. """ import logging @@ -17,6 +25,7 @@ original_list_providers = ProviderImpl.list_providers VECTOR_DIMENSION = 128 # sample dimension + # Helper method to check FAISS health directly async def check_faiss_health(): """Check the health of the FAISS vector database directly.""" @@ -25,10 +34,8 @@ async def check_faiss_health(): faiss.IndexFlatL2(VECTOR_DIMENSION) return HealthResponse(status=HealthStatus.OK) except Exception as e: - return HealthResponse( - status=HealthStatus.ERROR, - message=f"FAISS health check failed: {str(e)}" - ) + return HealthResponse(status=HealthStatus.ERROR, message=f"FAISS health check failed: {str(e)}") + async def patched_list_providers(self): """Patched version of list_providers to include FAISS health check.""" @@ -44,7 +51,9 @@ async def patched_list_providers(self): logger.debug("Updated FAISS health to: %s", provider.health) return response -new_list_providers = patched_list_providers + # Apply the patch by replacing the original method with patched version -ProviderImpl.list_providers = new_list_providers +# Added type: ignore because mypy cannot infer the correct type +# The typing error doesn't affect runtime behavior - it's only a static type check warning +ProviderImpl.list_providers = patched_list_providers # type: ignore logger.debug("Successfully applied patch for FAISS provider health check") diff --git a/tests/unit/providers/vector_io/test_faiss_provider_patch.py b/tests/unit/providers/vector_io/test_faiss_provider_patch.py index 6d0e1488e..d6c10c8b6 100644 --- a/tests/unit/providers/vector_io/test_faiss_provider_patch.py +++ b/tests/unit/providers/vector_io/test_faiss_provider_patch.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + """ Unit tests for the FAISS provider health check implementation via provider patch. """ @@ -54,7 +60,7 @@ class TestFaissProviderPatch(unittest.TestCase): ] with patch( "llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers", - self.original_list_providers + self.original_list_providers, ): result = await patched_list_providers(self.provider_impl) @@ -68,21 +74,19 @@ class TestFaissProviderPatch(unittest.TestCase): """Test the patched_list_providers method when a FAISS provider is found.""" # Create a mock FAISS provider mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io") - mock_faiss_provider.health = MagicMock( - return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED) - ) + mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)) # Set up the mock response with a FAISS provider self.mock_response.data = [ MagicMock(provider_id="other", api="vector_io"), mock_faiss_provider, ] - with patch( - "llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers", - self.original_list_providers - ), \ + with ( patch( - "llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health" - ) as mock_health: + "llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers", + self.original_list_providers, + ), + patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health, + ): mock_health.return_value = HealthResponse(status=HealthStatus.OK) result = await patched_list_providers(self.provider_impl) self.assertEqual(result, self.mock_response) @@ -96,24 +100,20 @@ class TestFaissProviderPatch(unittest.TestCase): async def test_patched_list_providers_with_faiss_health_failure(self): """Test the patched_list_providers method when the FAISS health check fails.""" mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io") - mock_faiss_provider.health = MagicMock( - return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED) - ) + mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)) self.mock_response.data = [ MagicMock(provider_id="other", api="vector_io"), mock_faiss_provider, ] - with patch( - "llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers", - self.original_list_providers), \ + with ( patch( - "llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health" - ) as mock_health: + "llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers", + self.original_list_providers, + ), + patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health, + ): # Configure the mock health check to simulate a failure - error_response = HealthResponse( - status=HealthStatus.ERROR, - message="FAISS health check failed: Test error" - ) + error_response = HealthResponse(status=HealthStatus.ERROR, message="FAISS health check failed: Test error") mock_health.return_value = error_response result = await patched_list_providers(self.provider_impl) @@ -124,28 +124,24 @@ class TestFaissProviderPatch(unittest.TestCase): for provider in result.data: if provider.provider_id == "faiss" and provider.api == "vector_io": self.assertEqual(provider.health.status, HealthStatus.ERROR) - self.assertEqual( - provider.health.message, "FAISS health check failed: Test error" - ) + self.assertEqual(provider.health.message, "FAISS health check failed: Test error") async def test_patched_list_providers_with_exception(self): """Test the patched_list_providers method when an exception occurs during health check.""" mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io") - mock_faiss_provider.health = MagicMock( - return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED) - ) + mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)) self.mock_response.data = [ MagicMock(provider_id="other", api="vector_io"), mock_faiss_provider, ] - with patch( - "llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers", - self.original_list_providers - ), \ - patch( - "llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health" - ) as mock_health: + with ( + patch( + "llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers", + self.original_list_providers, + ), + patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health, + ): # Configure the mock health check to raise an exception mock_health.side_effect = Exception("Unexpected error") result = await patched_list_providers(self.provider_impl)