fix faiss db health check correctly via review

This commit is contained in:
Sumit Jaiswal 2025-06-16 23:17:56 +05:30
parent 47373a65cf
commit be6ef8bc17
No known key found for this signature in database
GPG key ID: A4604B39D64D6AEC
5 changed files with 113 additions and 227 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from llama_stack.apis.common.content_types import (
@ -21,7 +22,7 @@ from llama_stack.apis.vector_io import (
)
from llama_stack.apis.vector_io.vector_io import VectorStoreChunkingStrategy, VectorStoreFileObject
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core")
@ -272,3 +273,26 @@ class VectorIORouter(VectorIO):
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def health(self) -> dict[str, HealthResponse]:
health_statuses = {}
timeout = 1 # increasing the timeout to 1 second for health checks
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
try:
# check if the provider has a health method
if not hasattr(impl, "health"):
continue
health = await asyncio.wait_for(impl.health(), timeout=timeout)
health_statuses[provider_id] = health
except (asyncio.TimeoutError, TimeoutError):
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check timed out after {timeout} seconds",
)
except NotImplementedError:
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
except Exception as e:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
)
return health_statuses

View file

@ -24,8 +24,11 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.faiss.provider_patch import * # noqa: F403
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
VectorDBsProtocolPrivate,
)
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -42,6 +45,7 @@ VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
VECTOR_DIMENSION = 128 # sample dimension
class FaissIndex(EmbeddingIndex):
@ -176,6 +180,21 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
# Cleanup if needed
pass
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the inline faiss DB.
This method is used by the Provider API to verify
that the service is running correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
faiss.IndexFlatL2(VECTOR_DIMENSION)
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def register_vector_db(
self,
vector_db: VectorDB,

View file

@ -1,59 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Patch for the provider impl to fix the health check for the FAISS provider.
It is the workaround fix with current implementation if place for get_providers_health
as it returns a dict mapping API names to a single health response, but list_providers
expects a dict mapping API names to a dict of provider IDs to health responses.
"""
import logging
import faiss
from llama_stack.distribution.providers import ProviderImpl
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
logger = logging.getLogger(__name__)
# Store the original methods
original_list_providers = ProviderImpl.list_providers
VECTOR_DIMENSION = 128 # sample dimension
# Helper method to check FAISS health directly
async def check_faiss_health():
"""Check the health of the FAISS vector database directly."""
try:
# Create FAISS index to verify readiness
faiss.IndexFlatL2(VECTOR_DIMENSION)
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"FAISS health check failed: {str(e)}")
async def patched_list_providers(self):
"""Patched version of list_providers to include FAISS health check."""
logger.debug("Using patched list_providers method")
# Get the original response
response = await original_list_providers(self)
# To find the FAISS provider in the response
for provider in response.data:
if provider.provider_id == "faiss" and provider.api == "vector_io":
health_result = await check_faiss_health()
logger.debug("FAISS health check result: %s", health_result)
provider.health = health_result
logger.debug("Updated FAISS health to: %s", provider.health)
return response
# Apply the patch by replacing the original method with patched version
# Added type: ignore because mypy cannot infer the correct type
# The typing error doesn't affect runtime behavior - it's only a static type check warning
ProviderImpl.list_providers = patched_list_providers # type: ignore
logger.debug("Successfully applied patch for FAISS provider health check")