mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
Handle attaching files at vector store creation
This wires up the `file_ids` from the vector store create request to actually attach all those files at creation time. This also required smarter handling of the file_ids and file_count metadata handling to ensure we update the in-memory cache and persistent representation of those as we attach those files. And, expand that logic to handle errors during file attachment to persist the failed status. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
459d50a365
commit
a2f0f608db
3 changed files with 97 additions and 42 deletions
|
@ -660,7 +660,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
(store_id, file_id),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
print(f"!!! row is {row}")
|
||||
if row is None:
|
||||
return None
|
||||
(metadata,) = row
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import time
|
||||
|
@ -160,15 +161,15 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
# Create OpenAI vector store metadata
|
||||
status = "completed"
|
||||
file_ids = file_ids or []
|
||||
|
||||
# Start with no files attached and update later
|
||||
file_counts = VectorStoreFileCounts(
|
||||
cancelled=0,
|
||||
completed=len(file_ids),
|
||||
completed=0,
|
||||
failed=0,
|
||||
in_progress=0,
|
||||
total=len(file_ids),
|
||||
total=0,
|
||||
)
|
||||
# TODO: actually attach these files to the vector store...
|
||||
store_info = {
|
||||
"id": store_id,
|
||||
"object": "vector_store",
|
||||
|
@ -180,7 +181,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
"expires_after": expires_after,
|
||||
"expires_at": None,
|
||||
"last_active_at": created_at,
|
||||
"file_ids": file_ids,
|
||||
"file_ids": [],
|
||||
"chunking_strategy": chunking_strategy,
|
||||
}
|
||||
|
||||
|
@ -198,18 +199,14 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
# Store in memory cache
|
||||
self.openai_vector_stores[store_id] = store_info
|
||||
|
||||
return VectorStoreObject(
|
||||
id=store_id,
|
||||
created_at=created_at,
|
||||
name=store_id,
|
||||
usage_bytes=0,
|
||||
file_counts=file_counts,
|
||||
status=status,
|
||||
expires_after=expires_after,
|
||||
expires_at=None,
|
||||
last_active_at=created_at,
|
||||
metadata=metadata,
|
||||
)
|
||||
# Now that our vector store is created, attach any files that were provided
|
||||
file_ids = file_ids or []
|
||||
tasks = [self.openai_attach_file_to_vector_store(store_id, file_id) for file_id in file_ids]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Get the updated store info and return it
|
||||
store_info = self.openai_vector_stores[store_id]
|
||||
return VectorStoreObject.model_validate(store_info)
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
|
@ -491,8 +488,6 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
if vector_store_id not in self.openai_vector_stores:
|
||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||
|
||||
store_info = self.openai_vector_stores[vector_store_id].copy()
|
||||
|
||||
attributes = attributes or {}
|
||||
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
|
||||
created_at = int(time.time())
|
||||
|
@ -543,26 +538,12 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
code="server_error",
|
||||
message="No chunks were generated from the file",
|
||||
)
|
||||
return vector_store_file_object
|
||||
|
||||
else:
|
||||
await self.insert_chunks(
|
||||
vector_db_id=vector_store_id,
|
||||
chunks=chunks,
|
||||
)
|
||||
vector_store_file_object.status = "completed"
|
||||
|
||||
# Create OpenAI vector store file metadata
|
||||
file_info = vector_store_file_object.model_dump(exclude={"last_error"})
|
||||
|
||||
# Save to persistent storage (provider-specific)
|
||||
await self._save_openai_vector_store_file(vector_store_id, file_id, file_info)
|
||||
|
||||
# Update in-memory cache
|
||||
store_info["file_ids"].append(file_id)
|
||||
store_info["file_counts"]["completed"] += 1
|
||||
store_info["file_counts"]["total"] += 1
|
||||
self.openai_vector_stores[vector_store_id] = store_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error attaching file to vector store: {e}")
|
||||
vector_store_file_object.status = "failed"
|
||||
|
@ -570,7 +551,24 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
code="server_error",
|
||||
message=str(e),
|
||||
)
|
||||
return vector_store_file_object
|
||||
|
||||
# Create OpenAI vector store file metadata
|
||||
file_info = vector_store_file_object.model_dump(exclude={"last_error"})
|
||||
|
||||
# Save vector store file to persistent storage (provider-specific)
|
||||
await self._save_openai_vector_store_file(vector_store_id, file_id, file_info)
|
||||
|
||||
# Update file_ids and file_counts in vector store metadata
|
||||
store_info = self.openai_vector_stores[vector_store_id].copy()
|
||||
store_info["file_ids"].append(file_id)
|
||||
store_info["file_counts"]["total"] += 1
|
||||
store_info["file_counts"][vector_store_file_object.status] += 1
|
||||
|
||||
# Save updated vector store to persistent storage
|
||||
await self._save_openai_vector_store(vector_store_id, store_info)
|
||||
|
||||
# Update vector store in-memory cache
|
||||
self.openai_vector_stores[vector_store_id] = store_info
|
||||
|
||||
return vector_store_file_object
|
||||
|
||||
|
@ -643,12 +641,17 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id)
|
||||
await self._delete_openai_vector_store_file_from_storage(vector_store_id, file_id)
|
||||
|
||||
# TODO: We need to actually delete the embeddings from the underlying vector store...
|
||||
|
||||
# Update in-memory cache
|
||||
store_info["file_ids"].remove(file_id)
|
||||
store_info["file_counts"][file.status] -= 1
|
||||
store_info["file_counts"]["total"] -= 1
|
||||
self.openai_vector_stores[vector_store_id] = store_info
|
||||
|
||||
# Save updated vector store to persistent storage
|
||||
await self._save_openai_vector_store(vector_store_id, store_info)
|
||||
|
||||
return VectorStoreFileDeleteResponse(
|
||||
id=file_id,
|
||||
deleted=True,
|
||||
|
|
|
@ -481,6 +481,59 @@ def test_openai_vector_store_attach_file_response_attributes(compat_client_with_
|
|||
assert updated_vector_store.file_counts.in_progress == 0
|
||||
|
||||
|
||||
def test_openai_vector_store_attach_files_on_creation(compat_client_with_empty_stores, client_with_models):
|
||||
"""Test OpenAI vector store attach files on creation."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||
pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient")
|
||||
|
||||
compat_client = compat_client_with_empty_stores
|
||||
|
||||
# Create some files and attach them to the vector store
|
||||
valid_file_ids = []
|
||||
for i in range(3):
|
||||
with BytesIO(f"This is a test file {i}".encode()) as file_buffer:
|
||||
file_buffer.name = f"openai_test_{i}.txt"
|
||||
file = compat_client.files.create(file=file_buffer, purpose="assistants")
|
||||
valid_file_ids.append(file.id)
|
||||
|
||||
# include an invalid file ID so we can test failed status
|
||||
file_ids = valid_file_ids + ["invalid_file_id"]
|
||||
|
||||
# Create a vector store
|
||||
vector_store = compat_client.vector_stores.create(
|
||||
name="test_store",
|
||||
file_ids=file_ids,
|
||||
)
|
||||
|
||||
assert vector_store.file_counts.completed == len(valid_file_ids)
|
||||
assert vector_store.file_counts.total == len(file_ids)
|
||||
assert vector_store.file_counts.cancelled == 0
|
||||
assert vector_store.file_counts.failed == len(file_ids) - len(valid_file_ids)
|
||||
assert vector_store.file_counts.in_progress == 0
|
||||
|
||||
files_list = compat_client.vector_stores.files.list(vector_store_id=vector_store.id)
|
||||
assert len(files_list.data) == len(file_ids)
|
||||
assert set(file_ids) == {file.id for file in files_list.data}
|
||||
for file in files_list.data:
|
||||
if file.id in valid_file_ids:
|
||||
assert file.status == "completed"
|
||||
else:
|
||||
assert file.status == "failed"
|
||||
|
||||
# Delete the invalid file
|
||||
delete_response = compat_client.vector_stores.files.delete(
|
||||
vector_store_id=vector_store.id, file_id="invalid_file_id"
|
||||
)
|
||||
assert delete_response.id == "invalid_file_id"
|
||||
|
||||
updated_vector_store = compat_client.vector_stores.retrieve(vector_store_id=vector_store.id)
|
||||
assert updated_vector_store.file_counts.completed == len(valid_file_ids)
|
||||
assert updated_vector_store.file_counts.total == len(valid_file_ids)
|
||||
assert updated_vector_store.file_counts.failed == 0
|
||||
|
||||
|
||||
def test_openai_vector_store_list_files(compat_client_with_empty_stores, client_with_models):
|
||||
"""Test OpenAI vector store list files."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
@ -511,7 +564,7 @@ def test_openai_vector_store_list_files(compat_client_with_empty_stores, client_
|
|||
assert files_list.object == "list"
|
||||
assert files_list.data
|
||||
assert len(files_list.data) == 3
|
||||
assert file_ids == [file.id for file in files_list.data]
|
||||
assert set(file_ids) == {file.id for file in files_list.data}
|
||||
assert files_list.data[0].object == "vector_store.file"
|
||||
assert files_list.data[0].vector_store_id == vector_store.id
|
||||
assert files_list.data[0].status == "completed"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue