updates to fix pre-commits check

This commit is contained in:
Sumit Jaiswal 2025-06-01 17:35:45 +05:30
parent 24707a1173
commit 47373a65cf
No known key found for this signature in database
GPG key ID: A4604B39D64D6AEC
2 changed files with 46 additions and 41 deletions

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Unit tests for the FAISS provider health check implementation via provider patch.
"""
@ -54,7 +60,7 @@ class TestFaissProviderPatch(unittest.TestCase):
]
with patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers
self.original_list_providers,
):
result = await patched_list_providers(self.provider_impl)
@ -68,21 +74,19 @@ class TestFaissProviderPatch(unittest.TestCase):
"""Test the patched_list_providers method when a FAISS provider is found."""
# Create a mock FAISS provider
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
mock_faiss_provider.health = MagicMock(
return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
)
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
# Set up the mock response with a FAISS provider
self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"),
mock_faiss_provider,
]
with patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers
), \
with (
patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health"
) as mock_health:
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers,
),
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
):
mock_health.return_value = HealthResponse(status=HealthStatus.OK)
result = await patched_list_providers(self.provider_impl)
self.assertEqual(result, self.mock_response)
@ -96,24 +100,20 @@ class TestFaissProviderPatch(unittest.TestCase):
async def test_patched_list_providers_with_faiss_health_failure(self):
"""Test the patched_list_providers method when the FAISS health check fails."""
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
mock_faiss_provider.health = MagicMock(
return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
)
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"),
mock_faiss_provider,
]
with patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers), \
with (
patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health"
) as mock_health:
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers,
),
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
):
# Configure the mock health check to simulate a failure
error_response = HealthResponse(
status=HealthStatus.ERROR,
message="FAISS health check failed: Test error"
)
error_response = HealthResponse(status=HealthStatus.ERROR, message="FAISS health check failed: Test error")
mock_health.return_value = error_response
result = await patched_list_providers(self.provider_impl)
@ -124,28 +124,24 @@ class TestFaissProviderPatch(unittest.TestCase):
for provider in result.data:
if provider.provider_id == "faiss" and provider.api == "vector_io":
self.assertEqual(provider.health.status, HealthStatus.ERROR)
self.assertEqual(
provider.health.message, "FAISS health check failed: Test error"
)
self.assertEqual(provider.health.message, "FAISS health check failed: Test error")
async def test_patched_list_providers_with_exception(self):
"""Test the patched_list_providers method when an exception occurs during health check."""
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
mock_faiss_provider.health = MagicMock(
return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
)
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
self.mock_response.data = [
MagicMock(provider_id="other", api="vector_io"),
mock_faiss_provider,
]
with patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers
), \
patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health"
) as mock_health:
with (
patch(
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
self.original_list_providers,
),
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
):
# Configure the mock health check to raise an exception
mock_health.side_effect = Exception("Unexpected error")
result = await patched_list_providers(self.provider_impl)