chore: fix flaky unit test and add proper shutdown for file batches (#3725)

# What does this PR do?
Have been running into flaky unit test failures:
5217035494
Fixing below
1. Shutting down properly by cancelling any stale file batches tasks
running in background.
2. Also, use unique_kvstore_config, so the test dont use same db path
and maintain test isolation
## Test Plan
Ran unit test locally and CI
This commit is contained in:
slekkala1 2025-10-07 14:23:14 -07:00 committed by GitHub
parent 1970b4aa4b
commit c2d97a9db9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 39 additions and 17 deletions

View file

@ -225,8 +225,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
await self.initialize_openai_vector_stores() await self.initialize_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
# Cleanup if needed # Clean up mixin resources (file batch tasks)
pass await super().shutdown()
async def health(self) -> HealthResponse: async def health(self) -> HealthResponse:
""" """

View file

@ -434,8 +434,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
await self.initialize_openai_vector_stores() await self.initialize_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection # Clean up mixin resources (file batch tasks)
pass await super().shutdown()
async def list_vector_dbs(self) -> list[VectorDB]: async def list_vector_dbs(self) -> list[VectorDB]:
return [v.vector_db for v in self.cache.values()] return [v.vector_db for v in self.cache.values()]

View file

@ -167,7 +167,8 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.openai_vector_stores = await self._load_openai_vector_stores() self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass # Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db( async def register_vector_db(
self, self,

View file

@ -349,6 +349,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.client.close() self.client.close()
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db( async def register_vector_db(
self, self,

View file

@ -390,6 +390,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
if self.conn is not None: if self.conn is not None:
self.conn.close() self.conn.close()
log.info("Connection to PGVector database server closed") log.info("Connection to PGVector database server closed")
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(self, vector_db: VectorDB) -> None: async def register_vector_db(self, vector_db: VectorDB) -> None:
# Persist vector DB metadata in the KV store # Persist vector DB metadata in the KV store

View file

@ -191,6 +191,8 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def shutdown(self) -> None: async def shutdown(self) -> None:
await self.client.close() await self.client.close()
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db( async def register_vector_db(
self, self,

View file

@ -347,6 +347,8 @@ class WeaviateVectorIOAdapter(
async def shutdown(self) -> None: async def shutdown(self) -> None:
for client in self.client_cache.values(): for client in self.client_cache.values():
client.close() client.close()
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db( async def register_vector_db(
self, self,

View file

@ -293,6 +293,19 @@ class OpenAIVectorStoreMixin(ABC):
await self._resume_incomplete_batches() await self._resume_incomplete_batches()
self._last_file_batch_cleanup_time = 0 self._last_file_batch_cleanup_time = 0
async def shutdown(self) -> None:
"""Clean up mixin resources including background tasks."""
# Cancel any running file batch tasks gracefully
if hasattr(self, "_file_batch_tasks"):
tasks_to_cancel = list(self._file_batch_tasks.items())
for _, task in tasks_to_cancel:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@abstractmethod @abstractmethod
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a vector store.""" """Delete chunks from a vector store."""

View file

@ -145,10 +145,10 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
@pytest.fixture @pytest.fixture
async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension): async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
config = SQLiteVectorIOConfig( config = SQLiteVectorIOConfig(
db_path=sqlite_vec_db_path, db_path=sqlite_vec_db_path,
kvstore=SqliteKVStoreConfig(), kvstore=unique_kvstore_config,
) )
adapter = SQLiteVecVectorIOAdapter( adapter = SQLiteVecVectorIOAdapter(
config=config, config=config,
@ -187,10 +187,10 @@ async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
@pytest.fixture @pytest.fixture
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api): async def milvus_vec_adapter(milvus_vec_db_path, unique_kvstore_config, mock_inference_api):
config = MilvusVectorIOConfig( config = MilvusVectorIOConfig(
db_path=milvus_vec_db_path, db_path=milvus_vec_db_path,
kvstore=SqliteKVStoreConfig(), kvstore=unique_kvstore_config,
) )
adapter = MilvusVectorIOAdapter( adapter = MilvusVectorIOAdapter(
config=config, config=config,
@ -264,10 +264,10 @@ async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
@pytest.fixture @pytest.fixture
async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension): async def chroma_vec_adapter(chroma_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
config = ChromaVectorIOConfig( config = ChromaVectorIOConfig(
db_path=chroma_vec_db_path, db_path=chroma_vec_db_path,
kvstore=SqliteKVStoreConfig(), kvstore=unique_kvstore_config,
) )
adapter = ChromaVectorIOAdapter( adapter = ChromaVectorIOAdapter(
config=config, config=config,
@ -296,12 +296,12 @@ def qdrant_vec_db_path(tmp_path_factory):
@pytest.fixture @pytest.fixture
async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension): async def qdrant_vec_adapter(qdrant_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
import uuid import uuid
config = QdrantVectorIOConfig( config = QdrantVectorIOConfig(
db_path=qdrant_vec_db_path, db_path=qdrant_vec_db_path,
kvstore=SqliteKVStoreConfig(), kvstore=unique_kvstore_config,
) )
adapter = QdrantVectorIOAdapter( adapter = QdrantVectorIOAdapter(
config=config, config=config,
@ -386,14 +386,14 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
@pytest.fixture @pytest.fixture
async def pgvector_vec_adapter(mock_inference_api, embedding_dimension): async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
config = PGVectorVectorIOConfig( config = PGVectorVectorIOConfig(
host="localhost", host="localhost",
port=5432, port=5432,
db="test_db", db="test_db",
user="test_user", user="test_user",
password="test_password", password="test_password",
kvstore=SqliteKVStoreConfig(), kvstore=unique_kvstore_config,
) )
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None) adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
@ -476,7 +476,7 @@ async def weaviate_vec_index(weaviate_vec_db_path):
@pytest.fixture @pytest.fixture
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension): async def weaviate_vec_adapter(weaviate_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
import pytest_socket import pytest_socket
import weaviate import weaviate
@ -492,7 +492,7 @@ async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embeddi
config = WeaviateVectorIOConfig( config = WeaviateVectorIOConfig(
weaviate_cluster_url="localhost:8080", weaviate_cluster_url="localhost:8080",
weaviate_api_key=None, weaviate_api_key=None,
kvstore=SqliteKVStoreConfig(), kvstore=unique_kvstore_config,
) )
adapter = WeaviateVectorIOAdapter( adapter = WeaviateVectorIOAdapter(
config=config, config=config,