# What does this PR do?


## Test Plan
OpenAI processes file attachments asynchronously. Don't mark files as
"completed" immediately after attachment. Instead:

1. Return the status from OpenAI's API response when attaching files
2. Override openai_retrieve_vector_store_file() to check actual status from OpenAI
   when status is "in_progress" and update the cached status
3. Update file counts in vector store metadata when status changes

This allows clients to poll the file status and get accurate processing updates
instead of getting an incorrect "completed" status before OpenAI has finished.
This commit is contained in:
Eric Huang 2025-11-03 21:17:51 -08:00
parent 715d4f8d8c
commit 2367a4ff80
8 changed files with 786 additions and 59 deletions

View file

@ -260,7 +260,7 @@ class VectorStoreSearchResponsePage(BaseModel):
"""
object: str = "vector_store.search_results.page"
search_query: str
search_query: str | list[str]
data: list[VectorStoreSearchResponse]
has_more: bool = False
next_page: str | None = None

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.providers.datatypes import Api
from .config import OpenAIVectorIOConfig
async def get_provider_impl(config: OpenAIVectorIOConfig, deps: dict[Api, Any]):
from .openai import OpenAIVectorIOAdapter
assert isinstance(config, OpenAIVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = OpenAIVectorIOAdapter(
config,
deps[Api.inference],
deps.get(Api.files),
)
await impl.initialize()
return impl

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class OpenAIVectorIOConfig(BaseModel):
api_key: str | None = Field(
None,
description="OpenAI API key. If not provided, will use OPENAI_API_KEY environment variable.",
)
persistence: KVStoreReference = Field(
description="KVStore reference for persisting vector store metadata.",
)
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"api_key": "${OPENAI_API_KEY}",
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::openai",
).model_dump(exclude_none=True),
}

View file

@ -0,0 +1,512 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.files import Files, OpenAIFileObject
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
VectorStoreFileDeleteResponse,
VectorStoreFileLastError,
VectorStoreFileObject,
VectorStoreSearchResponsePage,
)
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, VectorStoresProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion
from .config import OpenAIVectorIOConfig
logger = get_logger(name=__name__, category="vector_io")
# Prefix for storing the mapping from Llama Stack vector store IDs to OpenAI vector store IDs
VECTOR_STORE_ID_MAPPING_PREFIX = "openai_vector_store_id_mapping::"
# Prefix for storing the mapping from Llama Stack file IDs to OpenAI file IDs
FILE_ID_MAPPING_PREFIX = "openai_file_id_mapping::"
class OpenAIVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(
self,
config: OpenAIVectorIOConfig,
inference_api: Inference,
files_api: Files | None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.openai_client = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.persistence)
# Initialize OpenAI client
try:
from openai import AsyncOpenAI
except ImportError as e:
raise RuntimeError(
"OpenAI Python client library is not installed. Please install it with: pip install openai"
) from e
api_key = self.config.api_key or None
if api_key == "${OPENAI_API_KEY}":
api_key = None
try:
self.openai_client = AsyncOpenAI(api_key=api_key)
except Exception as e:
raise RuntimeError("Failed to initialize OpenAI client") from e
# Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores()
async def _store_vector_store_id_mapping(self, llama_stack_id: str, openai_id: str) -> None:
"""Store mapping from Llama Stack vector store ID to OpenAI vector store ID."""
if self.kvstore:
key = f"{VECTOR_STORE_ID_MAPPING_PREFIX}{llama_stack_id}"
await self.kvstore.set(key, openai_id)
async def _get_openai_vector_store_id(self, llama_stack_id: str) -> str:
"""Get OpenAI vector store ID from Llama Stack vector store ID.
Raises ValueError if mapping is not found.
"""
if self.kvstore:
key = f"{VECTOR_STORE_ID_MAPPING_PREFIX}{llama_stack_id}"
try:
openai_id = await self.kvstore.get(key)
if openai_id:
return openai_id
except Exception:
pass
# If not found in mapping, raise an error instead of assuming
raise ValueError(f"No OpenAI vector store mapping found for Llama Stack ID: {llama_stack_id}")
async def _delete_vector_store_id_mapping(self, llama_stack_id: str) -> None:
"""Delete mapping for a vector store ID."""
if self.kvstore:
key = f"{VECTOR_STORE_ID_MAPPING_PREFIX}{llama_stack_id}"
try:
await self.kvstore.delete(key)
except Exception:
pass
async def _store_file_id_mapping(self, llama_stack_file_id: str, openai_file_id: str) -> None:
"""Store mapping from Llama Stack file ID to OpenAI file ID."""
if self.kvstore:
key = f"{FILE_ID_MAPPING_PREFIX}{llama_stack_file_id}"
await self.kvstore.set(key, openai_file_id)
async def _get_openai_file_id(self, llama_stack_file_id: str) -> str | None:
"""Get OpenAI file ID from Llama Stack file ID. Returns None if not found."""
if self.kvstore:
key = f"{FILE_ID_MAPPING_PREFIX}{llama_stack_file_id}"
try:
openai_id = await self.kvstore.get(key)
if openai_id:
return openai_id
except Exception:
pass
return None
async def _get_llama_stack_file_id(self, openai_file_id: str) -> str | None:
"""Get Llama Stack file ID from OpenAI file ID. Returns None if not found."""
if self.kvstore:
# For reverse lookup, we need to search through all mappings
prefix = FILE_ID_MAPPING_PREFIX
start_key = prefix
end_key = f"{prefix}\xff"
try:
items = await self.kvstore.items_in_range(start_key, end_key)
for key, value in items:
if value == openai_file_id:
# Extract the Llama Stack file ID from the key
return key[len(prefix) :]
except Exception:
pass
return None
async def _delete_file_id_mapping(self, llama_stack_file_id: str) -> None:
"""Delete mapping for a file ID."""
if self.kvstore:
key = f"{FILE_ID_MAPPING_PREFIX}{llama_stack_file_id}"
try:
await self.kvstore.delete(key)
except Exception:
pass
async def shutdown(self) -> None:
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to OpenAI API.
"""
try:
if self.openai_client is None:
return HealthResponse(
status=HealthStatus.ERROR,
message="OpenAI client not initialized",
)
# Try to list models as a simple health check
await self.openai_client.models.list()
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check failed: {str(e)}",
)
async def register_vector_store(self, vector_store: VectorStore) -> None:
"""Register a vector store by creating it in OpenAI's API."""
if self.openai_client is None:
raise RuntimeError("OpenAI client not initialized")
# Create vector store in OpenAI
created_store = await self.openai_client.vector_stores.create(
name=vector_store.vector_store_name or vector_store.identifier,
)
# Store mapping from Llama Stack ID to OpenAI ID
await self._store_vector_store_id_mapping(vector_store.identifier, created_store.id)
logger.info(f"Created OpenAI vector store: {created_store.id} for identifier: {vector_store.identifier}")
async def unregister_vector_store(self, vector_store_id: str) -> None:
"""Delete a vector store from OpenAI's API."""
if self.openai_client is None:
raise RuntimeError("OpenAI client not initialized")
try:
# Look up the OpenAI ID from our mapping
if self.kvstore:
key = f"{VECTOR_STORE_ID_MAPPING_PREFIX}{vector_store_id}"
try:
openai_vector_store_id = await self.kvstore.get(key)
if openai_vector_store_id:
await self.openai_client.vector_stores.delete(openai_vector_store_id)
logger.info(
f"Deleted OpenAI vector store: {openai_vector_store_id} for identifier: {vector_store_id}"
)
else:
logger.warning(f"No OpenAI vector store mapping found for {vector_store_id}, skipping deletion")
except Exception as e:
logger.warning(f"Failed to delete vector store {vector_store_id} from OpenAI: {e}", exc_info=True)
# Clean up the mapping
await self._delete_vector_store_id_mapping(vector_store_id)
except Exception as e:
logger.warning(f"Error in unregister_vector_store for {vector_store_id}: {e}", exc_info=True)
async def insert_chunks(
self,
vector_store_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
"""
OpenAI Vector Stores API doesn't support direct chunk insertion.
Use file attachment instead via openai_attach_file_to_vector_store.
"""
raise NotImplementedError(
"Direct chunk insertion is not supported by OpenAI Vector Stores API. "
"Please use file attachment instead via the openai_attach_file_to_vector_store endpoint."
)
async def query_chunks(
self,
vector_store_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
"""
OpenAI Vector Stores API doesn't support direct chunk queries.
Use the OpenAI vector store search API instead.
"""
raise NotImplementedError(
"Direct chunk querying is not supported by OpenAI Vector Stores API. "
"Please use the openai_search_vector_store endpoint instead."
)
async def delete_chunks(
self,
store_id: str,
chunks_for_deletion: list[ChunkForDeletion],
) -> None:
"""
OpenAI Vector Stores API doesn't support direct chunk deletion.
Delete files from the vector store instead.
"""
raise NotImplementedError(
"Direct chunk deletion is not supported by OpenAI Vector Stores API. "
"Please delete files from the vector store instead via openai_delete_vector_store_file."
)
async def _prepare_and_attach_file_chunks(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
chunking_strategy: Any,
created_at: int,
) -> tuple[Any, list[Chunk], Any]:
"""
Override to download file from Llama Stack, upload to OpenAI,
and attach to OpenAI vector store instead of storing chunks locally.
Returns: (VectorStoreFileObject, empty chunks list, file response)
"""
# Translate Llama Stack ID to OpenAI ID
try:
openai_vector_store_id = await self._get_openai_vector_store_id(vector_store_id)
except ValueError as e:
logger.error(f"Cannot attach file to vector store {vector_store_id}: {e}")
return (
VectorStoreFileObject(
id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
created_at=created_at,
status="failed",
vector_store_id=vector_store_id,
last_error=VectorStoreFileLastError(
code="server_error",
message=str(e),
),
),
[],
None,
)
vector_store_file_object = VectorStoreFileObject(
id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
created_at=created_at,
status="in_progress",
vector_store_id=vector_store_id,
)
# Prepare file: download from Llama Stack if needed, upload to OpenAI
try:
file_obj: OpenAIFileObject = await self.files_api.openai_retrieve_file(file_id)
file_content_response = await self.files_api.openai_retrieve_file_content(file_id)
file_data = file_content_response.body
import io
file_buffer = io.BytesIO(file_data)
file_buffer.name = file_obj.filename
openai_file = await self.openai_client.files.create(
file=file_buffer,
purpose="assistants",
)
logger.info(f"Uploaded file {file_id} to OpenAI as {openai_file.id}")
openai_file_id = openai_file.id
# Store mapping for later lookup
await self._store_file_id_mapping(file_id, openai_file_id)
except Exception as e:
logger.debug(f"Could not retrieve file {file_id} from Llama Stack: {e}. Using file_id directly.")
openai_file_id = file_id
# Attach file to OpenAI vector store
try:
attached_file = await self.openai_client.vector_stores.files.create(
vector_store_id=openai_vector_store_id,
file_id=openai_file_id,
)
logger.info(
f"Attached file {openai_file_id} to OpenAI vector store {openai_vector_store_id}, "
f"status: {attached_file.status}"
)
# Use the status from OpenAI's response, don't assume it's completed
vector_store_file_object.status = attached_file.status
file_response = file_obj if "file_obj" in locals() else None
except Exception as e:
logger.error(f"Failed to attach file {openai_file_id} to vector store: {e}")
vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError(
code="server_error",
message=str(e),
)
file_response = file_obj if "file_obj" in locals() else None
# Return VectorStoreFileObject and empty chunks (OpenAI handles storage)
return vector_store_file_object, [], file_response
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
"""Search a vector store using OpenAI's native search API."""
assert self.openai_client is not None
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
openai_vector_store_id = await self._get_openai_vector_store_id(vector_store_id)
# raise ValueError(f"openai_vector_store_id: {openai_vector_store_id}")
logger.info(f"openai_vector_store_id: {openai_vector_store_id}")
response = await self.openai_client.vector_stores.search(
vector_store_id=openai_vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
)
payload = response.model_dump()
logger.info(f"payload: {payload}")
# Remap OpenAI file IDs back to Llama Stack file IDs in results
if payload.get("data"):
for result in payload["data"]:
if result.get("file_id"):
llama_stack_file_id = await self._get_llama_stack_file_id(result["file_id"])
if llama_stack_file_id:
result["file_id"] = llama_stack_file_id
return VectorStoreSearchResponsePage(**payload)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
"""Delete a file from a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
if self.openai_client is None:
raise RuntimeError("OpenAI client not initialized")
try:
# Get the OpenAI file ID for this Llama Stack file ID
openai_file_id = await self._get_openai_file_id(file_id)
if not openai_file_id:
# If no mapping, use the file_id as-is (may be native OpenAI file ID)
openai_file_id = file_id
# Get OpenAI vector store ID
openai_vector_store_id = await self._get_openai_vector_store_id(vector_store_id)
# Delete file from OpenAI vector store
await self.openai_client.vector_stores.files.delete(
vector_store_id=openai_vector_store_id,
file_id=openai_file_id,
)
logger.info(f"Deleted file {openai_file_id} from OpenAI vector store {openai_vector_store_id}")
# Delete the file from OpenAI if it was created by us
if await self._get_openai_file_id(file_id):
try:
await self.openai_client.files.delete(openai_file_id)
logger.info(f"Deleted OpenAI file {openai_file_id}")
except Exception as e:
logger.debug(f"Could not delete OpenAI file {openai_file_id}: {e}")
# Clean up mappings
await self._delete_file_id_mapping(file_id)
# Update vector store metadata
store_info = self.openai_vector_stores[vector_store_id].copy()
if file_id in store_info["file_ids"]:
store_info["file_ids"].remove(file_id)
store_info["file_counts"]["total"] -= 1
store_info["file_counts"]["completed"] -= 1
self.openai_vector_stores[vector_store_id] = store_info
await self._save_openai_vector_store(vector_store_id, store_info)
return VectorStoreFileDeleteResponse(
id=file_id,
deleted=True,
)
except Exception as e:
logger.error(f"Error deleting file {file_id} from vector store {vector_store_id}: {e}")
raise
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
"""Retrieve a vector store file and check status from OpenAI if still in_progress."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
if self.openai_client is None:
raise RuntimeError("OpenAI client not initialized")
# Get the cached file info
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
file_object = VectorStoreFileObject(**file_info)
# If status is still in_progress, check the actual status from OpenAI
if file_object.status == "in_progress":
try:
# Get OpenAI file ID for this Llama Stack file ID
openai_file_id = await self._get_openai_file_id(file_id)
if not openai_file_id:
openai_file_id = file_id
# Get OpenAI vector store ID
openai_vector_store_id = await self._get_openai_vector_store_id(vector_store_id)
# Retrieve the file status from OpenAI
openai_file = await self.openai_client.vector_stores.files.retrieve(
vector_store_id=openai_vector_store_id,
file_id=openai_file_id,
)
# Update the status from OpenAI
file_object.status = openai_file.status
# If status changed, update it in storage
if openai_file.status != "in_progress":
file_info["status"] = openai_file.status
# Update file counts in vector store metadata
store_info = self.openai_vector_stores[vector_store_id].copy()
if file_object.status == "completed":
store_info["file_counts"]["in_progress"] = max(
0, store_info["file_counts"].get("in_progress", 0) - 1
)
store_info["file_counts"]["completed"] = (
store_info["file_counts"].get("completed", 0) + 1
)
elif file_object.status == "failed":
store_info["file_counts"]["in_progress"] = max(
0, store_info["file_counts"].get("in_progress", 0) - 1
)
store_info["file_counts"]["failed"] = store_info["file_counts"].get("failed", 0) + 1
self.openai_vector_stores[vector_store_id] = store_info
await self._save_openai_vector_store_file(vector_store_id, file_id, file_info)
await self._save_openai_vector_store(vector_store_id, store_info)
except Exception as e:
logger.debug(f"Could not retrieve file status from OpenAI: {e}. Using cached status.")
return file_object

View file

@ -78,6 +78,75 @@ pip install faiss-cpu
## Documentation
See [Faiss' documentation](https://faiss.ai/) or the [Faiss Wiki](https://github.com/facebookresearch/faiss/wiki) for
more details about Faiss in general.
""",
),
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="openai",
provider_type="remote::openai",
pip_packages=["openai"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.openai",
config_class="llama_stack.providers.inline.vector_io.openai.OpenAIVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files, Api.models],
description="""
[OpenAI Vector Stores](https://platform.openai.com/docs/guides/vector-stores) is a remote vector database provider for Llama Stack that uses OpenAI's Vector Stores API.
It allows you to store and query vectors using OpenAI's embeddings and vector store infrastructure.
## Features
- Direct integration with OpenAI's Vector Stores API
- File attachment support for batch vector store operations
- OpenAI-compatible API endpoints
- Full-text search and vector search capabilities
- Metadata filtering
## Search Modes
**Supported:**
- **Vector Search** (`mode="vector"`): Performs vector similarity search using OpenAI embeddings
**Not Supported:**
- **Keyword Search** (`mode="keyword"`): Not supported by OpenAI Vector Stores API
- **Hybrid Search** (`mode="hybrid"`): Not supported by OpenAI Vector Stores API
## Configuration
To use this provider, you need to provide:
1. **API Key**: Either pass `api_key` in the config or set the `OPENAI_API_KEY` environment variable
2. **Persistence**: A KVStore backend for storing metadata
### Example Configuration
```yaml
vector_io:
- provider_id: openai
provider_type: remote::openai
config:
api_key: ${OPENAI_API_KEY}
persistence:
backend: kv_default
namespace: vector_io::openai
```
## Installation
Install the OpenAI Python client:
```bash
pip install openai
```
## Limitations
- OpenAI Vector Stores API currently supports file-based attachments primarily
- Direct chunk insertion uses OpenAI's embeddings API
- For queries, only vector search mode is natively supported
## Documentation
See [OpenAI Vector Stores API documentation](https://platform.openai.com/docs/guides/vector-stores) for more details.
""",
),
# NOTE: sqlite-vec cannot be bundled into the container image because it does not have a

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.vector_io.openai.config import OpenAIVectorIOConfig
async def get_adapter_impl(config: OpenAIVectorIOConfig, deps: dict[Api, Any]):
"""Remote adapter for OpenAI Vector Store provider - delegates to inline implementation."""
from llama_stack.providers.inline.vector_io.openai.openai import OpenAIVectorIOAdapter
assert isinstance(config, OpenAIVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = OpenAIVectorIOAdapter(
config,
deps[Api.inference],
deps.get(Api.files),
)
await impl.initialize()
return impl

View file

@ -729,6 +729,70 @@ class OpenAIVectorStoreMixin(ABC):
]
return content
async def _prepare_and_attach_file_chunks(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
chunking_strategy: VectorStoreChunkingStrategy,
created_at: int,
) -> tuple[VectorStoreFileObject, list[Chunk], OpenAIFileObject | None]:
"""
Implementation-specific method for preparing and attaching file chunks to vector store.
Subclasses can override this to customize how files are prepared and attached.
Returns: (VectorStoreFileObject, chunks, file_response) tuple
"""
if isinstance(chunking_strategy, VectorStoreChunkingStrategyStatic):
max_chunk_size_tokens = chunking_strategy.static.max_chunk_size_tokens
chunk_overlap_tokens = chunking_strategy.static.chunk_overlap_tokens
else:
# Default values from OpenAI API spec
max_chunk_size_tokens = 800
chunk_overlap_tokens = 400
file_response = await self.files_api.openai_retrieve_file(file_id)
mime_type, _ = mimetypes.guess_type(file_response.filename)
content_response = await self.files_api.openai_retrieve_file_content(file_id)
content = content_from_data_and_mime_type(content_response.body, mime_type)
chunk_attributes = attributes.copy()
chunk_attributes["filename"] = file_response.filename
chunks = make_overlapped_chunks(
file_id,
content,
max_chunk_size_tokens,
chunk_overlap_tokens,
chunk_attributes,
)
vector_store_file_object = VectorStoreFileObject(
id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
created_at=created_at,
status="in_progress",
vector_store_id=vector_store_id,
)
if not chunks:
vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError(
code="server_error",
message="No chunks were generated from the file",
)
else:
# Default implementation: insert chunks directly
await self.insert_chunks(
vector_store_id=vector_store_id,
chunks=chunks,
)
vector_store_file_object.status = "completed"
return vector_store_file_object, chunks, file_response
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
@ -750,69 +814,43 @@ class OpenAIVectorStoreMixin(ABC):
attributes = attributes or {}
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
created_at = int(time.time())
chunks: list[Chunk] = []
file_response: OpenAIFileObject | None = None
vector_store_file_object = VectorStoreFileObject(
id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
created_at=created_at,
status="in_progress",
vector_store_id=vector_store_id,
)
if not hasattr(self, "files_api") or not self.files_api:
vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError(
code="server_error",
message="Files API is not available",
vector_store_file_object = VectorStoreFileObject(
id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
created_at=created_at,
status="failed",
vector_store_id=vector_store_id,
last_error=VectorStoreFileLastError(
code="server_error",
message="Files API is not available",
),
)
return vector_store_file_object
if isinstance(chunking_strategy, VectorStoreChunkingStrategyStatic):
max_chunk_size_tokens = chunking_strategy.static.max_chunk_size_tokens
chunk_overlap_tokens = chunking_strategy.static.chunk_overlap_tokens
else:
# Default values from OpenAI API spec
max_chunk_size_tokens = 800
chunk_overlap_tokens = 400
try:
file_response = await self.files_api.openai_retrieve_file(file_id)
mime_type, _ = mimetypes.guess_type(file_response.filename)
content_response = await self.files_api.openai_retrieve_file_content(file_id)
content = content_from_data_and_mime_type(content_response.body, mime_type)
chunk_attributes = attributes.copy()
chunk_attributes["filename"] = file_response.filename
chunks = make_overlapped_chunks(
file_id,
content,
max_chunk_size_tokens,
chunk_overlap_tokens,
chunk_attributes,
vector_store_file_object, chunks, file_response = await self._prepare_and_attach_file_chunks(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
created_at=created_at,
)
if not chunks:
vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError(
code="server_error",
message="No chunks were generated from the file",
)
else:
await self.insert_chunks(
vector_store_id=vector_store_id,
chunks=chunks,
)
vector_store_file_object.status = "completed"
except Exception as e:
logger.exception("Error attaching file to vector store")
vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError(
code="server_error",
message=str(e),
vector_store_file_object = VectorStoreFileObject(
id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
created_at=created_at,
status="failed",
vector_store_id=vector_store_id,
last_error=VectorStoreFileLastError(
code="server_error",
message=str(e),
),
)
# Create OpenAI vector store file metadata
@ -820,8 +858,8 @@ class OpenAIVectorStoreMixin(ABC):
file_info["filename"] = file_response.filename if file_response else ""
# Save vector store file to persistent storage (provider-specific)
dict_chunks = [c.model_dump() for c in chunks]
# This should be updated to include chunk_id
# Only save chunks if they were generated (some providers like OpenAI handle storage remotely)
dict_chunks = [c.model_dump() for c in chunks] if chunks else []
await self._save_openai_vector_store_file(vector_store_id, file_id, file_info, dict_chunks)
# Update file_ids and file_counts in vector store metadata

View file

@ -80,6 +80,17 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
)
def skip_if_provider_is_openai_vector_store(client_with_models):
"""Skip tests that require direct chunk insertion/querying (not supported by OpenAI)."""
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
for p in vector_io_providers:
if p.provider_type == "remote::openai":
pytest.skip(
"OpenAI Vector Stores provider does not support direct chunk insertion/querying. "
"Use file attachment instead."
)
@pytest.fixture(scope="session")
def sample_chunks():
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
@ -144,8 +155,8 @@ def compat_client_with_empty_stores(compat_client):
yield compat_client
# Clean up after the test
clear_vector_stores()
clear_files()
# clear_vector_stores()
# clear_files()
@vector_provider_wrapper
@ -365,6 +376,7 @@ def test_openai_vector_store_with_chunks(
):
"""Test vector store functionality with actual chunks using both OpenAI and native APIs."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
skip_if_provider_is_openai_vector_store(client_with_models)
compat_client = compat_client_with_empty_stores
llama_client = client_with_models
@ -430,6 +442,7 @@ def test_openai_vector_store_search_relevance(
):
"""Test that OpenAI vector store search returns relevant results for different queries."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
skip_if_provider_is_openai_vector_store(client_with_models)
compat_client = compat_client_with_empty_stores
llama_client = client_with_models
@ -482,6 +495,7 @@ def test_openai_vector_store_search_with_ranking_options(
):
"""Test OpenAI vector store search with ranking options."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
skip_if_provider_is_openai_vector_store(client_with_models)
compat_client = compat_client_with_empty_stores
llama_client = client_with_models
@ -542,6 +556,7 @@ def test_openai_vector_store_search_with_high_score_filter(
):
"""Test that searching with text very similar to a document and high score threshold returns only that document."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
skip_if_provider_is_openai_vector_store(client_with_models)
compat_client = compat_client_with_empty_stores
llama_client = client_with_models
@ -608,6 +623,7 @@ def test_openai_vector_store_search_with_max_num_results(
):
"""Test OpenAI vector store search with max_num_results."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
skip_if_provider_is_openai_vector_store(client_with_models)
compat_client = compat_client_with_empty_stores
llama_client = client_with_models
@ -678,6 +694,13 @@ def test_openai_vector_store_attach_file(
assert file_attach_response.object == "vector_store.file"
assert file_attach_response.id == file.id
assert file_attach_response.vector_store_id == vector_store.id
start_time = time.time()
while file_attach_response.status != "completed" and time.time() - start_time < 10:
file_attach_response = compat_client.vector_stores.files.retrieve(
vector_store_id=vector_store.id,
file_id=file.id,
)
assert file_attach_response.status == "completed"
assert file_attach_response.chunking_strategy.type == "auto"
assert file_attach_response.created_at > 0
@ -1178,6 +1201,7 @@ def test_openai_vector_store_search_modes(
):
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_models, search_mode)
skip_if_provider_is_openai_vector_store(client_with_models)
vector_store = llama_stack_client.vector_stores.create(
name=f"search_mode_test_{search_mode}",