mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
updates to fix pre-commits check
This commit is contained in:
parent
24707a1173
commit
47373a65cf
2 changed files with 46 additions and 41 deletions
|
@ -1,6 +1,14 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Patch for the provider impl to fix the health check for the FAISS provider.
|
Patch for the provider impl to fix the health check for the FAISS provider.
|
||||||
|
It is the workaround fix with current implementation if place for get_providers_health
|
||||||
|
as it returns a dict mapping API names to a single health response, but list_providers
|
||||||
|
expects a dict mapping API names to a dict of provider IDs to health responses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -17,6 +25,7 @@ original_list_providers = ProviderImpl.list_providers
|
||||||
|
|
||||||
VECTOR_DIMENSION = 128 # sample dimension
|
VECTOR_DIMENSION = 128 # sample dimension
|
||||||
|
|
||||||
|
|
||||||
# Helper method to check FAISS health directly
|
# Helper method to check FAISS health directly
|
||||||
async def check_faiss_health():
|
async def check_faiss_health():
|
||||||
"""Check the health of the FAISS vector database directly."""
|
"""Check the health of the FAISS vector database directly."""
|
||||||
|
@ -25,10 +34,8 @@ async def check_faiss_health():
|
||||||
faiss.IndexFlatL2(VECTOR_DIMENSION)
|
faiss.IndexFlatL2(VECTOR_DIMENSION)
|
||||||
return HealthResponse(status=HealthStatus.OK)
|
return HealthResponse(status=HealthStatus.OK)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return HealthResponse(
|
return HealthResponse(status=HealthStatus.ERROR, message=f"FAISS health check failed: {str(e)}")
|
||||||
status=HealthStatus.ERROR,
|
|
||||||
message=f"FAISS health check failed: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def patched_list_providers(self):
|
async def patched_list_providers(self):
|
||||||
"""Patched version of list_providers to include FAISS health check."""
|
"""Patched version of list_providers to include FAISS health check."""
|
||||||
|
@ -44,7 +51,9 @@ async def patched_list_providers(self):
|
||||||
logger.debug("Updated FAISS health to: %s", provider.health)
|
logger.debug("Updated FAISS health to: %s", provider.health)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
new_list_providers = patched_list_providers
|
|
||||||
# Apply the patch by replacing the original method with patched version
|
# Apply the patch by replacing the original method with patched version
|
||||||
ProviderImpl.list_providers = new_list_providers
|
# Added type: ignore because mypy cannot infer the correct type
|
||||||
|
# The typing error doesn't affect runtime behavior - it's only a static type check warning
|
||||||
|
ProviderImpl.list_providers = patched_list_providers # type: ignore
|
||||||
logger.debug("Successfully applied patch for FAISS provider health check")
|
logger.debug("Successfully applied patch for FAISS provider health check")
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Unit tests for the FAISS provider health check implementation via provider patch.
|
Unit tests for the FAISS provider health check implementation via provider patch.
|
||||||
"""
|
"""
|
||||||
|
@ -54,7 +60,7 @@ class TestFaissProviderPatch(unittest.TestCase):
|
||||||
]
|
]
|
||||||
with patch(
|
with patch(
|
||||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
|
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
|
||||||
self.original_list_providers
|
self.original_list_providers,
|
||||||
):
|
):
|
||||||
result = await patched_list_providers(self.provider_impl)
|
result = await patched_list_providers(self.provider_impl)
|
||||||
|
|
||||||
|
@ -68,21 +74,19 @@ class TestFaissProviderPatch(unittest.TestCase):
|
||||||
"""Test the patched_list_providers method when a FAISS provider is found."""
|
"""Test the patched_list_providers method when a FAISS provider is found."""
|
||||||
# Create a mock FAISS provider
|
# Create a mock FAISS provider
|
||||||
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
|
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
|
||||||
mock_faiss_provider.health = MagicMock(
|
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
|
||||||
return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
|
||||||
)
|
|
||||||
# Set up the mock response with a FAISS provider
|
# Set up the mock response with a FAISS provider
|
||||||
self.mock_response.data = [
|
self.mock_response.data = [
|
||||||
MagicMock(provider_id="other", api="vector_io"),
|
MagicMock(provider_id="other", api="vector_io"),
|
||||||
mock_faiss_provider,
|
mock_faiss_provider,
|
||||||
]
|
]
|
||||||
with patch(
|
with (
|
||||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
|
|
||||||
self.original_list_providers
|
|
||||||
), \
|
|
||||||
patch(
|
patch(
|
||||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health"
|
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
|
||||||
) as mock_health:
|
self.original_list_providers,
|
||||||
|
),
|
||||||
|
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
|
||||||
|
):
|
||||||
mock_health.return_value = HealthResponse(status=HealthStatus.OK)
|
mock_health.return_value = HealthResponse(status=HealthStatus.OK)
|
||||||
result = await patched_list_providers(self.provider_impl)
|
result = await patched_list_providers(self.provider_impl)
|
||||||
self.assertEqual(result, self.mock_response)
|
self.assertEqual(result, self.mock_response)
|
||||||
|
@ -96,24 +100,20 @@ class TestFaissProviderPatch(unittest.TestCase):
|
||||||
async def test_patched_list_providers_with_faiss_health_failure(self):
|
async def test_patched_list_providers_with_faiss_health_failure(self):
|
||||||
"""Test the patched_list_providers method when the FAISS health check fails."""
|
"""Test the patched_list_providers method when the FAISS health check fails."""
|
||||||
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
|
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
|
||||||
mock_faiss_provider.health = MagicMock(
|
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
|
||||||
return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
|
||||||
)
|
|
||||||
self.mock_response.data = [
|
self.mock_response.data = [
|
||||||
MagicMock(provider_id="other", api="vector_io"),
|
MagicMock(provider_id="other", api="vector_io"),
|
||||||
mock_faiss_provider,
|
mock_faiss_provider,
|
||||||
]
|
]
|
||||||
with patch(
|
with (
|
||||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
|
|
||||||
self.original_list_providers), \
|
|
||||||
patch(
|
patch(
|
||||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health"
|
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
|
||||||
) as mock_health:
|
self.original_list_providers,
|
||||||
|
),
|
||||||
|
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
|
||||||
|
):
|
||||||
# Configure the mock health check to simulate a failure
|
# Configure the mock health check to simulate a failure
|
||||||
error_response = HealthResponse(
|
error_response = HealthResponse(status=HealthStatus.ERROR, message="FAISS health check failed: Test error")
|
||||||
status=HealthStatus.ERROR,
|
|
||||||
message="FAISS health check failed: Test error"
|
|
||||||
)
|
|
||||||
mock_health.return_value = error_response
|
mock_health.return_value = error_response
|
||||||
|
|
||||||
result = await patched_list_providers(self.provider_impl)
|
result = await patched_list_providers(self.provider_impl)
|
||||||
|
@ -124,28 +124,24 @@ class TestFaissProviderPatch(unittest.TestCase):
|
||||||
for provider in result.data:
|
for provider in result.data:
|
||||||
if provider.provider_id == "faiss" and provider.api == "vector_io":
|
if provider.provider_id == "faiss" and provider.api == "vector_io":
|
||||||
self.assertEqual(provider.health.status, HealthStatus.ERROR)
|
self.assertEqual(provider.health.status, HealthStatus.ERROR)
|
||||||
self.assertEqual(
|
self.assertEqual(provider.health.message, "FAISS health check failed: Test error")
|
||||||
provider.health.message, "FAISS health check failed: Test error"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_patched_list_providers_with_exception(self):
|
async def test_patched_list_providers_with_exception(self):
|
||||||
"""Test the patched_list_providers method when an exception occurs during health check."""
|
"""Test the patched_list_providers method when an exception occurs during health check."""
|
||||||
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
|
mock_faiss_provider = MagicMock(provider_id="faiss", api="vector_io")
|
||||||
mock_faiss_provider.health = MagicMock(
|
mock_faiss_provider.health = MagicMock(return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED))
|
||||||
return_value=HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.mock_response.data = [
|
self.mock_response.data = [
|
||||||
MagicMock(provider_id="other", api="vector_io"),
|
MagicMock(provider_id="other", api="vector_io"),
|
||||||
mock_faiss_provider,
|
mock_faiss_provider,
|
||||||
]
|
]
|
||||||
with patch(
|
with (
|
||||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
|
patch(
|
||||||
self.original_list_providers
|
"llama_stack.providers.inline.vector_io.faiss.provider_patch.original_list_providers",
|
||||||
), \
|
self.original_list_providers,
|
||||||
patch(
|
),
|
||||||
"llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health"
|
patch("llama_stack.providers.inline.vector_io.faiss.provider_patch.check_faiss_health") as mock_health,
|
||||||
) as mock_health:
|
):
|
||||||
# Configure the mock health check to raise an exception
|
# Configure the mock health check to raise an exception
|
||||||
mock_health.side_effect = Exception("Unexpected error")
|
mock_health.side_effect = Exception("Unexpected error")
|
||||||
result = await patched_list_providers(self.provider_impl)
|
result = await patched_list_providers(self.provider_impl)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue