This commit is contained in:
Swapna Lekkala 2025-09-30 10:31:56 -07:00
parent 852c058806
commit ca02ce0a98
3 changed files with 41 additions and 205 deletions

View file

@ -55,9 +55,7 @@ VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = (
f"openai_vector_stores_files_contents:{VERSION}::"
)
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
class OpenAIVectorStoreMixin(ABC):
@ -74,9 +72,7 @@ class OpenAIVectorStoreMixin(ABC):
# KV store for persisting OpenAI vector store metadata
kvstore: KVStore | None
async def _save_openai_vector_store(
self, store_id: str, store_info: dict[str, Any]
) -> None:
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to persistent storage."""
assert self.kvstore
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
@ -97,9 +93,7 @@ class OpenAIVectorStoreMixin(ABC):
stores[info["id"]] = info
return stores
async def _update_openai_vector_store(
self, store_id: str, store_info: dict[str, Any]
) -> None:
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in persistent storage."""
assert self.kvstore
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
@ -126,26 +120,18 @@ class OpenAIVectorStoreMixin(ABC):
assert self.kvstore
meta_key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=meta_key, value=json.dumps(file_info))
contents_prefix = (
f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}:"
)
contents_prefix = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}:"
for idx, chunk in enumerate(file_contents):
await self.kvstore.set(
key=f"{contents_prefix}{idx}", value=json.dumps(chunk)
)
await self.kvstore.set(key=f"{contents_prefix}{idx}", value=json.dumps(chunk))
async def _load_openai_vector_store_file(
self, store_id: str, file_id: str
) -> dict[str, Any]:
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from persistent storage."""
assert self.kvstore
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
stored_data = await self.kvstore.get(key)
return json.loads(stored_data) if stored_data else {}
async def _load_openai_vector_store_file_contents(
self, store_id: str, file_id: str
) -> list[dict[str, Any]]:
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from persistent storage."""
assert self.kvstore
prefix = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}:"
@ -153,26 +139,20 @@ class OpenAIVectorStoreMixin(ABC):
raw_items = await self.kvstore.values_in_range(prefix, end_key)
return [json.loads(item) for item in raw_items]
async def _update_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any]
) -> None:
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in persistent storage."""
assert self.kvstore
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=key, value=json.dumps(file_info))
async def _delete_openai_vector_store_file_from_storage(
self, store_id: str, file_id: str
) -> None:
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from persistent storage."""
assert self.kvstore
meta_key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.delete(meta_key)
contents_prefix = (
f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}:"
)
contents_prefix = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}:"
end_key = f"{contents_prefix}\xff"
# load all stored chunk values (values_in_range is implemented by all backends)
raw_items = await self.kvstore.values_in_range(contents_prefix, end_key)
@ -185,9 +165,7 @@ class OpenAIVectorStoreMixin(ABC):
self.openai_vector_stores = await self._load_openai_vector_stores()
@abstractmethod
async def delete_chunks(
self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]
) -> None:
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a vector store."""
pass
@ -298,10 +276,7 @@ class OpenAIVectorStoreMixin(ABC):
# Now that our vector store is created, attach any files that were provided
file_ids = file_ids or []
tasks = [
self.openai_attach_file_to_vector_store(vector_db_id, file_id)
for file_id in file_ids
]
tasks = [self.openai_attach_file_to_vector_store(vector_db_id, file_id) for file_id in file_ids]
await asyncio.gather(*tasks)
# Get the updated store info and return it
@ -328,9 +303,7 @@ class OpenAIVectorStoreMixin(ABC):
# Apply cursor-based pagination
if after:
after_index = next(
(i for i, store in enumerate(all_stores) if store["id"] == after), -1
)
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
if after_index >= 0:
all_stores = all_stores[after_index + 1 :]
@ -419,9 +392,7 @@ class OpenAIVectorStoreMixin(ABC):
try:
await self.unregister_vector_db(vector_store_id)
except Exception as e:
logger.warning(
f"Failed to delete underlying vector DB {vector_store_id}: {e}"
)
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
return VectorStoreDeleteResponse(
id=vector_store_id,
@ -446,9 +417,7 @@ class OpenAIVectorStoreMixin(ABC):
# Validate search_mode
valid_modes = {"keyword", "vector", "hybrid"}
if search_mode not in valid_modes:
raise ValueError(
f"search_mode must be one of {valid_modes}, got {search_mode}"
)
raise ValueError(f"search_mode must be one of {valid_modes}, got {search_mode}")
if vector_store_id not in self.openai_vector_stores:
raise VectorStoreNotFoundError(vector_store_id)
@ -516,9 +485,7 @@ class OpenAIVectorStoreMixin(ABC):
next_page=None,
)
def _matches_filters(
self, metadata: dict[str, Any], filters: dict[str, Any]
) -> bool:
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
"""Check if metadata matches the provided filters."""
if not filters:
return True
@ -638,9 +605,7 @@ class OpenAIVectorStoreMixin(ABC):
try:
file_response = await self.files_api.openai_retrieve_file(file_id)
mime_type, _ = mimetypes.guess_type(file_response.filename)
content_response = await self.files_api.openai_retrieve_file_content(
file_id
)
content_response = await self.files_api.openai_retrieve_file_content(file_id)
content = content_from_data_and_mime_type(content_response.body, mime_type)
@ -679,9 +644,7 @@ class OpenAIVectorStoreMixin(ABC):
# Save vector store file to persistent storage (provider-specific)
dict_chunks = [c.model_dump() for c in chunks]
# This should be updated to include chunk_id
await self._save_openai_vector_store_file(
vector_store_id, file_id, file_info, dict_chunks
)
await self._save_openai_vector_store_file(vector_store_id, file_id, file_info, dict_chunks)
# Update file_ids and file_counts in vector store metadata
store_info = self.openai_vector_stores[vector_store_id].copy()
@ -717,9 +680,7 @@ class OpenAIVectorStoreMixin(ABC):
file_objects: list[VectorStoreFileObject] = []
for file_id in store_info["file_ids"]:
file_info = await self._load_openai_vector_store_file(
vector_store_id, file_id
)
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
file_object = VectorStoreFileObject(**file_info)
if filter and file_object.status != filter:
continue
@ -731,9 +692,7 @@ class OpenAIVectorStoreMixin(ABC):
# Apply cursor-based pagination
if after:
after_index = next(
(i for i, file in enumerate(file_objects) if file.id == after), -1
)
after_index = next((i for i, file in enumerate(file_objects) if file.id == after), -1)
if after_index >= 0:
file_objects = file_objects[after_index + 1 :]
@ -770,9 +729,7 @@ class OpenAIVectorStoreMixin(ABC):
store_info = self.openai_vector_stores[vector_store_id]
if file_id not in store_info["file_ids"]:
raise ValueError(
f"File {file_id} not found in vector store {vector_store_id}"
)
raise ValueError(f"File {file_id} not found in vector store {vector_store_id}")
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
return VectorStoreFileObject(**file_info)
@ -787,9 +744,7 @@ class OpenAIVectorStoreMixin(ABC):
raise VectorStoreNotFoundError(vector_store_id)
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
dict_chunks = await self._load_openai_vector_store_file_contents(
vector_store_id, file_id
)
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
chunks = [Chunk.model_validate(c) for c in dict_chunks]
content = []
for chunk in chunks:
@ -813,9 +768,7 @@ class OpenAIVectorStoreMixin(ABC):
store_info = self.openai_vector_stores[vector_store_id]
if file_id not in store_info["file_ids"]:
raise ValueError(
f"File {file_id} not found in vector store {vector_store_id}"
)
raise ValueError(f"File {file_id} not found in vector store {vector_store_id}")
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
file_info["attributes"] = attributes
@ -831,9 +784,7 @@ class OpenAIVectorStoreMixin(ABC):
if vector_store_id not in self.openai_vector_stores:
raise VectorStoreNotFoundError(vector_store_id)
dict_chunks = await self._load_openai_vector_store_file_contents(
vector_store_id, file_id
)
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
chunks = [Chunk.model_validate(c) for c in dict_chunks]
# Create ChunkForDeletion objects with both chunk_id and document_id
@ -844,15 +795,9 @@ class OpenAIVectorStoreMixin(ABC):
c.chunk_metadata.document_id if c.chunk_metadata else None
)
if document_id:
chunks_for_deletion.append(
ChunkForDeletion(
chunk_id=str(c.chunk_id), document_id=document_id
)
)
chunks_for_deletion.append(ChunkForDeletion(chunk_id=str(c.chunk_id), document_id=document_id))
else:
logger.warning(
f"Chunk {c.chunk_id} has no document_id, skipping deletion"
)
logger.warning(f"Chunk {c.chunk_id} has no document_id, skipping deletion")
if chunks_for_deletion:
await self.delete_chunks(vector_store_id, chunks_for_deletion)
@ -860,9 +805,7 @@ class OpenAIVectorStoreMixin(ABC):
store_info = self.openai_vector_stores[vector_store_id].copy()
file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id)
await self._delete_openai_vector_store_file_from_storage(
vector_store_id, file_id
)
await self._delete_openai_vector_store_file_from_storage(vector_store_id, file_id)
# Update in-memory cache
store_info["file_ids"].remove(file_id)
@ -919,11 +862,7 @@ class OpenAIVectorStoreMixin(ABC):
}
# Start background processing of files
asyncio.create_task(
self._process_file_batch_async(
batch_id, file_ids, attributes, chunking_strategy
)
)
asyncio.create_task(self._process_file_batch_async(batch_id, file_ids, attributes, chunking_strategy))
return batch_object
@ -954,9 +893,7 @@ class OpenAIVectorStoreMixin(ABC):
batch_object.file_counts.in_progress -= 1
except Exception as e:
logger.error(
f"Failed to process file {file_id} in batch {batch_id}: {e}"
)
logger.error(f"Failed to process file {file_id} in batch {batch_id}: {e}")
batch_object.file_counts.failed += 1
batch_object.file_counts.in_progress -= 1
@ -968,9 +905,7 @@ class OpenAIVectorStoreMixin(ABC):
else:
batch_object.status = "completed" # Partial success counts as completed
logger.info(
f"File batch {batch_id} processing completed with status: {batch_object.status}"
)
logger.info(f"File batch {batch_id} processing completed with status: {batch_object.status}")
def _get_and_validate_batch(
self, batch_id: str, vector_store_id: str
@ -986,9 +921,7 @@ class OpenAIVectorStoreMixin(ABC):
batch_object = batch_info["batch_object"]
if batch_object.vector_store_id != vector_store_id:
raise ValueError(
f"File batch {batch_id} does not belong to vector store {vector_store_id}"
)
raise ValueError(f"File batch {batch_id} does not belong to vector store {vector_store_id}")
return batch_info, batch_object
@ -1056,9 +989,7 @@ class OpenAIVectorStoreMixin(ABC):
for file_id in batch_file_ids:
try:
file_info = await self._load_openai_vector_store_file(
vector_store_id, file_id
)
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
file_object = VectorStoreFileObject(**file_info)
# Apply status filter if provided
@ -1067,9 +998,7 @@ class OpenAIVectorStoreMixin(ABC):
batch_file_objects.append(file_object)
except Exception as e:
logger.warning(
f"Could not load file {file_id} from batch {batch_id}: {e}"
)
logger.warning(f"Could not load file {file_id} from batch {batch_id}: {e}")
continue
# Sort by created_at
@ -1077,9 +1006,7 @@ class OpenAIVectorStoreMixin(ABC):
batch_file_objects.sort(key=lambda x: x.created_at, reverse=reverse_order)
# Apply pagination using helper
paginated_files, has_more, first_id, last_id = self._paginate_objects(
batch_file_objects, limit, after, before
)
paginated_files, has_more, first_id, last_id = self._paginate_objects(batch_file_objects, limit, after, before)
return VectorStoreFilesListInBatchResponse(
data=paginated_files,
@ -1094,15 +1021,11 @@ class OpenAIVectorStoreMixin(ABC):
vector_store_id: str,
) -> VectorStoreFileBatchObject:
"""Cancel a vector store file batch."""
batch_info, batch_object = self._get_and_validate_batch(
batch_id, vector_store_id
)
batch_info, batch_object = self._get_and_validate_batch(batch_id, vector_store_id)
# Only allow cancellation if batch is in progress
if batch_object.status not in ["in_progress"]:
raise ValueError(
f"Cannot cancel batch {batch_id} with status {batch_object.status}"
)
raise ValueError(f"Cannot cancel batch {batch_id} with status {batch_object.status}")
# Create updated batch object with cancelled status
updated_batch = VectorStoreFileBatchObject(