updates to fix pre-commits check

This commit is contained in:
Sumit Jaiswal 2025-06-01 17:35:45 +05:30
parent 24707a1173
commit 47373a65cf
No known key found for this signature in database
GPG key ID: A4604B39D64D6AEC
2 changed files with 46 additions and 41 deletions

View file

@ -1,6 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Patch for the provider impl to fix the health check for the FAISS provider.
It is the workaround fix with current implementation if place for get_providers_health
as it returns a dict mapping API names to a single health response, but list_providers
expects a dict mapping API names to a dict of provider IDs to health responses.
"""
import logging
@ -17,6 +25,7 @@ original_list_providers = ProviderImpl.list_providers
VECTOR_DIMENSION = 128 # sample dimension
# Helper method to check FAISS health directly
async def check_faiss_health():
"""Check the health of the FAISS vector database directly."""
@ -25,10 +34,8 @@ async def check_faiss_health():
faiss.IndexFlatL2(VECTOR_DIMENSION)
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(
status=HealthStatus.ERROR,
message=f"FAISS health check failed: {str(e)}"
)
return HealthResponse(status=HealthStatus.ERROR, message=f"FAISS health check failed: {str(e)}")
async def patched_list_providers(self):
"""Patched version of list_providers to include FAISS health check."""
@ -44,7 +51,9 @@ async def patched_list_providers(self):
logger.debug("Updated FAISS health to: %s", provider.health)
return response
new_list_providers = patched_list_providers
# Apply the patch by replacing the original method with patched version
ProviderImpl.list_providers = new_list_providers
# Added type: ignore because mypy cannot infer the correct type
# The typing error doesn't affect runtime behavior - it's only a static type check warning
ProviderImpl.list_providers = patched_list_providers # type: ignore
logger.debug("Successfully applied patch for FAISS provider health check")

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Unit tests for the FAISS provider health check implementation via provider patch.
"""
@ -54,7 +60,7 @@ class TestFaissProviderPatch(unittest.TestCase):
]
with patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers
self.original_list_providers,
):
result = await patched_list_providers(self.provider_impl)
@ -68,21 +74,19 @@ class TestFaissProviderPatch(unittest.TestCase):
"""Test the patched_list_providers method when a FAISS provider is found."""
# Create a mock FAISS provider
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
mock_faiss_provider.health = MagicMock(
return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
)
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
# Set up the mock response with a FAISS provider
self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"),
mock_faiss_provider,
]
with patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers
), \
with (
patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health"
) as mock_health:
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers,
),
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
):
mock_health.return_value = HealthResponse(status=HealthStatus.OK)
result = await patched_list_providers(self.provider_impl)
self.assertEqual(result, self.mock_response)
@ -96,24 +100,20 @@ class TestFaissProviderPatch(unittest.TestCase):
async def test_patched_list_providers_with_faiss_health_failure(self):
"""Test the patched_list_providers method when the FAISS health check fails."""
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
mock_faiss_provider.health = MagicMock(
return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
)
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"),
mock_faiss_provider,
]
with patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers), \
with (
patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health"
) as mock_health:
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers,
),
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
):
# Configure the mock health check to simulate a failure
error_response = HealthResponse(
status=HealthStatus.ERROR,
message="FAISS health check failed: Test error"
)
error_response = HealthResponse(status=HealthStatus.ERROR, message="FAISS health check failed: Test error")
mock_health.return_value = error_response
result = await patched_list_providers(self.provider_impl)
@ -124,28 +124,24 @@ class TestFaissProviderPatch(unittest.TestCase):
for provider in result.data:
if provider.provider_id == "faiss" and provider.api == "vector_io":
self.assertEqual(provider.health.status, HealthStatus.ERROR)
self.assertEqual(
provider.health.message, "FAISS health check failed: Test error"
)
self.assertEqual(provider.health.message, "FAISS health check failed: Test error")
async def test_patched_list_providers_with_exception(self):
"""Test the patched_list_providers method when an exception occurs during health check."""
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
mock_faiss_provider.health = MagicMock(
return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
)
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"),
mock_faiss_provider,
]
with patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers
), \
patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health"
) as mock_health:
with (
patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers,
),
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
):
# Configure the mock health check to raise an exception
mock_health.side_effect = Exception("Unexpected error")
result = await patched_list_providers(self.provider_impl)