mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
add health check for faiss inline vector_io provider
This commit is contained in:
parent
436c7aa751
commit
13c3bcc275
2 changed files with 52 additions and 0 deletions
|
@ -25,6 +25,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorIO,
|
||||
)
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.faiss.provider_patch import * # noqa: F403
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
|
||||
"""
|
||||
Patch for the provider impl to fix the health check for the FAISS provider.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import faiss
|
||||
|
||||
from llama_stack.distribution.providers import ProviderImpl
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger("faiss_provider_patch")
|
||||
|
||||
# Store the original methods
|
||||
original_list_providers = ProviderImpl.list_providers
|
||||
|
||||
VECTOR_DIMENSION = 128 # sample dimension
|
||||
|
||||
# Helper method to check FAISS health directly
|
||||
async def check_faiss_health():
|
||||
"""Check the health of the FAISS vector database directly."""
|
||||
try:
|
||||
# Create FAISS index to verify readiness
|
||||
faiss.IndexFlatL2(VECTOR_DIMENSION)
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except Exception as e:
|
||||
return HealthResponse(
|
||||
status=HealthStatus.ERROR,
|
||||
message=f"FAISS health check failed: {str(e)}"
|
||||
)
|
||||
|
||||
async def patched_list_providers(self):
|
||||
"""Patched version of list_providers to include FAISS health check."""
|
||||
logger.debug("Using patched list_providers method")
|
||||
# Get the original response
|
||||
response = await original_list_providers(self)
|
||||
# To find the FAISS provider in the response
|
||||
for provider in response.data:
|
||||
if provider.provider_id == "faiss" and provider.api == "vector_io":
|
||||
health_result = await check_faiss_health()
|
||||
logger.debug("FAISS health check result: %s", health_result)
|
||||
provider.health = health_result
|
||||
logger.debug("Updated FAISS health to: %s", provider.health)
|
||||
return response
|
||||
|
||||
new_list_providers = patched_list_providers
|
||||
# Apply the patch by replacing the original method with patched version
|
||||
ProviderImpl.list_providers = new_list_providers
|
||||
logger.debug("Successfully applied patch for FAISS provider health check")
|
Loading…
Add table
Add a link
Reference in a new issue