add concurrent file attaching

This commit is contained in:
Swapna Lekkala 2025-10-03 11:01:19 -07:00
parent a5b71c2cc7
commit e58bf82581
9 changed files with 188 additions and 33 deletions

View file

@ -207,6 +207,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
self.kvstore: KVStore | None = None self.kvstore: KVStore | None = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {} self.openai_file_batches: dict[str, dict[str, Any]] = {}
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
self._last_file_batch_cleanup_time = 0 self._last_file_batch_cleanup_time = 0
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -416,6 +416,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
self.cache: dict[str, VectorDBWithIndex] = {} self.cache: dict[str, VectorDBWithIndex] = {}
self.openai_vector_stores: dict[str, dict[str, Any]] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {} self.openai_file_batches: dict[str, dict[str, Any]] = {}
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
self._last_file_batch_cleanup_time = 0 self._last_file_batch_cleanup_time = 0
self.kvstore: KVStore | None = None self.kvstore: KVStore | None = None

View file

@ -167,6 +167,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.client = chromadb.PersistentClient(path=self.config.db_path) self.client = chromadb.PersistentClient(path=self.config.db_path)
self.openai_vector_stores = await self._load_openai_vector_stores() self.openai_vector_stores = await self._load_openai_vector_stores()
self.openai_file_batches: dict[str, dict[str, Any]] = {} self.openai_file_batches: dict[str, dict[str, Any]] = {}
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
self._last_file_batch_cleanup_time = 0 self._last_file_batch_cleanup_time = 0
async def shutdown(self) -> None: async def shutdown(self) -> None:

View file

@ -318,6 +318,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.vector_db_store = None self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {} self.openai_file_batches: dict[str, dict[str, Any]] = {}
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
self._last_file_batch_cleanup_time = 0 self._last_file_batch_cleanup_time = 0
self.metadata_collection_name = "openai_vector_stores_metadata" self.metadata_collection_name = "openai_vector_stores_metadata"

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import heapq import heapq
from typing import Any from typing import Any
@ -354,6 +355,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
self.vector_db_store = None self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {} self.openai_file_batches: dict[str, dict[str, Any]] = {}
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
self._last_file_batch_cleanup_time = 0 self._last_file_batch_cleanup_time = 0
self.metadata_collection_name = "openai_vector_stores_metadata" self.metadata_collection_name = "openai_vector_stores_metadata"

View file

@ -171,6 +171,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.kvstore: KVStore | None = None self.kvstore: KVStore | None = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {} self.openai_file_batches: dict[str, dict[str, Any]] = {}
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
self._last_file_batch_cleanup_time = 0 self._last_file_batch_cleanup_time = 0
self._qdrant_lock = asyncio.Lock() self._qdrant_lock = asyncio.Lock()

View file

@ -3,6 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import json import json
from typing import Any from typing import Any
@ -293,6 +294,7 @@ class WeaviateVectorIOAdapter(
self.vector_db_store = None self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {} self.openai_file_batches: dict[str, dict[str, Any]] = {}
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
self._last_file_batch_cleanup_time = 0 self._last_file_batch_cleanup_time = 0
self.metadata_collection_name = "openai_vector_stores_metadata" self.metadata_collection_name = "openai_vector_stores_metadata"

View file

@ -53,6 +53,8 @@ logger = get_logger(name=__name__, category="providers::utils")
# Constants for OpenAI vector stores # Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5 CHUNK_MULTIPLIER = 5
FILE_BATCH_CLEANUP_INTERVAL_SECONDS = 24 * 60 * 60 # 1 day in seconds FILE_BATCH_CLEANUP_INTERVAL_SECONDS = 24 * 60 * 60 # 1 day in seconds
MAX_CONCURRENT_FILES_PER_BATCH = 5 # Maximum concurrent file processing within a batch
FILE_BATCH_CHUNK_SIZE = 10 # Process files in chunks of this size (2x concurrency)
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
@ -77,6 +79,8 @@ class OpenAIVectorStoreMixin(ABC):
kvstore: KVStore | None kvstore: KVStore | None
# Track last cleanup time to throttle cleanup operations # Track last cleanup time to throttle cleanup operations
_last_file_batch_cleanup_time: int _last_file_batch_cleanup_time: int
# Track running file batch processing tasks
_file_batch_tasks: dict[str, asyncio.Task[None]]
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to persistent storage.""" """Save vector store metadata to persistent storage."""
@ -224,12 +228,14 @@ class OpenAIVectorStoreMixin(ABC):
if batch_info["status"] == "in_progress": if batch_info["status"] == "in_progress":
logger.info(f"Resuming incomplete file batch: {batch_id}") logger.info(f"Resuming incomplete file batch: {batch_id}")
# Restart the background processing task # Restart the background processing task
asyncio.create_task(self._process_file_batch_async(batch_id, batch_info)) task = asyncio.create_task(self._process_file_batch_async(batch_id, batch_info))
self._file_batch_tasks[batch_id] = task
async def initialize_openai_vector_stores(self) -> None: async def initialize_openai_vector_stores(self) -> None:
"""Load existing OpenAI vector stores and file batches into the in-memory cache.""" """Load existing OpenAI vector stores and file batches into the in-memory cache."""
self.openai_vector_stores = await self._load_openai_vector_stores() self.openai_vector_stores = await self._load_openai_vector_stores()
self.openai_file_batches = await self._load_openai_vector_store_file_batches() self.openai_file_batches = await self._load_openai_vector_store_file_batches()
self._file_batch_tasks = {}
await self._resume_incomplete_batches() await self._resume_incomplete_batches()
self._last_file_batch_cleanup_time = 0 self._last_file_batch_cleanup_time = 0
@ -935,7 +941,8 @@ class OpenAIVectorStoreMixin(ABC):
await self._save_openai_vector_store_file_batch(batch_id, batch_info) await self._save_openai_vector_store_file_batch(batch_id, batch_info)
# Start background processing of files # Start background processing of files
asyncio.create_task(self._process_file_batch_async(batch_id, batch_info)) task = asyncio.create_task(self._process_file_batch_async(batch_id, batch_info))
self._file_batch_tasks[batch_id] = task
# Run cleanup if needed (throttled to once every 1 day) # Run cleanup if needed (throttled to once every 1 day)
current_time = int(time.time()) current_time = int(time.time())
@ -946,6 +953,72 @@ class OpenAIVectorStoreMixin(ABC):
return batch_object return batch_object
async def _process_files_with_concurrency(
self,
file_ids: list[str],
vector_store_id: str,
attributes: dict[str, Any],
chunking_strategy_obj: Any,
batch_id: str,
batch_info: dict[str, Any],
) -> None:
"""Process files with controlled concurrency and chunking."""
semaphore = asyncio.Semaphore(MAX_CONCURRENT_FILES_PER_BATCH)
async def process_single_file(file_id: str) -> tuple[str, bool]:
"""Process a single file with concurrency control."""
async with semaphore:
try:
await self.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy_obj,
)
return file_id, True
except Exception as e:
logger.error(f"Failed to process file {file_id} in batch {batch_id}: {e}")
return file_id, False
# Process files in chunks to avoid creating too many tasks at once
total_files = len(file_ids)
for chunk_start in range(0, total_files, FILE_BATCH_CHUNK_SIZE):
chunk_end = min(chunk_start + FILE_BATCH_CHUNK_SIZE, total_files)
chunk = file_ids[chunk_start:chunk_end]
logger.info(
f"Processing chunk {chunk_start // FILE_BATCH_CHUNK_SIZE + 1} of {(total_files + FILE_BATCH_CHUNK_SIZE - 1) // FILE_BATCH_CHUNK_SIZE} ({len(chunk)} files)"
)
async with asyncio.TaskGroup() as tg:
chunk_tasks = [tg.create_task(process_single_file(file_id)) for file_id in chunk]
chunk_results = [task.result() for task in chunk_tasks]
# Update counts after each chunk for progressive feedback
for _, success in chunk_results:
self._update_file_counts(batch_info, success=success)
# Save progress after each chunk
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
def _update_file_counts(self, batch_info: dict[str, Any], success: bool) -> None:
"""Update file counts based on processing result."""
if success:
batch_info["file_counts"]["completed"] += 1
else:
batch_info["file_counts"]["failed"] += 1
batch_info["file_counts"]["in_progress"] -= 1
def _update_batch_status(self, batch_info: dict[str, Any]) -> None:
"""Update final batch status based on file processing results."""
if batch_info["file_counts"]["failed"] == 0:
batch_info["status"] = "completed"
elif batch_info["file_counts"]["completed"] == 0:
batch_info["status"] = "failed"
else:
batch_info["status"] = "completed" # Partial success counts as completed
async def _process_file_batch_async( async def _process_file_batch_async(
self, self,
batch_id: str, batch_id: str,
@ -956,40 +1029,34 @@ class OpenAIVectorStoreMixin(ABC):
attributes = batch_info["attributes"] attributes = batch_info["attributes"]
chunking_strategy = batch_info["chunking_strategy"] chunking_strategy = batch_info["chunking_strategy"]
vector_store_id = batch_info["vector_store_id"] vector_store_id = batch_info["vector_store_id"]
chunking_strategy_adapter: TypeAdapter[VectorStoreChunkingStrategy] = TypeAdapter(VectorStoreChunkingStrategy)
chunking_strategy_obj = chunking_strategy_adapter.validate_python(chunking_strategy)
for file_id in file_ids: try:
try: # Process all files with controlled concurrency
chunking_strategy_adapter: TypeAdapter[VectorStoreChunkingStrategy] = TypeAdapter( await self._process_files_with_concurrency(
VectorStoreChunkingStrategy file_ids=file_ids,
) vector_store_id=vector_store_id,
chunking_strategy_obj = chunking_strategy_adapter.validate_python(chunking_strategy) attributes=attributes,
await self.openai_attach_file_to_vector_store( chunking_strategy_obj=chunking_strategy_obj,
vector_store_id=vector_store_id, batch_id=batch_id,
file_id=file_id, batch_info=batch_info,
attributes=attributes, )
chunking_strategy=chunking_strategy_obj,
)
# Update counts atomically # Update final batch status
batch_info["file_counts"]["completed"] += 1 self._update_batch_status(batch_info)
batch_info["file_counts"]["in_progress"] -= 1 await self._save_openai_vector_store_file_batch(batch_id, batch_info)
except Exception as e: logger.info(f"File batch {batch_id} processing completed with status: {batch_info['status']}")
logger.error(f"Failed to process file {file_id} in batch {batch_id}: {e}")
batch_info["file_counts"]["failed"] += 1
batch_info["file_counts"]["in_progress"] -= 1
# Update final status when all files are processed except asyncio.CancelledError:
if batch_info["file_counts"]["failed"] == 0: logger.info(f"File batch {batch_id} processing was cancelled")
batch_info["status"] = "completed" # Clean up task reference if it still exists
elif batch_info["file_counts"]["completed"] == 0: self._file_batch_tasks.pop(batch_id, None)
batch_info["status"] = "failed" raise # Re-raise to ensure proper cancellation propagation
else: finally:
batch_info["status"] = "completed" # Partial success counts as completed # Always clean up task reference when processing ends
self._file_batch_tasks.pop(batch_id, None)
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
logger.info(f"File batch {batch_id} processing completed with status: {batch_info['status']}")
def _get_and_validate_batch(self, batch_id: str, vector_store_id: str) -> dict[str, Any]: def _get_and_validate_batch(self, batch_id: str, vector_store_id: str) -> dict[str, Any]:
"""Get and validate batch exists and belongs to vector store.""" """Get and validate batch exists and belongs to vector store."""
@ -1114,6 +1181,15 @@ class OpenAIVectorStoreMixin(ABC):
if batch_info["status"] not in ["in_progress"]: if batch_info["status"] not in ["in_progress"]:
raise ValueError(f"Cannot cancel batch {batch_id} with status {batch_info['status']}") raise ValueError(f"Cannot cancel batch {batch_id} with status {batch_info['status']}")
# Cancel the actual processing task if it exists
if batch_id in self._file_batch_tasks:
task = self._file_batch_tasks[batch_id]
if not task.done():
task.cancel()
logger.info(f"Cancelled processing task for file batch: {batch_id}")
# Remove from task tracking
del self._file_batch_tasks[batch_id]
batch_info["status"] = "cancelled" batch_info["status"] = "cancelled"
await self._save_openai_vector_store_file_batch(batch_id, batch_info) await self._save_openai_vector_store_file_batch(batch_id, batch_info)

View file

@ -6,7 +6,7 @@
import json import json
import time import time
from unittest.mock import AsyncMock from unittest.mock import AsyncMock, patch
import numpy as np import numpy as np
import pytest import pytest
@ -31,6 +31,24 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREF
# -v -s --tb=short --disable-warnings --asyncio-mode=auto # -v -s --tb=short --disable-warnings --asyncio-mode=auto
@pytest.fixture(autouse=True)
def mock_resume_file_batches(request):
"""Mock the resume functionality to prevent stale file batches from being processed during tests."""
# Skip mocking for tests that specifically test the resume functionality
if any(
test_name in request.node.name
for test_name in ["test_only_in_progress_batches_resumed", "test_file_batch_persistence_across_restarts"]
):
yield
return
with patch(
"llama_stack.providers.utils.memory.openai_vector_store_mixin.OpenAIVectorStoreMixin._resume_incomplete_batches",
new_callable=AsyncMock,
):
yield
async def test_initialize_index(vector_index): async def test_initialize_index(vector_index):
await vector_index.initialize() await vector_index.initialize()
@ -918,3 +936,55 @@ async def test_expired_batch_access_error(vector_io_adapter):
# Try to access expired batch # Try to access expired batch
with pytest.raises(ValueError, match="File batch batch_expired has expired after 7 days from creation"): with pytest.raises(ValueError, match="File batch batch_expired has expired after 7 days from creation"):
vector_io_adapter._get_and_validate_batch("batch_expired", store_id) vector_io_adapter._get_and_validate_batch("batch_expired", store_id)
async def test_max_concurrent_files_per_batch(vector_io_adapter):
"""Test that file batch processing respects MAX_CONCURRENT_FILES_PER_BATCH limit."""
import asyncio
store_id = "vs_1234"
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
active_files = 0
async def mock_attach_file_with_delay(vector_store_id: str, file_id: str, **kwargs):
"""Mock that tracks concurrency and blocks indefinitely to test concurrency limit."""
nonlocal active_files
active_files += 1
# Block indefinitely to test concurrency limit
await asyncio.sleep(float("inf"))
# Replace the attachment method
vector_io_adapter.openai_attach_file_to_vector_store = mock_attach_file_with_delay
# Create a batch with more files than the concurrency limit
file_ids = [f"file_{i}" for i in range(8)] # 8 files, but limit should be 5
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
)
# Give time for the semaphore logic to start processing files
await asyncio.sleep(0.2)
# Verify that only MAX_CONCURRENT_FILES_PER_BATCH files are processing concurrently
# The semaphore in _process_files_with_concurrency should limit this
from llama_stack.providers.utils.memory.openai_vector_store_mixin import MAX_CONCURRENT_FILES_PER_BATCH
assert active_files == MAX_CONCURRENT_FILES_PER_BATCH, (
f"Expected {MAX_CONCURRENT_FILES_PER_BATCH} active files, got {active_files}"
)
# Verify batch is in progress
assert batch.status == "in_progress"
assert batch.file_counts.total == 8
assert batch.file_counts.in_progress == 8