mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
add concurrent file attaching
This commit is contained in:
parent
a5b71c2cc7
commit
e58bf82581
9 changed files with 188 additions and 33 deletions
|
@ -207,6 +207,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
self.kvstore: KVStore | None = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.openai_file_batches: dict[str, dict[str, Any]] = {}
|
||||
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._last_file_batch_cleanup_time = 0
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -416,6 +416,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.openai_file_batches: dict[str, dict[str, Any]] = {}
|
||||
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._last_file_batch_cleanup_time = 0
|
||||
self.kvstore: KVStore | None = None
|
||||
|
||||
|
|
|
@ -167,6 +167,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.client = chromadb.PersistentClient(path=self.config.db_path)
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
self.openai_file_batches: dict[str, dict[str, Any]] = {}
|
||||
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._last_file_batch_cleanup_time = 0
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
|
|
@ -318,6 +318,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.vector_db_store = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.openai_file_batches: dict[str, dict[str, Any]] = {}
|
||||
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._last_file_batch_cleanup_time = 0
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import heapq
|
||||
from typing import Any
|
||||
|
||||
|
@ -354,6 +355,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
self.vector_db_store = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.openai_file_batches: dict[str, dict[str, Any]] = {}
|
||||
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._last_file_batch_cleanup_time = 0
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
|
|
|
@ -171,6 +171,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.kvstore: KVStore | None = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.openai_file_batches: dict[str, dict[str, Any]] = {}
|
||||
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._last_file_batch_cleanup_time = 0
|
||||
self._qdrant_lock = asyncio.Lock()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
@ -293,6 +294,7 @@ class WeaviateVectorIOAdapter(
|
|||
self.vector_db_store = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.openai_file_batches: dict[str, dict[str, Any]] = {}
|
||||
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._last_file_batch_cleanup_time = 0
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
|
|
|
@ -53,6 +53,8 @@ logger = get_logger(name=__name__, category="providers::utils")
|
|||
# Constants for OpenAI vector stores
|
||||
CHUNK_MULTIPLIER = 5
|
||||
FILE_BATCH_CLEANUP_INTERVAL_SECONDS = 24 * 60 * 60 # 1 day in seconds
|
||||
MAX_CONCURRENT_FILES_PER_BATCH = 5 # Maximum concurrent file processing within a batch
|
||||
FILE_BATCH_CHUNK_SIZE = 10 # Process files in chunks of this size (2x concurrency)
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||
|
@ -77,6 +79,8 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
kvstore: KVStore | None
|
||||
# Track last cleanup time to throttle cleanup operations
|
||||
_last_file_batch_cleanup_time: int
|
||||
# Track running file batch processing tasks
|
||||
_file_batch_tasks: dict[str, asyncio.Task[None]]
|
||||
|
||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Save vector store metadata to persistent storage."""
|
||||
|
@ -224,12 +228,14 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
if batch_info["status"] == "in_progress":
|
||||
logger.info(f"Resuming incomplete file batch: {batch_id}")
|
||||
# Restart the background processing task
|
||||
asyncio.create_task(self._process_file_batch_async(batch_id, batch_info))
|
||||
task = asyncio.create_task(self._process_file_batch_async(batch_id, batch_info))
|
||||
self._file_batch_tasks[batch_id] = task
|
||||
|
||||
async def initialize_openai_vector_stores(self) -> None:
|
||||
"""Load existing OpenAI vector stores and file batches into the in-memory cache."""
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
self.openai_file_batches = await self._load_openai_vector_store_file_batches()
|
||||
self._file_batch_tasks = {}
|
||||
await self._resume_incomplete_batches()
|
||||
self._last_file_batch_cleanup_time = 0
|
||||
|
||||
|
@ -935,7 +941,8 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
|
||||
|
||||
# Start background processing of files
|
||||
asyncio.create_task(self._process_file_batch_async(batch_id, batch_info))
|
||||
task = asyncio.create_task(self._process_file_batch_async(batch_id, batch_info))
|
||||
self._file_batch_tasks[batch_id] = task
|
||||
|
||||
# Run cleanup if needed (throttled to once every 1 day)
|
||||
current_time = int(time.time())
|
||||
|
@ -946,6 +953,72 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
return batch_object
|
||||
|
||||
async def _process_files_with_concurrency(
|
||||
self,
|
||||
file_ids: list[str],
|
||||
vector_store_id: str,
|
||||
attributes: dict[str, Any],
|
||||
chunking_strategy_obj: Any,
|
||||
batch_id: str,
|
||||
batch_info: dict[str, Any],
|
||||
) -> None:
|
||||
"""Process files with controlled concurrency and chunking."""
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_FILES_PER_BATCH)
|
||||
|
||||
async def process_single_file(file_id: str) -> tuple[str, bool]:
|
||||
"""Process a single file with concurrency control."""
|
||||
async with semaphore:
|
||||
try:
|
||||
await self.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy_obj,
|
||||
)
|
||||
return file_id, True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process file {file_id} in batch {batch_id}: {e}")
|
||||
return file_id, False
|
||||
|
||||
# Process files in chunks to avoid creating too many tasks at once
|
||||
total_files = len(file_ids)
|
||||
for chunk_start in range(0, total_files, FILE_BATCH_CHUNK_SIZE):
|
||||
chunk_end = min(chunk_start + FILE_BATCH_CHUNK_SIZE, total_files)
|
||||
chunk = file_ids[chunk_start:chunk_end]
|
||||
|
||||
logger.info(
|
||||
f"Processing chunk {chunk_start // FILE_BATCH_CHUNK_SIZE + 1} of {(total_files + FILE_BATCH_CHUNK_SIZE - 1) // FILE_BATCH_CHUNK_SIZE} ({len(chunk)} files)"
|
||||
)
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
chunk_tasks = [tg.create_task(process_single_file(file_id)) for file_id in chunk]
|
||||
|
||||
chunk_results = [task.result() for task in chunk_tasks]
|
||||
|
||||
# Update counts after each chunk for progressive feedback
|
||||
for _, success in chunk_results:
|
||||
self._update_file_counts(batch_info, success=success)
|
||||
|
||||
# Save progress after each chunk
|
||||
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
|
||||
|
||||
def _update_file_counts(self, batch_info: dict[str, Any], success: bool) -> None:
|
||||
"""Update file counts based on processing result."""
|
||||
if success:
|
||||
batch_info["file_counts"]["completed"] += 1
|
||||
else:
|
||||
batch_info["file_counts"]["failed"] += 1
|
||||
batch_info["file_counts"]["in_progress"] -= 1
|
||||
|
||||
def _update_batch_status(self, batch_info: dict[str, Any]) -> None:
|
||||
"""Update final batch status based on file processing results."""
|
||||
if batch_info["file_counts"]["failed"] == 0:
|
||||
batch_info["status"] = "completed"
|
||||
elif batch_info["file_counts"]["completed"] == 0:
|
||||
batch_info["status"] = "failed"
|
||||
else:
|
||||
batch_info["status"] = "completed" # Partial success counts as completed
|
||||
|
||||
async def _process_file_batch_async(
|
||||
self,
|
||||
batch_id: str,
|
||||
|
@ -956,40 +1029,34 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
attributes = batch_info["attributes"]
|
||||
chunking_strategy = batch_info["chunking_strategy"]
|
||||
vector_store_id = batch_info["vector_store_id"]
|
||||
chunking_strategy_adapter: TypeAdapter[VectorStoreChunkingStrategy] = TypeAdapter(VectorStoreChunkingStrategy)
|
||||
chunking_strategy_obj = chunking_strategy_adapter.validate_python(chunking_strategy)
|
||||
|
||||
for file_id in file_ids:
|
||||
try:
|
||||
chunking_strategy_adapter: TypeAdapter[VectorStoreChunkingStrategy] = TypeAdapter(
|
||||
VectorStoreChunkingStrategy
|
||||
)
|
||||
chunking_strategy_obj = chunking_strategy_adapter.validate_python(chunking_strategy)
|
||||
await self.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy_obj,
|
||||
)
|
||||
try:
|
||||
# Process all files with controlled concurrency
|
||||
await self._process_files_with_concurrency(
|
||||
file_ids=file_ids,
|
||||
vector_store_id=vector_store_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy_obj=chunking_strategy_obj,
|
||||
batch_id=batch_id,
|
||||
batch_info=batch_info,
|
||||
)
|
||||
|
||||
# Update counts atomically
|
||||
batch_info["file_counts"]["completed"] += 1
|
||||
batch_info["file_counts"]["in_progress"] -= 1
|
||||
# Update final batch status
|
||||
self._update_batch_status(batch_info)
|
||||
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process file {file_id} in batch {batch_id}: {e}")
|
||||
batch_info["file_counts"]["failed"] += 1
|
||||
batch_info["file_counts"]["in_progress"] -= 1
|
||||
logger.info(f"File batch {batch_id} processing completed with status: {batch_info['status']}")
|
||||
|
||||
# Update final status when all files are processed
|
||||
if batch_info["file_counts"]["failed"] == 0:
|
||||
batch_info["status"] = "completed"
|
||||
elif batch_info["file_counts"]["completed"] == 0:
|
||||
batch_info["status"] = "failed"
|
||||
else:
|
||||
batch_info["status"] = "completed" # Partial success counts as completed
|
||||
|
||||
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
|
||||
|
||||
logger.info(f"File batch {batch_id} processing completed with status: {batch_info['status']}")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"File batch {batch_id} processing was cancelled")
|
||||
# Clean up task reference if it still exists
|
||||
self._file_batch_tasks.pop(batch_id, None)
|
||||
raise # Re-raise to ensure proper cancellation propagation
|
||||
finally:
|
||||
# Always clean up task reference when processing ends
|
||||
self._file_batch_tasks.pop(batch_id, None)
|
||||
|
||||
def _get_and_validate_batch(self, batch_id: str, vector_store_id: str) -> dict[str, Any]:
|
||||
"""Get and validate batch exists and belongs to vector store."""
|
||||
|
@ -1114,6 +1181,15 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
if batch_info["status"] not in ["in_progress"]:
|
||||
raise ValueError(f"Cannot cancel batch {batch_id} with status {batch_info['status']}")
|
||||
|
||||
# Cancel the actual processing task if it exists
|
||||
if batch_id in self._file_batch_tasks:
|
||||
task = self._file_batch_tasks[batch_id]
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
logger.info(f"Cancelled processing task for file batch: {batch_id}")
|
||||
# Remove from task tracking
|
||||
del self._file_batch_tasks[batch_id]
|
||||
|
||||
batch_info["status"] = "cancelled"
|
||||
|
||||
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -31,6 +31,24 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREF
|
|||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_resume_file_batches(request):
|
||||
"""Mock the resume functionality to prevent stale file batches from being processed during tests."""
|
||||
# Skip mocking for tests that specifically test the resume functionality
|
||||
if any(
|
||||
test_name in request.node.name
|
||||
for test_name in ["test_only_in_progress_batches_resumed", "test_file_batch_persistence_across_restarts"]
|
||||
):
|
||||
yield
|
||||
return
|
||||
|
||||
with patch(
|
||||
"llama_stack.providers.utils.memory.openai_vector_store_mixin.OpenAIVectorStoreMixin._resume_incomplete_batches",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
async def test_initialize_index(vector_index):
|
||||
await vector_index.initialize()
|
||||
|
||||
|
@ -918,3 +936,55 @@ async def test_expired_batch_access_error(vector_io_adapter):
|
|||
# Try to access expired batch
|
||||
with pytest.raises(ValueError, match="File batch batch_expired has expired after 7 days from creation"):
|
||||
vector_io_adapter._get_and_validate_batch("batch_expired", store_id)
|
||||
|
||||
|
||||
async def test_max_concurrent_files_per_batch(vector_io_adapter):
|
||||
"""Test that file batch processing respects MAX_CONCURRENT_FILES_PER_BATCH limit."""
|
||||
import asyncio
|
||||
|
||||
store_id = "vs_1234"
|
||||
|
||||
# Setup vector store
|
||||
vector_io_adapter.openai_vector_stores[store_id] = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"files": {},
|
||||
"file_ids": [],
|
||||
}
|
||||
|
||||
active_files = 0
|
||||
|
||||
async def mock_attach_file_with_delay(vector_store_id: str, file_id: str, **kwargs):
|
||||
"""Mock that tracks concurrency and blocks indefinitely to test concurrency limit."""
|
||||
nonlocal active_files
|
||||
active_files += 1
|
||||
|
||||
# Block indefinitely to test concurrency limit
|
||||
await asyncio.sleep(float("inf"))
|
||||
|
||||
# Replace the attachment method
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = mock_attach_file_with_delay
|
||||
|
||||
# Create a batch with more files than the concurrency limit
|
||||
file_ids = [f"file_{i}" for i in range(8)] # 8 files, but limit should be 5
|
||||
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
)
|
||||
|
||||
# Give time for the semaphore logic to start processing files
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Verify that only MAX_CONCURRENT_FILES_PER_BATCH files are processing concurrently
|
||||
# The semaphore in _process_files_with_concurrency should limit this
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import MAX_CONCURRENT_FILES_PER_BATCH
|
||||
|
||||
assert active_files == MAX_CONCURRENT_FILES_PER_BATCH, (
|
||||
f"Expected {MAX_CONCURRENT_FILES_PER_BATCH} active files, got {active_files}"
|
||||
)
|
||||
|
||||
# Verify batch is in progress
|
||||
assert batch.status == "in_progress"
|
||||
assert batch.file_counts.total == 8
|
||||
assert batch.file_counts.in_progress == 8
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue