diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index f5958754e..18f4f49b9 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -12583,6 +12583,9 @@
}
},
"additionalProperties": false,
+ "required": [
+ "name"
+ ],
"title": "OpenaiCreateVectorStoreRequest"
},
"VectorStoreObject": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index c302e58fb..4dbcd3ac9 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -8791,6 +8791,8 @@ components:
description: >-
The provider-specific vector database ID.
additionalProperties: false
+ required:
+ - name
title: OpenaiCreateVectorStoreRequest
VectorStoreObject:
type: object
diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py
index 229d62386..5f54539a4 100644
--- a/llama_stack/apis/vector_io/vector_io.py
+++ b/llama_stack/apis/vector_io/vector_io.py
@@ -165,7 +165,7 @@ class VectorIO(Protocol):
@webmethod(route="/openai/v1/vector_stores", method="POST")
async def openai_create_vector_store(
self,
- name: str | None = None,
+ name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py
index 6e7bb5edd..e71ff8092 100644
--- a/llama_stack/distribution/resolver.py
+++ b/llama_stack/distribution/resolver.py
@@ -394,9 +394,13 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
- # Check if the method is actually implemented in the class
- method_owner = next((cls for cls in mro if name in cls.__dict__), None)
- if method_owner is None or method_owner.__name__ == protocol.__name__:
+ # Check if the method has a concrete implementation (not just a protocol stub)
+ # Find all classes in MRO that define this method
+ method_owners = [cls for cls in mro if name in cls.__dict__]
+
+ # Allow methods from mixins/parents, only reject if ONLY the protocol defines it
+ if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__:
+ # Only reject if the method is ONLY defined in the protocol itself (abstract stub)
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:
diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py
index 92ddcbcfa..30b19c436 100644
--- a/llama_stack/distribution/routers/vector_io.py
+++ b/llama_stack/distribution/routers/vector_io.py
@@ -108,7 +108,7 @@ class VectorIORouter(VectorIO):
# OpenAI Vector Stores API endpoints
async def openai_create_vector_store(
self,
- name: str | None = None,
+ name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py
index d0b195cbe..5e9155011 100644
--- a/llama_stack/providers/inline/vector_io/faiss/faiss.py
+++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py
@@ -9,9 +9,7 @@ import base64
import io
import json
import logging
-import time
-import uuid
-from typing import Any, Literal
+from typing import Any
import faiss
import numpy as np
@@ -24,14 +22,11 @@ from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
- VectorStoreDeleteResponse,
- VectorStoreListResponse,
- VectorStoreObject,
- VectorStoreSearchResponse,
)
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
+from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
@@ -47,10 +42,6 @@ FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
-# In faiss, since we do
-CHUNK_MULTIPLIER = 5
-
-
class FaissIndex(EmbeddingIndex):
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
self.index = faiss.IndexFlatL2(dimension)
@@ -140,7 +131,7 @@ class FaissIndex(EmbeddingIndex):
raise NotImplementedError("Keyword search is not supported in FAISS")
-class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
+class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
self.config = config
self.inference_api = inference_api
@@ -164,14 +155,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
)
self.cache[vector_db.identifier] = index
- # Load existing OpenAI vector stores
- start_key = OPENAI_VECTOR_STORES_PREFIX
- end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
- stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
-
- for store_data in stored_openai_stores:
- store_info = json.loads(store_data)
- self.openai_vector_stores[store_info["id"]] = store_info
+ # Load existing OpenAI vector stores using the mixin method
+ self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
# Cleanup if needed
@@ -234,285 +219,34 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
return await index.query_chunks(query, params)
- # OpenAI Vector Stores API endpoints implementation
- async def openai_create_vector_store(
- self,
- name: str | None = None,
- file_ids: list[str] | None = None,
- expires_after: dict[str, Any] | None = None,
- chunking_strategy: dict[str, Any] | None = None,
- metadata: dict[str, Any] | None = None,
- embedding_model: str | None = None,
- embedding_dimension: int | None = 384,
- provider_id: str | None = None,
- provider_vector_db_id: str | None = None,
- ) -> VectorStoreObject:
- """Creates a vector store."""
+ # OpenAI Vector Store Mixin abstract method implementations
+ async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
+ """Save vector store metadata to kvstore."""
assert self.kvstore is not None
- # store and vector_db have the same id
- store_id = name or str(uuid.uuid4())
- created_at = int(time.time())
-
- if provider_id is None:
- raise ValueError("Provider ID is required")
-
- if embedding_model is None:
- raise ValueError("Embedding model is required")
-
- # Use provided embedding dimension or default to 384
- if embedding_dimension is None:
- raise ValueError("Embedding dimension is required")
-
- provider_vector_db_id = provider_vector_db_id or store_id
- vector_db = VectorDB(
- identifier=store_id,
- embedding_dimension=embedding_dimension,
- embedding_model=embedding_model,
- provider_id=provider_id,
- provider_resource_id=provider_vector_db_id,
- )
-
- # Register the vector DB
- await self.register_vector_db(vector_db)
-
- # Create OpenAI vector store metadata
- store_info = {
- "id": store_id,
- "object": "vector_store",
- "created_at": created_at,
- "name": store_id,
- "usage_bytes": 0,
- "file_counts": {},
- "status": "completed",
- "expires_after": expires_after,
- "expires_at": None,
- "last_active_at": created_at,
- "file_ids": file_ids or [],
- "chunking_strategy": chunking_strategy,
- }
-
- # Add provider information to metadata if provided
- metadata = metadata or {}
- if provider_id:
- metadata["provider_id"] = provider_id
- if provider_vector_db_id:
- metadata["provider_vector_db_id"] = provider_vector_db_id
- store_info["metadata"] = metadata
-
- # Store in kvstore
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
- # Store in memory cache
- self.openai_vector_stores[store_id] = store_info
-
- return VectorStoreObject(
- id=store_id,
- created_at=created_at,
- name=store_id,
- usage_bytes=0,
- file_counts={},
- status="completed",
- expires_after=expires_after,
- expires_at=None,
- last_active_at=created_at,
- metadata=metadata,
- )
-
- async def openai_list_vector_stores(
- self,
- limit: int = 20,
- order: str = "desc",
- after: str | None = None,
- before: str | None = None,
- ) -> VectorStoreListResponse:
- """Returns a list of vector stores."""
- # Get all vector stores
- all_stores = list(self.openai_vector_stores.values())
-
- # Sort by created_at
- reverse_order = order == "desc"
- all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
-
- # Apply cursor-based pagination
- if after:
- after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
- if after_index >= 0:
- all_stores = all_stores[after_index + 1 :]
-
- if before:
- before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
- all_stores = all_stores[:before_index]
-
- # Apply limit
- limited_stores = all_stores[:limit]
- # Convert to VectorStoreObject instances
- data = [VectorStoreObject(**store) for store in limited_stores]
-
- # Determine pagination info
- has_more = len(all_stores) > limit
- first_id = data[0].id if data else None
- last_id = data[-1].id if data else None
-
- return VectorStoreListResponse(
- data=data,
- has_more=has_more,
- first_id=first_id,
- last_id=last_id,
- )
-
- async def openai_retrieve_vector_store(
- self,
- vector_store_id: str,
- ) -> VectorStoreObject:
- """Retrieves a vector store."""
- if vector_store_id not in self.openai_vector_stores:
- raise ValueError(f"Vector store {vector_store_id} not found")
-
- store_info = self.openai_vector_stores[vector_store_id]
- return VectorStoreObject(**store_info)
-
- async def openai_update_vector_store(
- self,
- vector_store_id: str,
- name: str | None = None,
- expires_after: dict[str, Any] | None = None,
- metadata: dict[str, Any] | None = None,
- ) -> VectorStoreObject:
- """Modifies a vector store."""
+ async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
+ """Load all vector store metadata from kvstore."""
assert self.kvstore is not None
- if vector_store_id not in self.openai_vector_stores:
- raise ValueError(f"Vector store {vector_store_id} not found")
+ start_key = OPENAI_VECTOR_STORES_PREFIX
+ end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
+ stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
- store_info = self.openai_vector_stores[vector_store_id].copy()
+ stores = {}
+ for store_data in stored_openai_stores:
+ store_info = json.loads(store_data)
+ stores[store_info["id"]] = store_info
+ return stores
- # Update fields if provided
- if name is not None:
- store_info["name"] = name
- if expires_after is not None:
- store_info["expires_after"] = expires_after
- if metadata is not None:
- store_info["metadata"] = metadata
-
- # Update last_active_at
- store_info["last_active_at"] = int(time.time())
-
- # Save to kvstore
- key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}"
+ async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
+ """Update vector store metadata in kvstore."""
+ assert self.kvstore is not None
+ key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
- # Update in-memory cache
- self.openai_vector_stores[vector_store_id] = store_info
-
- return VectorStoreObject(**store_info)
-
- async def openai_delete_vector_store(
- self,
- vector_store_id: str,
- ) -> VectorStoreDeleteResponse:
- """Delete a vector store."""
+ async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
+ """Delete vector store metadata from kvstore."""
assert self.kvstore is not None
- if vector_store_id not in self.openai_vector_stores:
- raise ValueError(f"Vector store {vector_store_id} not found")
-
- # Delete from kvstore
- key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}"
+ key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
-
- # Delete from in-memory cache
- del self.openai_vector_stores[vector_store_id]
-
- # Also delete the underlying vector DB
- try:
- await self.unregister_vector_db(vector_store_id)
- except Exception as e:
- logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
-
- return VectorStoreDeleteResponse(
- id=vector_store_id,
- deleted=True,
- )
-
- async def openai_search_vector_store(
- self,
- vector_store_id: str,
- query: str | list[str],
- filters: dict[str, Any] | None = None,
- max_num_results: int = 10,
- ranking_options: dict[str, Any] | None = None,
- rewrite_query: bool = False,
- search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
- ) -> VectorStoreSearchResponse:
- """Search for chunks in a vector store."""
- if vector_store_id not in self.openai_vector_stores:
- raise ValueError(f"Vector store {vector_store_id} not found")
-
- if isinstance(query, list):
- search_query = " ".join(query)
- else:
- search_query = query
-
- try:
- score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
- params = {
- "max_chunks": max_num_results * CHUNK_MULTIPLIER,
- "score_threshold": score_threshold,
- "mode": search_mode,
- }
- # TODO: Add support for ranking_options.ranker
-
- response = await self.query_chunks(
- vector_db_id=vector_store_id,
- query=search_query,
- params=params,
- )
-
- # Convert response to OpenAI format
- data = []
- for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)):
- # Apply score based filtering
- if score < score_threshold:
- continue
-
- # Apply filters if provided
- if filters:
- # Simple metadata filtering
- if not self._matches_filters(chunk.metadata, filters):
- continue
-
- chunk_data = {
- "id": f"chunk_{i}",
- "object": "vector_store.search_result",
- "score": score,
- "content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content),
- "metadata": chunk.metadata,
- }
- data.append(chunk_data)
- if len(data) >= max_num_results:
- break
-
- return VectorStoreSearchResponse(
- search_query=search_query,
- data=data,
- has_more=False, # For simplicity, we don't implement pagination here
- next_page=None,
- )
-
- except Exception as e:
- logger.error(f"Error searching vector store {vector_store_id}: {e}")
- # Return empty results on error
- return VectorStoreSearchResponse(
- search_query=search_query,
- data=[],
- has_more=False,
- next_page=None,
- )
-
- def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
- """Check if metadata matches the provided filters."""
- for key, value in filters.items():
- if key not in metadata:
- return False
- if metadata[key] != value:
- return False
- return True
diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
index 0a0f0e653..02f04e766 100644
--- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
+++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
@@ -10,9 +10,8 @@ import json
import logging
import sqlite3
import struct
-import time
import uuid
-from typing import Any, Literal
+from typing import Any
import numpy as np
import sqlite_vec
@@ -24,12 +23,9 @@ from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
- VectorStoreDeleteResponse,
- VectorStoreListResponse,
- VectorStoreObject,
- VectorStoreSearchResponse,
)
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
+from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
logger = logging.getLogger(__name__)
@@ -39,11 +35,6 @@ VECTOR_SEARCH = "vector"
KEYWORD_SEARCH = "keyword"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
-# Constants for OpenAI vector stores (similar to faiss)
-VERSION = "v3"
-OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
-CHUNK_MULTIPLIER = 5
-
def serialize_vector(vector: list[float]) -> bytes:
"""Serialize a list of floats into a compact binary representation."""
@@ -303,7 +294,7 @@ class SQLiteVecIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
-class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
+class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
"""
A VectorIO implementation using SQLite + sqlite_vec.
This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
@@ -340,15 +331,12 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
# Load any existing vector DB registrations.
cur.execute("SELECT metadata FROM vector_dbs")
vector_db_rows = cur.fetchall()
- # Load any existing OpenAI vector stores.
- cur.execute("SELECT metadata FROM openai_vector_stores")
- openai_store_rows = cur.fetchall()
- return vector_db_rows, openai_store_rows
+ return vector_db_rows
finally:
cur.close()
connection.close()
- vector_db_rows, openai_store_rows = await asyncio.to_thread(_setup_connection)
+ vector_db_rows = await asyncio.to_thread(_setup_connection)
# Load existing vector DBs
for row in vector_db_rows:
@@ -359,11 +347,8 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
- # Load existing OpenAI vector stores
- for row in openai_store_rows:
- store_data = row[0]
- store_info = json.loads(store_data)
- self.openai_vector_stores[store_info["id"]] = store_info
+ # Load existing OpenAI vector stores using the mixin method
+ self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
@@ -409,6 +394,87 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
await asyncio.to_thread(_delete_vector_db_from_registry)
+ # OpenAI Vector Store Mixin abstract method implementations
+ async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
+ """Save vector store metadata to SQLite database."""
+
+ def _store():
+ connection = _create_sqlite_connection(self.config.db_path)
+ cur = connection.cursor()
+ try:
+ cur.execute(
+ "INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
+ (store_id, json.dumps(store_info)),
+ )
+ connection.commit()
+ except Exception as e:
+ logger.error(f"Error saving openai vector store {store_id}: {e}")
+ raise
+ finally:
+ cur.close()
+ connection.close()
+
+ try:
+ await asyncio.to_thread(_store)
+ except Exception as e:
+ logger.error(f"Error saving openai vector store {store_id}: {e}")
+ raise
+
+ async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
+ """Load all vector store metadata from SQLite database."""
+
+ def _load():
+ connection = _create_sqlite_connection(self.config.db_path)
+ cur = connection.cursor()
+ try:
+ cur.execute("SELECT metadata FROM openai_vector_stores")
+ rows = cur.fetchall()
+ return rows
+ finally:
+ cur.close()
+ connection.close()
+
+ rows = await asyncio.to_thread(_load)
+ stores = {}
+ for row in rows:
+ store_data = row[0]
+ store_info = json.loads(store_data)
+ stores[store_info["id"]] = store_info
+ return stores
+
+ async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
+ """Update vector store metadata in SQLite database."""
+
+ def _update():
+ connection = _create_sqlite_connection(self.config.db_path)
+ cur = connection.cursor()
+ try:
+ cur.execute(
+ "UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
+ (json.dumps(store_info), store_id),
+ )
+ connection.commit()
+ finally:
+ cur.close()
+ connection.close()
+
+ await asyncio.to_thread(_update)
+
+ async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
+ """Delete vector store metadata from SQLite database."""
+
+ def _delete():
+ connection = _create_sqlite_connection(self.config.db_path)
+ cur = connection.cursor()
+ try:
+ cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,))
+ connection.commit()
+ finally:
+ cur.close()
+ connection.close()
+
+ await asyncio.to_thread(_delete)
+
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
@@ -423,318 +489,6 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params)
- async def openai_create_vector_store(
- self,
- name: str | None = None,
- file_ids: list[str] | None = None,
- expires_after: dict[str, Any] | None = None,
- chunking_strategy: dict[str, Any] | None = None,
- metadata: dict[str, Any] | None = None,
- embedding_model: str | None = None,
- embedding_dimension: int | None = 384,
- provider_id: str | None = None,
- provider_vector_db_id: str | None = None,
- ) -> VectorStoreObject:
- """Creates a vector store."""
- # store and vector_db have the same id
- store_id = name or str(uuid.uuid4())
- created_at = int(time.time())
-
- if provider_id is None:
- raise ValueError("Provider ID is required")
-
- if embedding_model is None:
- raise ValueError("Embedding model is required")
-
- # Use provided embedding dimension or default to 384
- if embedding_dimension is None:
- raise ValueError("Embedding dimension is required")
-
- provider_vector_db_id = provider_vector_db_id or store_id
- vector_db = VectorDB(
- identifier=store_id,
- embedding_dimension=embedding_dimension,
- embedding_model=embedding_model,
- provider_id=provider_id,
- provider_resource_id=provider_vector_db_id,
- )
-
- # Register the vector DB
- await self.register_vector_db(vector_db)
-
- # Create OpenAI vector store metadata
- store_info = {
- "id": store_id,
- "object": "vector_store",
- "created_at": created_at,
- "name": store_id,
- "usage_bytes": 0,
- "file_counts": {},
- "status": "completed",
- "expires_after": expires_after,
- "expires_at": None,
- "last_active_at": created_at,
- "file_ids": file_ids or [],
- "chunking_strategy": chunking_strategy,
- }
-
- # Add provider information to metadata if provided
- metadata = metadata or {}
- if provider_id:
- metadata["provider_id"] = provider_id
- if provider_vector_db_id:
- metadata["provider_vector_db_id"] = provider_vector_db_id
- store_info["metadata"] = metadata
-
- # Store in SQLite database
- def _store_openai_vector_store():
- connection = _create_sqlite_connection(self.config.db_path)
- cur = connection.cursor()
- try:
- cur.execute(
- "INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
- (store_id, json.dumps(store_info)),
- )
- connection.commit()
- finally:
- cur.close()
- connection.close()
-
- await asyncio.to_thread(_store_openai_vector_store)
-
- # Store in memory cache
- self.openai_vector_stores[store_id] = store_info
-
- return VectorStoreObject(
- id=store_id,
- created_at=created_at,
- name=store_id,
- usage_bytes=0,
- file_counts={},
- status="completed",
- expires_after=expires_after,
- expires_at=None,
- last_active_at=created_at,
- metadata=metadata,
- )
-
- async def openai_list_vector_stores(
- self,
- limit: int = 20,
- order: str = "desc",
- after: str | None = None,
- before: str | None = None,
- ) -> VectorStoreListResponse:
- """Returns a list of vector stores."""
- # Get all vector stores
- all_stores = list(self.openai_vector_stores.values())
-
- # Sort by created_at
- reverse_order = order == "desc"
- all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
-
- # Apply cursor-based pagination
- if after:
- after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
- if after_index >= 0:
- all_stores = all_stores[after_index + 1 :]
-
- if before:
- before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
- all_stores = all_stores[:before_index]
-
- # Apply limit
- limited_stores = all_stores[:limit]
- # Convert to VectorStoreObject instances
- data = [VectorStoreObject(**store) for store in limited_stores]
-
- # Determine pagination info
- has_more = len(all_stores) > limit
- first_id = data[0].id if data else None
- last_id = data[-1].id if data else None
-
- return VectorStoreListResponse(
- data=data,
- has_more=has_more,
- first_id=first_id,
- last_id=last_id,
- )
-
- async def openai_retrieve_vector_store(
- self,
- vector_store_id: str,
- ) -> VectorStoreObject:
- """Retrieves a vector store."""
- if vector_store_id not in self.openai_vector_stores:
- raise ValueError(f"Vector store {vector_store_id} not found")
-
- store_info = self.openai_vector_stores[vector_store_id]
- return VectorStoreObject(**store_info)
-
- async def openai_update_vector_store(
- self,
- vector_store_id: str,
- name: str | None = None,
- expires_after: dict[str, Any] | None = None,
- metadata: dict[str, Any] | None = None,
- ) -> VectorStoreObject:
- """Modifies a vector store."""
- if vector_store_id not in self.openai_vector_stores:
- raise ValueError(f"Vector store {vector_store_id} not found")
-
- store_info = self.openai_vector_stores[vector_store_id].copy()
-
- # Update fields if provided
- if name is not None:
- store_info["name"] = name
- if expires_after is not None:
- store_info["expires_after"] = expires_after
- if metadata is not None:
- store_info["metadata"] = metadata
-
- # Update last_active_at
- store_info["last_active_at"] = int(time.time())
-
- # Save to SQLite database
- def _update_openai_vector_store():
- connection = _create_sqlite_connection(self.config.db_path)
- cur = connection.cursor()
- try:
- cur.execute(
- "UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
- (json.dumps(store_info), vector_store_id),
- )
- connection.commit()
- finally:
- cur.close()
- connection.close()
-
- await asyncio.to_thread(_update_openai_vector_store)
-
- # Update in-memory cache
- self.openai_vector_stores[vector_store_id] = store_info
-
- return VectorStoreObject(**store_info)
-
- async def openai_delete_vector_store(
- self,
- vector_store_id: str,
- ) -> VectorStoreDeleteResponse:
- """Delete a vector store."""
- if vector_store_id not in self.openai_vector_stores:
- raise ValueError(f"Vector store {vector_store_id} not found")
-
- # Delete from SQLite database
- def _delete_openai_vector_store():
- connection = _create_sqlite_connection(self.config.db_path)
- cur = connection.cursor()
- try:
- cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (vector_store_id,))
- connection.commit()
- finally:
- cur.close()
- connection.close()
-
- await asyncio.to_thread(_delete_openai_vector_store)
-
- # Delete from in-memory cache
- del self.openai_vector_stores[vector_store_id]
-
- # Also delete the underlying vector DB
- try:
- await self.unregister_vector_db(vector_store_id)
- except Exception as e:
- logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
-
- return VectorStoreDeleteResponse(
- id=vector_store_id,
- deleted=True,
- )
-
- async def openai_search_vector_store(
- self,
- vector_store_id: str,
- query: str | list[str],
- filters: dict[str, Any] | None = None,
- max_num_results: int = 10,
- ranking_options: dict[str, Any] | None = None,
- rewrite_query: bool = False,
- search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
- ) -> VectorStoreSearchResponse:
- """Search for chunks in a vector store."""
- if vector_store_id not in self.openai_vector_stores:
- raise ValueError(f"Vector store {vector_store_id} not found")
-
- if isinstance(query, list):
- search_query = " ".join(query)
- else:
- search_query = query
-
- try:
- score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
- params = {
- "max_chunks": max_num_results * CHUNK_MULTIPLIER,
- "score_threshold": score_threshold,
- "mode": search_mode,
- }
- # TODO: Add support for ranking_options.ranker
-
- response = await self.query_chunks(
- vector_db_id=vector_store_id,
- query=search_query,
- params=params,
- )
-
- # Convert response to OpenAI format
- data = []
- for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)):
- # Apply score based filtering
- if score < score_threshold:
- continue
-
- # Apply filters if provided
- if filters:
- # Simple metadata filtering
- if not self._matches_filters(chunk.metadata, filters):
- continue
-
- chunk_data = {
- "id": f"chunk_{i}",
- "object": "vector_store.search_result",
- "score": score,
- "content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content),
- "metadata": chunk.metadata,
- }
- data.append(chunk_data)
- if len(data) >= max_num_results:
- break
-
- return VectorStoreSearchResponse(
- search_query=search_query,
- data=data,
- has_more=False, # For simplicity, we don't implement pagination here
- next_page=None,
- )
-
- except Exception as e:
- logger.error(f"Error searching vector store {vector_store_id}: {e}")
- # Return empty results on error
- return VectorStoreSearchResponse(
- search_query=search_query,
- data=[],
- has_more=False,
- next_page=None,
- )
-
- def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
- """Check if metadata matches the provided filters."""
- for key, value in filters.items():
- if key not in metadata:
- return False
- if metadata[key] != value:
- return False
- return True
-
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py
index 50b864fcd..de41f388c 100644
--- a/llama_stack/providers/remote/vector_io/chroma/chroma.py
+++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py
@@ -189,7 +189,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def openai_create_vector_store(
self,
- name: str | None = None,
+ name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py
index 9e1ee9f6d..9360ef36a 100644
--- a/llama_stack/providers/remote/vector_io/milvus/milvus.py
+++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py
@@ -187,7 +187,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def openai_create_vector_store(
self,
- name: str | None = None,
+ name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
index a11513174..cff62bff5 100644
--- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
+++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
@@ -189,7 +189,7 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def openai_create_vector_store(
self,
- name: str | None = None,
+ name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py
new file mode 100644
index 000000000..345171828
--- /dev/null
+++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py
@@ -0,0 +1,354 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import logging
+import time
+import uuid
+from abc import ABC, abstractmethod
+from typing import Any, Literal
+
+from llama_stack.apis.vector_dbs import VectorDB
+from llama_stack.apis.vector_io import (
+ QueryChunksResponse,
+ VectorStoreDeleteResponse,
+ VectorStoreListResponse,
+ VectorStoreObject,
+ VectorStoreSearchResponse,
+)
+
+logger = logging.getLogger(__name__)
+
+# Constants for OpenAI vector stores
+CHUNK_MULTIPLIER = 5
+
+
+class OpenAIVectorStoreMixin(ABC):
+ """
+ Mixin class that provides common OpenAI Vector Store API implementation.
+ Providers need to implement the abstract storage methods and maintain
+ an openai_vector_stores in-memory cache.
+ """
+
+ # These should be provided by the implementing class
+ openai_vector_stores: dict[str, dict[str, Any]]
+
+ @abstractmethod
+ async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
+ """Save vector store metadata to persistent storage."""
+ pass
+
+ @abstractmethod
+ async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
+ """Load all vector store metadata from persistent storage."""
+ pass
+
+ @abstractmethod
+ async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
+ """Update vector store metadata in persistent storage."""
+ pass
+
+ @abstractmethod
+ async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
+ """Delete vector store metadata from persistent storage."""
+ pass
+
+ @abstractmethod
+ async def register_vector_db(self, vector_db: VectorDB) -> None:
+ """Register a vector database (provider-specific implementation)."""
+ pass
+
+ @abstractmethod
+ async def unregister_vector_db(self, vector_db_id: str) -> None:
+ """Unregister a vector database (provider-specific implementation)."""
+ pass
+
+ @abstractmethod
+ async def query_chunks(
+ self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
+ ) -> QueryChunksResponse:
+ """Query chunks from a vector database (provider-specific implementation)."""
+ pass
+
+ async def openai_create_vector_store(
+ self,
+ name: str,
+ file_ids: list[str] | None = None,
+ expires_after: dict[str, Any] | None = None,
+ chunking_strategy: dict[str, Any] | None = None,
+ metadata: dict[str, Any] | None = None,
+ embedding_model: str | None = None,
+ embedding_dimension: int | None = 384,
+ provider_id: str | None = None,
+ provider_vector_db_id: str | None = None,
+ ) -> VectorStoreObject:
+ """Creates a vector store."""
+ print("IN OPENAI VECTOR STORE MIXIN, openai_create_vector_store")
+ # store and vector_db have the same id
+ store_id = name or str(uuid.uuid4())
+ created_at = int(time.time())
+
+ if provider_id is None:
+ raise ValueError("Provider ID is required")
+
+ if embedding_model is None:
+ raise ValueError("Embedding model is required")
+
+ # Use provided embedding dimension or default to 384
+ if embedding_dimension is None:
+ raise ValueError("Embedding dimension is required")
+
+ provider_vector_db_id = provider_vector_db_id or store_id
+ vector_db = VectorDB(
+ identifier=store_id,
+ embedding_dimension=embedding_dimension,
+ embedding_model=embedding_model,
+ provider_id=provider_id,
+ provider_resource_id=provider_vector_db_id,
+ )
+ from rich.pretty import pprint
+
+ print("VECTOR DB")
+ pprint(vector_db)
+
+ # Register the vector DB
+ await self.register_vector_db(vector_db)
+
+ # Create OpenAI vector store metadata
+ store_info = {
+ "id": store_id,
+ "object": "vector_store",
+ "created_at": created_at,
+ "name": store_id,
+ "usage_bytes": 0,
+ "file_counts": {},
+ "status": "completed",
+ "expires_after": expires_after,
+ "expires_at": None,
+ "last_active_at": created_at,
+ "file_ids": file_ids or [],
+ "chunking_strategy": chunking_strategy,
+ }
+
+ # Add provider information to metadata if provided
+ metadata = metadata or {}
+ if provider_id:
+ metadata["provider_id"] = provider_id
+ if provider_vector_db_id:
+ metadata["provider_vector_db_id"] = provider_vector_db_id
+ store_info["metadata"] = metadata
+
+ # Save to persistent storage (provider-specific)
+ await self._save_openai_vector_store(store_id, store_info)
+
+ # Store in memory cache
+ self.openai_vector_stores[store_id] = store_info
+
+ return VectorStoreObject(
+ id=store_id,
+ created_at=created_at,
+ name=store_id,
+ usage_bytes=0,
+ file_counts={},
+ status="completed",
+ expires_after=expires_after,
+ expires_at=None,
+ last_active_at=created_at,
+ metadata=metadata,
+ )
+
+ async def openai_list_vector_stores(
+ self,
+ limit: int = 20,
+ order: str = "desc",
+ after: str | None = None,
+ before: str | None = None,
+ ) -> VectorStoreListResponse:
+ """Returns a list of vector stores."""
+ # Get all vector stores
+ all_stores = list(self.openai_vector_stores.values())
+
+ # Sort by created_at
+ reverse_order = order == "desc"
+ all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
+
+ # Apply cursor-based pagination
+ if after:
+ after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
+ if after_index >= 0:
+ all_stores = all_stores[after_index + 1 :]
+
+ if before:
+ before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
+ all_stores = all_stores[:before_index]
+
+ # Apply limit
+ limited_stores = all_stores[:limit]
+ # Convert to VectorStoreObject instances
+ data = [VectorStoreObject(**store) for store in limited_stores]
+
+ # Determine pagination info
+ has_more = len(all_stores) > limit
+ first_id = data[0].id if data else None
+ last_id = data[-1].id if data else None
+
+ return VectorStoreListResponse(
+ data=data,
+ has_more=has_more,
+ first_id=first_id,
+ last_id=last_id,
+ )
+
+ async def openai_retrieve_vector_store(
+ self,
+ vector_store_id: str,
+ ) -> VectorStoreObject:
+ """Retrieves a vector store."""
+ if vector_store_id not in self.openai_vector_stores:
+ raise ValueError(f"Vector store {vector_store_id} not found")
+
+ store_info = self.openai_vector_stores[vector_store_id]
+ return VectorStoreObject(**store_info)
+
+ async def openai_update_vector_store(
+ self,
+ vector_store_id: str,
+ name: str | None = None,
+ expires_after: dict[str, Any] | None = None,
+ metadata: dict[str, Any] | None = None,
+ ) -> VectorStoreObject:
+ """Modifies a vector store."""
+ if vector_store_id not in self.openai_vector_stores:
+ raise ValueError(f"Vector store {vector_store_id} not found")
+
+ store_info = self.openai_vector_stores[vector_store_id].copy()
+
+ # Update fields if provided
+ if name is not None:
+ store_info["name"] = name
+ if expires_after is not None:
+ store_info["expires_after"] = expires_after
+ if metadata is not None:
+ store_info["metadata"] = metadata
+
+ # Update last_active_at
+ store_info["last_active_at"] = int(time.time())
+
+ # Save to persistent storage (provider-specific)
+ await self._update_openai_vector_store(vector_store_id, store_info)
+
+ # Update in-memory cache
+ self.openai_vector_stores[vector_store_id] = store_info
+
+ return VectorStoreObject(**store_info)
+
+ async def openai_delete_vector_store(
+ self,
+ vector_store_id: str,
+ ) -> VectorStoreDeleteResponse:
+ """Delete a vector store."""
+ if vector_store_id not in self.openai_vector_stores:
+ raise ValueError(f"Vector store {vector_store_id} not found")
+
+ # Delete from persistent storage (provider-specific)
+ await self._delete_openai_vector_store_from_storage(vector_store_id)
+
+ # Delete from in-memory cache
+ del self.openai_vector_stores[vector_store_id]
+
+ # Also delete the underlying vector DB
+ try:
+ await self.unregister_vector_db(vector_store_id)
+ except Exception as e:
+ logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
+
+ return VectorStoreDeleteResponse(
+ id=vector_store_id,
+ deleted=True,
+ )
+
+ async def openai_search_vector_store(
+ self,
+ vector_store_id: str,
+ query: str | list[str],
+ filters: dict[str, Any] | None = None,
+ max_num_results: int = 10,
+ ranking_options: dict[str, Any] | None = None,
+ rewrite_query: bool = False,
+ search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
+ ) -> VectorStoreSearchResponse:
+ """Search for chunks in a vector store."""
+ if vector_store_id not in self.openai_vector_stores:
+ raise ValueError(f"Vector store {vector_store_id} not found")
+
+ if isinstance(query, list):
+ search_query = " ".join(query)
+ else:
+ search_query = query
+
+ try:
+ score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
+ params = {
+ "max_chunks": max_num_results * CHUNK_MULTIPLIER,
+ "score_threshold": score_threshold,
+ "mode": search_mode,
+ }
+ # TODO: Add support for ranking_options.ranker
+
+ response = await self.query_chunks(
+ vector_db_id=vector_store_id,
+ query=search_query,
+ params=params,
+ )
+
+ # Convert response to OpenAI format
+ data = []
+ for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)):
+ # Apply score based filtering
+ if score < score_threshold:
+ continue
+
+ # Apply filters if provided
+ if filters:
+ # Simple metadata filtering
+ if not self._matches_filters(chunk.metadata, filters):
+ continue
+
+ chunk_data = {
+ "id": f"chunk_{i}",
+ "object": "vector_store.search_result",
+ "score": score,
+ "content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content),
+ "metadata": chunk.metadata,
+ }
+ data.append(chunk_data)
+ if len(data) >= max_num_results:
+ break
+
+ return VectorStoreSearchResponse(
+ search_query=search_query,
+ data=data,
+ has_more=False, # For simplicity, we don't implement pagination here
+ next_page=None,
+ )
+
+ except Exception as e:
+ logger.error(f"Error searching vector store {vector_store_id}: {e}")
+ # Return empty results on error
+ return VectorStoreSearchResponse(
+ search_query=search_query,
+ data=[],
+ has_more=False,
+ next_page=None,
+ )
+
+ def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
+ """Check if metadata matches the provided filters."""
+ for key, value in filters.items():
+ if key not in metadata:
+ return False
+ if metadata[key] != value:
+ return False
+ return True