From c2d97a9db95965d4272c4bf4ee2a70f57637e456 Mon Sep 17 00:00:00 2001 From: slekkala1 Date: Tue, 7 Oct 2025 14:23:14 -0700 Subject: [PATCH] chore: fix flaky unit test and add proper shutdown for file batches (#3725) # What does this PR do? Have been running into flaky unit test failures: https://github.com/llamastack/llama-stack/actions/runs/18319987543/job/52170354944?pr=3711 Fixing below 1. Shutting down properly by cancelling any stale file batches tasks running in background. 2. Also, use unique_kvstore_config, so the test dont use same db path and maintain test isolation ## Test Plan Ran unit test locally and CI --- .../providers/inline/vector_io/faiss/faiss.py | 4 ++-- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 4 ++-- .../remote/vector_io/chroma/chroma.py | 3 ++- .../remote/vector_io/milvus/milvus.py | 2 ++ .../remote/vector_io/pgvector/pgvector.py | 2 ++ .../remote/vector_io/qdrant/qdrant.py | 2 ++ .../remote/vector_io/weaviate/weaviate.py | 2 ++ .../utils/memory/openai_vector_store_mixin.py | 13 ++++++++++ tests/unit/providers/vector_io/conftest.py | 24 +++++++++---------- 9 files changed, 39 insertions(+), 17 deletions(-) diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 405c134e5..5a456c7c9 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -225,8 +225,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr await self.initialize_openai_vector_stores() async def shutdown(self) -> None: - # Cleanup if needed - pass + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def health(self) -> HealthResponse: """ diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 26231a9b7..a433257b2 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -434,8 +434,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc await self.initialize_openai_vector_stores() async def shutdown(self) -> None: - # nothing to do since we don't maintain a persistent connection - pass + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def list_vector_dbs(self) -> list[VectorDB]: return [v.vector_db for v in self.cache.values()] diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 511123d6e..331e5432e 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -167,7 +167,8 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.openai_vector_stores = await self._load_openai_vector_stores() async def shutdown(self) -> None: - pass + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def register_vector_db( self, diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 0acc90595..029eacfe3 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -349,6 +349,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP async def shutdown(self) -> None: self.client.close() + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def register_vector_db( self, diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index dfdfef6eb..21c388b1d 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -390,6 +390,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco if self.conn is not None: self.conn.close() log.info("Connection to PGVector database server closed") + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def register_vector_db(self, vector_db: VectorDB) -> None: # Persist vector DB metadata in the KV store diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 6b386840c..021938afd 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -191,6 +191,8 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP async def shutdown(self) -> None: await self.client.close() + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def register_vector_db( self, diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 54ac6f8d3..21df3bc45 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -347,6 +347,8 @@ class WeaviateVectorIOAdapter( async def shutdown(self) -> None: for client in self.client_cache.values(): client.close() + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def register_vector_db( self, diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 97079c3b3..2a5177f93 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -293,6 +293,19 @@ class OpenAIVectorStoreMixin(ABC): await self._resume_incomplete_batches() self._last_file_batch_cleanup_time = 0 + async def shutdown(self) -> None: + """Clean up mixin resources including background tasks.""" + # Cancel any running file batch tasks gracefully + if hasattr(self, "_file_batch_tasks"): + tasks_to_cancel = list(self._file_batch_tasks.items()) + for _, task in tasks_to_cancel: + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + @abstractmethod async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Delete chunks from a vector store.""" diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 70ace695e..d122f9323 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -145,10 +145,10 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory): @pytest.fixture -async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension): +async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension): config = SQLiteVectorIOConfig( db_path=sqlite_vec_db_path, - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = SQLiteVecVectorIOAdapter( config=config, @@ -187,10 +187,10 @@ async def milvus_vec_index(milvus_vec_db_path, embedding_dimension): @pytest.fixture -async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api): +async def milvus_vec_adapter(milvus_vec_db_path, unique_kvstore_config, mock_inference_api): config = MilvusVectorIOConfig( db_path=milvus_vec_db_path, - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = MilvusVectorIOAdapter( config=config, @@ -264,10 +264,10 @@ async def chroma_vec_index(chroma_vec_db_path, embedding_dimension): @pytest.fixture -async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension): +async def chroma_vec_adapter(chroma_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension): config = ChromaVectorIOConfig( db_path=chroma_vec_db_path, - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = ChromaVectorIOAdapter( config=config, @@ -296,12 +296,12 @@ def qdrant_vec_db_path(tmp_path_factory): @pytest.fixture -async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension): +async def qdrant_vec_adapter(qdrant_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension): import uuid config = QdrantVectorIOConfig( db_path=qdrant_vec_db_path, - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = QdrantVectorIOAdapter( config=config, @@ -386,14 +386,14 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection): @pytest.fixture -async def pgvector_vec_adapter(mock_inference_api, embedding_dimension): +async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension): config = PGVectorVectorIOConfig( host="localhost", port=5432, db="test_db", user="test_user", password="test_password", - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None) @@ -476,7 +476,7 @@ async def weaviate_vec_index(weaviate_vec_db_path): @pytest.fixture -async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension): +async def weaviate_vec_adapter(weaviate_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension): import pytest_socket import weaviate @@ -492,7 +492,7 @@ async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embeddi config = WeaviateVectorIOConfig( weaviate_cluster_url="localhost:8080", weaviate_api_key=None, - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = WeaviateVectorIOAdapter( config=config,