mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
Handle attaching files at vector store creation
This wires up the `file_ids` from the vector store create request to actually attach all those files at creation time. This also required smarter handling of the file_ids and file_count metadata handling to ensure we update the in-memory cache and persistent representation of those as we attach those files. And, expand that logic to handle errors during file attachment to persist the failed status. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
459d50a365
commit
a2f0f608db
3 changed files with 97 additions and 42 deletions
|
@ -660,7 +660,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
(store_id, file_id),
|
(store_id, file_id),
|
||||||
)
|
)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
print(f"!!! row is {row}")
|
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
(metadata,) = row
|
(metadata,) = row
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import time
|
import time
|
||||||
|
@ -160,15 +161,15 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
|
|
||||||
# Create OpenAI vector store metadata
|
# Create OpenAI vector store metadata
|
||||||
status = "completed"
|
status = "completed"
|
||||||
file_ids = file_ids or []
|
|
||||||
|
# Start with no files attached and update later
|
||||||
file_counts = VectorStoreFileCounts(
|
file_counts = VectorStoreFileCounts(
|
||||||
cancelled=0,
|
cancelled=0,
|
||||||
completed=len(file_ids),
|
completed=0,
|
||||||
failed=0,
|
failed=0,
|
||||||
in_progress=0,
|
in_progress=0,
|
||||||
total=len(file_ids),
|
total=0,
|
||||||
)
|
)
|
||||||
# TODO: actually attach these files to the vector store...
|
|
||||||
store_info = {
|
store_info = {
|
||||||
"id": store_id,
|
"id": store_id,
|
||||||
"object": "vector_store",
|
"object": "vector_store",
|
||||||
|
@ -180,7 +181,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
"expires_after": expires_after,
|
"expires_after": expires_after,
|
||||||
"expires_at": None,
|
"expires_at": None,
|
||||||
"last_active_at": created_at,
|
"last_active_at": created_at,
|
||||||
"file_ids": file_ids,
|
"file_ids": [],
|
||||||
"chunking_strategy": chunking_strategy,
|
"chunking_strategy": chunking_strategy,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,18 +199,14 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
# Store in memory cache
|
# Store in memory cache
|
||||||
self.openai_vector_stores[store_id] = store_info
|
self.openai_vector_stores[store_id] = store_info
|
||||||
|
|
||||||
return VectorStoreObject(
|
# Now that our vector store is created, attach any files that were provided
|
||||||
id=store_id,
|
file_ids = file_ids or []
|
||||||
created_at=created_at,
|
tasks = [self.openai_attach_file_to_vector_store(store_id, file_id) for file_id in file_ids]
|
||||||
name=store_id,
|
await asyncio.gather(*tasks)
|
||||||
usage_bytes=0,
|
|
||||||
file_counts=file_counts,
|
# Get the updated store info and return it
|
||||||
status=status,
|
store_info = self.openai_vector_stores[store_id]
|
||||||
expires_after=expires_after,
|
return VectorStoreObject.model_validate(store_info)
|
||||||
expires_at=None,
|
|
||||||
last_active_at=created_at,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_list_vector_stores(
|
async def openai_list_vector_stores(
|
||||||
self,
|
self,
|
||||||
|
@ -491,8 +488,6 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
if vector_store_id not in self.openai_vector_stores:
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||||
|
|
||||||
store_info = self.openai_vector_stores[vector_store_id].copy()
|
|
||||||
|
|
||||||
attributes = attributes or {}
|
attributes = attributes or {}
|
||||||
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
|
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
|
||||||
created_at = int(time.time())
|
created_at = int(time.time())
|
||||||
|
@ -543,26 +538,12 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
code="server_error",
|
code="server_error",
|
||||||
message="No chunks were generated from the file",
|
message="No chunks were generated from the file",
|
||||||
)
|
)
|
||||||
return vector_store_file_object
|
else:
|
||||||
|
|
||||||
await self.insert_chunks(
|
await self.insert_chunks(
|
||||||
vector_db_id=vector_store_id,
|
vector_db_id=vector_store_id,
|
||||||
chunks=chunks,
|
chunks=chunks,
|
||||||
)
|
)
|
||||||
vector_store_file_object.status = "completed"
|
vector_store_file_object.status = "completed"
|
||||||
|
|
||||||
# Create OpenAI vector store file metadata
|
|
||||||
file_info = vector_store_file_object.model_dump(exclude={"last_error"})
|
|
||||||
|
|
||||||
# Save to persistent storage (provider-specific)
|
|
||||||
await self._save_openai_vector_store_file(vector_store_id, file_id, file_info)
|
|
||||||
|
|
||||||
# Update in-memory cache
|
|
||||||
store_info["file_ids"].append(file_id)
|
|
||||||
store_info["file_counts"]["completed"] += 1
|
|
||||||
store_info["file_counts"]["total"] += 1
|
|
||||||
self.openai_vector_stores[vector_store_id] = store_info
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error attaching file to vector store: {e}")
|
logger.error(f"Error attaching file to vector store: {e}")
|
||||||
vector_store_file_object.status = "failed"
|
vector_store_file_object.status = "failed"
|
||||||
|
@ -570,7 +551,24 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
code="server_error",
|
code="server_error",
|
||||||
message=str(e),
|
message=str(e),
|
||||||
)
|
)
|
||||||
return vector_store_file_object
|
|
||||||
|
# Create OpenAI vector store file metadata
|
||||||
|
file_info = vector_store_file_object.model_dump(exclude={"last_error"})
|
||||||
|
|
||||||
|
# Save vector store file to persistent storage (provider-specific)
|
||||||
|
await self._save_openai_vector_store_file(vector_store_id, file_id, file_info)
|
||||||
|
|
||||||
|
# Update file_ids and file_counts in vector store metadata
|
||||||
|
store_info = self.openai_vector_stores[vector_store_id].copy()
|
||||||
|
store_info["file_ids"].append(file_id)
|
||||||
|
store_info["file_counts"]["total"] += 1
|
||||||
|
store_info["file_counts"][vector_store_file_object.status] += 1
|
||||||
|
|
||||||
|
# Save updated vector store to persistent storage
|
||||||
|
await self._save_openai_vector_store(vector_store_id, store_info)
|
||||||
|
|
||||||
|
# Update vector store in-memory cache
|
||||||
|
self.openai_vector_stores[vector_store_id] = store_info
|
||||||
|
|
||||||
return vector_store_file_object
|
return vector_store_file_object
|
||||||
|
|
||||||
|
@ -643,12 +641,17 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id)
|
file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id)
|
||||||
await self._delete_openai_vector_store_file_from_storage(vector_store_id, file_id)
|
await self._delete_openai_vector_store_file_from_storage(vector_store_id, file_id)
|
||||||
|
|
||||||
|
# TODO: We need to actually delete the embeddings from the underlying vector store...
|
||||||
|
|
||||||
# Update in-memory cache
|
# Update in-memory cache
|
||||||
store_info["file_ids"].remove(file_id)
|
store_info["file_ids"].remove(file_id)
|
||||||
store_info["file_counts"][file.status] -= 1
|
store_info["file_counts"][file.status] -= 1
|
||||||
store_info["file_counts"]["total"] -= 1
|
store_info["file_counts"]["total"] -= 1
|
||||||
self.openai_vector_stores[vector_store_id] = store_info
|
self.openai_vector_stores[vector_store_id] = store_info
|
||||||
|
|
||||||
|
# Save updated vector store to persistent storage
|
||||||
|
await self._save_openai_vector_store(vector_store_id, store_info)
|
||||||
|
|
||||||
return VectorStoreFileDeleteResponse(
|
return VectorStoreFileDeleteResponse(
|
||||||
id=file_id,
|
id=file_id,
|
||||||
deleted=True,
|
deleted=True,
|
||||||
|
|
|
@ -481,6 +481,59 @@ def test_openai_vector_store_attach_file_response_attributes(compat_client_with_
|
||||||
assert updated_vector_store.file_counts.in_progress == 0
|
assert updated_vector_store.file_counts.in_progress == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_vector_store_attach_files_on_creation(compat_client_with_empty_stores, client_with_models):
|
||||||
|
"""Test OpenAI vector store attach files on creation."""
|
||||||
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
||||||
|
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||||
|
pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient")
|
||||||
|
|
||||||
|
compat_client = compat_client_with_empty_stores
|
||||||
|
|
||||||
|
# Create some files and attach them to the vector store
|
||||||
|
valid_file_ids = []
|
||||||
|
for i in range(3):
|
||||||
|
with BytesIO(f"This is a test file {i}".encode()) as file_buffer:
|
||||||
|
file_buffer.name = f"openai_test_{i}.txt"
|
||||||
|
file = compat_client.files.create(file=file_buffer, purpose="assistants")
|
||||||
|
valid_file_ids.append(file.id)
|
||||||
|
|
||||||
|
# include an invalid file ID so we can test failed status
|
||||||
|
file_ids = valid_file_ids + ["invalid_file_id"]
|
||||||
|
|
||||||
|
# Create a vector store
|
||||||
|
vector_store = compat_client.vector_stores.create(
|
||||||
|
name="test_store",
|
||||||
|
file_ids=file_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert vector_store.file_counts.completed == len(valid_file_ids)
|
||||||
|
assert vector_store.file_counts.total == len(file_ids)
|
||||||
|
assert vector_store.file_counts.cancelled == 0
|
||||||
|
assert vector_store.file_counts.failed == len(file_ids) - len(valid_file_ids)
|
||||||
|
assert vector_store.file_counts.in_progress == 0
|
||||||
|
|
||||||
|
files_list = compat_client.vector_stores.files.list(vector_store_id=vector_store.id)
|
||||||
|
assert len(files_list.data) == len(file_ids)
|
||||||
|
assert set(file_ids) == {file.id for file in files_list.data}
|
||||||
|
for file in files_list.data:
|
||||||
|
if file.id in valid_file_ids:
|
||||||
|
assert file.status == "completed"
|
||||||
|
else:
|
||||||
|
assert file.status == "failed"
|
||||||
|
|
||||||
|
# Delete the invalid file
|
||||||
|
delete_response = compat_client.vector_stores.files.delete(
|
||||||
|
vector_store_id=vector_store.id, file_id="invalid_file_id"
|
||||||
|
)
|
||||||
|
assert delete_response.id == "invalid_file_id"
|
||||||
|
|
||||||
|
updated_vector_store = compat_client.vector_stores.retrieve(vector_store_id=vector_store.id)
|
||||||
|
assert updated_vector_store.file_counts.completed == len(valid_file_ids)
|
||||||
|
assert updated_vector_store.file_counts.total == len(valid_file_ids)
|
||||||
|
assert updated_vector_store.file_counts.failed == 0
|
||||||
|
|
||||||
|
|
||||||
def test_openai_vector_store_list_files(compat_client_with_empty_stores, client_with_models):
|
def test_openai_vector_store_list_files(compat_client_with_empty_stores, client_with_models):
|
||||||
"""Test OpenAI vector store list files."""
|
"""Test OpenAI vector store list files."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
@ -511,7 +564,7 @@ def test_openai_vector_store_list_files(compat_client_with_empty_stores, client_
|
||||||
assert files_list.object == "list"
|
assert files_list.object == "list"
|
||||||
assert files_list.data
|
assert files_list.data
|
||||||
assert len(files_list.data) == 3
|
assert len(files_list.data) == 3
|
||||||
assert file_ids == [file.id for file in files_list.data]
|
assert set(file_ids) == {file.id for file in files_list.data}
|
||||||
assert files_list.data[0].object == "vector_store.file"
|
assert files_list.data[0].object == "vector_store.file"
|
||||||
assert files_list.data[0].vector_store_id == vector_store.id
|
assert files_list.data[0].vector_store_id == vector_store.id
|
||||||
assert files_list.data[0].status == "completed"
|
assert files_list.data[0].status == "completed"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue