mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
updates to fix pre-commits check
This commit is contained in:
parent
24707a1173
commit
47373a65cf
2 changed files with 46 additions and 41 deletions
|
@ -1,6 +1,14 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Patch for the provider impl to fix the health check for the FAISS provider.
|
||||
It is the workaround fix with current implementation if place for get_providers_health
|
||||
as it returns a dict mapping API names to a single health response, but list_providers
|
||||
expects a dict mapping API names to a dict of provider IDs to health responses.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
@ -17,6 +25,7 @@ original_list_providers = ProviderImpl.list_providers
|
|||
|
||||
VECTOR_DIMENSION = 128 # sample dimension
|
||||
|
||||
|
||||
# Helper method to check FAISS health directly
|
||||
async def check_faiss_health():
|
||||
"""Check the health of the FAISS vector database directly."""
|
||||
|
@ -25,10 +34,8 @@ async def check_faiss_health():
|
|||
faiss.IndexFlatL2(VECTOR_DIMENSION)
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except Exception as e:
|
||||
return HealthResponse(
|
||||
status=HealthStatus.ERROR,
|
||||
message=f"FAISS health check failed: {str(e)}"
|
||||
)
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"FAISS health check failed: {str(e)}")
|
||||
|
||||
|
||||
async def patched_list_providers(self):
|
||||
"""Patched version of list_providers to include FAISS health check."""
|
||||
|
@ -44,7 +51,9 @@ async def patched_list_providers(self):
|
|||
logger.debug("Updated FAISS health to: %s", provider.health)
|
||||
return response
|
||||
|
||||
new_list_providers = patched_list_providers
|
||||
|
||||
# Apply the patch by replacing the original method with patched version
|
||||
ProviderImpl.list_providers = new_list_providers
|
||||
# Added type: ignore because mypy cannot infer the correct type
|
||||
# The typing error doesn't affect runtime behavior - it's only a static type check warning
|
||||
ProviderImpl.list_providers = patched_list_providers # type: ignore
|
||||
logger.debug("Successfully applied patch for FAISS provider health check")
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Unit tests for the FAISS provider health check implementation via provider patch.
|
||||
"""
|
||||
|
@ -54,7 +60,7 @@ class TestFaissProviderPatch(unittest.TestCase):
|
|||
]
|
||||
with patch(
|
||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
|
||||
self.original_list_providers
|
||||
self.original_list_providers,
|
||||
):
|
||||
result = await patched_list_providers(self.provider_impl)
|
||||
|
||||
|
@ -68,21 +74,19 @@ class TestFaissProviderPatch(unittest.TestCase):
|
|||
"""Test the patched_list_providers method when a FAISS provider is found."""
|
||||
# Create a mock FAISS provider
|
||||
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
|
||||
mock_faiss_provider.health = MagicMock(
|
||||
return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
||||
)
|
||||
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
|
||||
# Set up the mock response with a FAISS provider
|
||||
self.mock_response.data = [
|
||||
MagicMock(provider_id="other", api="vector_io"),
|
||||
mock_faiss_provider,
|
||||
]
|
||||
with patch(
|
||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
|
||||
self.original_list_providers
|
||||
), \
|
||||
with (
|
||||
patch(
|
||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health"
|
||||
) as mock_health:
|
||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
|
||||
self.original_list_providers,
|
||||
),
|
||||
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
|
||||
):
|
||||
mock_health.return_value = HealthResponse(status=HealthStatus.OK)
|
||||
result = await patched_list_providers(self.provider_impl)
|
||||
self.assertEqual(result, self.mock_response)
|
||||
|
@ -96,24 +100,20 @@ class TestFaissProviderPatch(unittest.TestCase):
|
|||
async def test_patched_list_providers_with_faiss_health_failure(self):
|
||||
"""Test the patched_list_providers method when the FAISS health check fails."""
|
||||
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
|
||||
mock_faiss_provider.health = MagicMock(
|
||||
return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
||||
)
|
||||
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
|
||||
self.mock_response.data = [
|
||||
MagicMock(provider_id="other", api="vector_io"),
|
||||
mock_faiss_provider,
|
||||
]
|
||||
with patch(
|
||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
|
||||
self.original_list_providers), \
|
||||
with (
|
||||
patch(
|
||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health"
|
||||
) as mock_health:
|
||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
|
||||
self.original_list_providers,
|
||||
),
|
||||
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
|
||||
):
|
||||
# Configure the mock health check to simulate a failure
|
||||
error_response = HealthResponse(
|
||||
status=HealthStatus.ERROR,
|
||||
message="FAISS health check failed: Test error"
|
||||
)
|
||||
error_response = HealthResponse(status=HealthStatus.ERROR, message="FAISS health check failed: Test error")
|
||||
mock_health.return_value = error_response
|
||||
|
||||
result = await patched_list_providers(self.provider_impl)
|
||||
|
@ -124,28 +124,24 @@ class TestFaissProviderPatch(unittest.TestCase):
|
|||
for provider in result.data:
|
||||
if provider.provider_id == "faiss" and provider.api == "vector_io":
|
||||
self.assertEqual(provider.health.status, HealthStatus.ERROR)
|
||||
self.assertEqual(
|
||||
provider.health.message, "FAISS health check failed: Test error"
|
||||
)
|
||||
self.assertEqual(provider.health.message, "FAISS health check failed: Test error")
|
||||
|
||||
async def test_patched_list_providers_with_exception(self):
|
||||
"""Test the patched_list_providers method when an exception occurs during health check."""
|
||||
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
|
||||
mock_faiss_provider.health = MagicMock(
|
||||
return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
||||
)
|
||||
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
|
||||
|
||||
self.mock_response.data = [
|
||||
MagicMock(provider_id="other", api="vector_io"),
|
||||
mock_faiss_provider,
|
||||
]
|
||||
with patch(
|
||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
|
||||
self.original_list_providers
|
||||
), \
|
||||
with (
|
||||
patch(
|
||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health"
|
||||
) as mock_health:
|
||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
|
||||
self.original_list_providers,
|
||||
),
|
||||
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
|
||||
):
|
||||
# Configure the mock health check to raise an exception
|
||||
mock_health.side_effect = Exception("Unexpected error")
|
||||
result = await patched_list_providers(self.provider_impl)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue