updates to fix pre-commits check

This commit is contained in:
Sumit Jaiswal 2025-06-01 17:35:45 +05:30
parent 24707a1173
commit 47373a65cf
No known key found for this signature in database
GPG key ID: A4604B39D64D6AEC
2 changed files with 46 additions and 41 deletions

View file

@ -1,6 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
""" """
Patch for the provider impl to fix the health check for the FAISS provider. Patch for the provider impl to fix the health check for the FAISS provider.
It is the workaround fix with current implementation if place for get_providers_health
as it returns a dict mapping API names to a single health response, but list_providers
expects a dict mapping API names to a dict of provider IDs to health responses.
""" """
import logging import logging
@ -17,6 +25,7 @@ original_list_providers = ProviderImpl.list_providers
VECTOR_DIMENSION = 128 # sample dimension VECTOR_DIMENSION = 128 # sample dimension
# Helper method to check FAISS health directly # Helper method to check FAISS health directly
async def check_faiss_health(): async def check_faiss_health():
"""Check the health of the FAISS vector database directly.""" """Check the health of the FAISS vector database directly."""
@ -25,10 +34,8 @@ async def check_faiss_health():
faiss.IndexFlatL2(VECTOR_DIMENSION) faiss.IndexFlatL2(VECTOR_DIMENSION)
return HealthResponse(status=HealthStatus.OK) return HealthResponse(status=HealthStatus.OK)
except Exception as e: except Exception as e:
return HealthResponse( return HealthResponse(status=HealthStatus.ERROR, message=f"FAISS health check failed: {str(e)}")
status=HealthStatus.ERROR,
message=f"FAISS health check failed: {str(e)}"
)
async def patched_list_providers(self): async def patched_list_providers(self):
"""Patched version of list_providers to include FAISS health check.""" """Patched version of list_providers to include FAISS health check."""
@ -44,7 +51,9 @@ async def patched_list_providers(self):
logger.debug("Updated FAISS health to: %s", provider.health) logger.debug("Updated FAISS health to: %s", provider.health)
return response return response
new_list_providers = patched_list_providers
# Apply the patch by replacing the original method with patched version # Apply the patch by replacing the original method with patched version
ProviderImpl.list_providers = new_list_providers # Added type: ignore because mypy cannot infer the correct type
# The typing error doesn't affect runtime behavior - it's only a static type check warning
ProviderImpl.list_providers = patched_list_providers # type: ignore
logger.debug("Successfully applied patch for FAISS provider health check") logger.debug("Successfully applied patch for FAISS provider health check")

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
""" """
Unit tests for the FAISS provider health check implementation via provider patch. Unit tests for the FAISS provider health check implementation via provider patch.
""" """
@ -54,7 +60,7 @@ class TestFaissProviderPatch(unittest.TestCase):
] ]
with patch( with patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers", "llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers self.original_list_providers,
): ):
result = await patched_list_providers(self.provider_impl) result = await patched_list_providers(self.provider_impl)
@ -68,21 +74,19 @@ class TestFaissProviderPatch(unittest.TestCase):
"""Test the patched_list_providers method when a FAISS provider is found.""" """Test the patched_list_providers method when a FAISS provider is found."""
# Create a mock FAISS provider # Create a mock FAISS provider
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io") mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
mock_faiss_provider.health = MagicMock( mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
)
# Set up the mock response with a FAISS provider # Set up the mock response with a FAISS provider
self.mock_response.data = [ self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"), MagicMock(provider_id="other", api="vector_io"),
mock_faiss_provider, mock_faiss_provider,
] ]
with patch( with (
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers
), \
patch( patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health" "llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
) as mock_health: self.original_list_providers,
),
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
):
mock_health.return_value = HealthResponse(status=HealthStatus.OK) mock_health.return_value = HealthResponse(status=HealthStatus.OK)
result = await patched_list_providers(self.provider_impl) result = await patched_list_providers(self.provider_impl)
self.assertEqual(result, self.mock_response) self.assertEqual(result, self.mock_response)
@ -96,24 +100,20 @@ class TestFaissProviderPatch(unittest.TestCase):
async def test_patched_list_providers_with_faiss_health_failure(self): async def test_patched_list_providers_with_faiss_health_failure(self):
"""Test the patched_list_providers method when the FAISS health check fails.""" """Test the patched_list_providers method when the FAISS health check fails."""
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io") mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
mock_faiss_provider.health = MagicMock( mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
)
self.mock_response.data = [ self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"), MagicMock(provider_id="other", api="vector_io"),
mock_faiss_provider, mock_faiss_provider,
] ]
with patch( with (
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers), \
patch( patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health" "llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
) as mock_health: self.original_list_providers,
),
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
):
# Configure the mock health check to simulate a failure # Configure the mock health check to simulate a failure
error_response = HealthResponse( error_response = HealthResponse(status=HealthStatus.ERROR, message="FAISS health check failed: Test error")
status=HealthStatus.ERROR,
message="FAISS health check failed: Test error"
)
mock_health.return_value = error_response mock_health.return_value = error_response
result = await patched_list_providers(self.provider_impl) result = await patched_list_providers(self.provider_impl)
@ -124,28 +124,24 @@ class TestFaissProviderPatch(unittest.TestCase):
for provider in result.data: for provider in result.data:
if provider.provider_id == "faiss" and provider.api == "vector_io": if provider.provider_id == "faiss" and provider.api == "vector_io":
self.assertEqual(provider.health.status, HealthStatus.ERROR) self.assertEqual(provider.health.status, HealthStatus.ERROR)
self.assertEqual( self.assertEqual(provider.health.message, "FAISS health check failed: Test error")
provider.health.message, "FAISS health check failed: Test error"
)
async def test_patched_list_providers_with_exception(self): async def test_patched_list_providers_with_exception(self):
"""Test the patched_list_providers method when an exception occurs during health check.""" """Test the patched_list_providers method when an exception occurs during health check."""
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io") mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
mock_faiss_provider.health = MagicMock( mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
)
self.mock_response.data = [ self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"), MagicMock(provider_id="other", api="vector_io"),
mock_faiss_provider, mock_faiss_provider,
] ]
with patch( with (
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers", patch(
self.original_list_providers "llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
), \ self.original_list_providers,
patch( ),
"llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health" patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
) as mock_health: ):
# Configure the mock health check to raise an exception # Configure the mock health check to raise an exception
mock_health.side_effect = Exception("Unexpected error") mock_health.side_effect = Exception("Unexpected error")
result = await patched_list_providers(self.provider_impl) result = await patched_list_providers(self.provider_impl)