From d55100d9b7971aae6414ce2134b8198edb20dc29 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Wed, 11 Jun 2025 15:40:57 -0700 Subject: [PATCH] feat: OpenAIVectorIOMixin for vector_stores common logic (#2427) Extracts common OpenAI vector-store code into its own mixin so that all providers can share the same core logic. This also makes it easy for Llama Stack to support both vector-stores and Llama Stack APIs in the interim so that both share the same underlying vector-dbs. Each provider contains storage specific logic to `create / edit / delete / list` vector dbs while the plumbing logic is standardized in the common code. Ensured that this works well with both faiss and sqllite-vec. ### Test Plan ``` llama stack run starter pytest -sv --stack-config http://localhost:8321 tests/integration/vector_io/test_openai_vector_stores.py --embedding-model all-MiniLM-L6-v2 ``` --- docs/_static/llama-stack-spec.html | 3 + docs/_static/llama-stack-spec.yaml | 2 + llama_stack/apis/vector_io/vector_io.py | 2 +- llama_stack/distribution/resolver.py | 10 +- llama_stack/distribution/routers/vector_io.py | 2 +- .../providers/inline/vector_io/faiss/faiss.py | 316 ++----------- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 422 ++++-------------- .../remote/vector_io/chroma/chroma.py | 2 +- .../remote/vector_io/milvus/milvus.py | 2 +- .../remote/vector_io/qdrant/qdrant.py | 2 +- .../utils/memory/openai_vector_store_mixin.py | 354 +++++++++++++++ 11 files changed, 484 insertions(+), 633 deletions(-) create mode 100644 llama_stack/providers/utils/memory/openai_vector_store_mixin.py diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index f5958754e..18f4f49b9 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -12583,6 +12583,9 @@ } }, "additionalProperties": false, + "required": [ + "name" + ], "title": "OpenaiCreateVectorStoreRequest" }, "VectorStoreObject": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index c302e58fb..4dbcd3ac9 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -8791,6 +8791,8 @@ components: description: >- The provider-specific vector database ID. additionalProperties: false + required: + - name title: OpenaiCreateVectorStoreRequest VectorStoreObject: type: object diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 229d62386..5f54539a4 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -165,7 +165,7 @@ class VectorIO(Protocol): @webmethod(route="/openai/v1/vector_stores", method="POST") async def openai_create_vector_store( self, - name: str | None = None, + name: str, file_ids: list[str] | None = None, expires_after: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None, diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 6e7bb5edd..e71ff8092 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -394,9 +394,13 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") missing_methods.append((name, "signature_mismatch")) else: - # Check if the method is actually implemented in the class - method_owner = next((cls for cls in mro if name in cls.__dict__), None) - if method_owner is None or method_owner.__name__ == protocol.__name__: + # Check if the method has a concrete implementation (not just a protocol stub) + # Find all classes in MRO that define this method + method_owners = [cls for cls in mro if name in cls.__dict__] + + # Allow methods from mixins/parents, only reject if ONLY the protocol defines it + if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__: + # Only reject if the method is ONLY defined in the protocol itself (abstract stub) missing_methods.append((name, "not_actually_implemented")) if missing_methods: diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py index 92ddcbcfa..30b19c436 100644 --- a/llama_stack/distribution/routers/vector_io.py +++ b/llama_stack/distribution/routers/vector_io.py @@ -108,7 +108,7 @@ class VectorIORouter(VectorIO): # OpenAI Vector Stores API endpoints async def openai_create_vector_store( self, - name: str | None = None, + name: str, file_ids: list[str] | None = None, expires_after: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None, diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index d0b195cbe..5e9155011 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -9,9 +9,7 @@ import base64 import io import json import logging -import time -import uuid -from typing import Any, Literal +from typing import Any import faiss import numpy as np @@ -24,14 +22,11 @@ from llama_stack.apis.vector_io import ( Chunk, QueryChunksResponse, VectorIO, - VectorStoreDeleteResponse, - VectorStoreListResponse, - VectorStoreObject, - VectorStoreSearchResponse, ) from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, @@ -47,10 +42,6 @@ FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::" OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::" -# In faiss, since we do -CHUNK_MULTIPLIER = 5 - - class FaissIndex(EmbeddingIndex): def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None): self.index = faiss.IndexFlatL2(dimension) @@ -140,7 +131,7 @@ class FaissIndex(EmbeddingIndex): raise NotImplementedError("Keyword search is not supported in FAISS") -class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): +class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None: self.config = config self.inference_api = inference_api @@ -164,14 +155,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): ) self.cache[vector_db.identifier] = index - # Load existing OpenAI vector stores - start_key = OPENAI_VECTOR_STORES_PREFIX - end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff" - stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key) - - for store_data in stored_openai_stores: - store_info = json.loads(store_data) - self.openai_vector_stores[store_info["id"]] = store_info + # Load existing OpenAI vector stores using the mixin method + self.openai_vector_stores = await self._load_openai_vector_stores() async def shutdown(self) -> None: # Cleanup if needed @@ -234,285 +219,34 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): return await index.query_chunks(query, params) - # OpenAI Vector Stores API endpoints implementation - async def openai_create_vector_store( - self, - name: str | None = None, - file_ids: list[str] | None = None, - expires_after: dict[str, Any] | None = None, - chunking_strategy: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - embedding_model: str | None = None, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - provider_vector_db_id: str | None = None, - ) -> VectorStoreObject: - """Creates a vector store.""" + # OpenAI Vector Store Mixin abstract method implementations + async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Save vector store metadata to kvstore.""" assert self.kvstore is not None - # store and vector_db have the same id - store_id = name or str(uuid.uuid4()) - created_at = int(time.time()) - - if provider_id is None: - raise ValueError("Provider ID is required") - - if embedding_model is None: - raise ValueError("Embedding model is required") - - # Use provided embedding dimension or default to 384 - if embedding_dimension is None: - raise ValueError("Embedding dimension is required") - - provider_vector_db_id = provider_vector_db_id or store_id - vector_db = VectorDB( - identifier=store_id, - embedding_dimension=embedding_dimension, - embedding_model=embedding_model, - provider_id=provider_id, - provider_resource_id=provider_vector_db_id, - ) - - # Register the vector DB - await self.register_vector_db(vector_db) - - # Create OpenAI vector store metadata - store_info = { - "id": store_id, - "object": "vector_store", - "created_at": created_at, - "name": store_id, - "usage_bytes": 0, - "file_counts": {}, - "status": "completed", - "expires_after": expires_after, - "expires_at": None, - "last_active_at": created_at, - "file_ids": file_ids or [], - "chunking_strategy": chunking_strategy, - } - - # Add provider information to metadata if provided - metadata = metadata or {} - if provider_id: - metadata["provider_id"] = provider_id - if provider_vector_db_id: - metadata["provider_vector_db_id"] = provider_vector_db_id - store_info["metadata"] = metadata - - # Store in kvstore key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" await self.kvstore.set(key=key, value=json.dumps(store_info)) - # Store in memory cache - self.openai_vector_stores[store_id] = store_info - - return VectorStoreObject( - id=store_id, - created_at=created_at, - name=store_id, - usage_bytes=0, - file_counts={}, - status="completed", - expires_after=expires_after, - expires_at=None, - last_active_at=created_at, - metadata=metadata, - ) - - async def openai_list_vector_stores( - self, - limit: int = 20, - order: str = "desc", - after: str | None = None, - before: str | None = None, - ) -> VectorStoreListResponse: - """Returns a list of vector stores.""" - # Get all vector stores - all_stores = list(self.openai_vector_stores.values()) - - # Sort by created_at - reverse_order = order == "desc" - all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order) - - # Apply cursor-based pagination - if after: - after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1) - if after_index >= 0: - all_stores = all_stores[after_index + 1 :] - - if before: - before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores)) - all_stores = all_stores[:before_index] - - # Apply limit - limited_stores = all_stores[:limit] - # Convert to VectorStoreObject instances - data = [VectorStoreObject(**store) for store in limited_stores] - - # Determine pagination info - has_more = len(all_stores) > limit - first_id = data[0].id if data else None - last_id = data[-1].id if data else None - - return VectorStoreListResponse( - data=data, - has_more=has_more, - first_id=first_id, - last_id=last_id, - ) - - async def openai_retrieve_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreObject: - """Retrieves a vector store.""" - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") - - store_info = self.openai_vector_stores[vector_store_id] - return VectorStoreObject(**store_info) - - async def openai_update_vector_store( - self, - vector_store_id: str, - name: str | None = None, - expires_after: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - ) -> VectorStoreObject: - """Modifies a vector store.""" + async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: + """Load all vector store metadata from kvstore.""" assert self.kvstore is not None - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") + start_key = OPENAI_VECTOR_STORES_PREFIX + end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff" + stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key) - store_info = self.openai_vector_stores[vector_store_id].copy() + stores = {} + for store_data in stored_openai_stores: + store_info = json.loads(store_data) + stores[store_info["id"]] = store_info + return stores - # Update fields if provided - if name is not None: - store_info["name"] = name - if expires_after is not None: - store_info["expires_after"] = expires_after - if metadata is not None: - store_info["metadata"] = metadata - - # Update last_active_at - store_info["last_active_at"] = int(time.time()) - - # Save to kvstore - key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}" + async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Update vector store metadata in kvstore.""" + assert self.kvstore is not None + key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" await self.kvstore.set(key=key, value=json.dumps(store_info)) - # Update in-memory cache - self.openai_vector_stores[vector_store_id] = store_info - - return VectorStoreObject(**store_info) - - async def openai_delete_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreDeleteResponse: - """Delete a vector store.""" + async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: + """Delete vector store metadata from kvstore.""" assert self.kvstore is not None - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") - - # Delete from kvstore - key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}" + key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" await self.kvstore.delete(key) - - # Delete from in-memory cache - del self.openai_vector_stores[vector_store_id] - - # Also delete the underlying vector DB - try: - await self.unregister_vector_db(vector_store_id) - except Exception as e: - logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}") - - return VectorStoreDeleteResponse( - id=vector_store_id, - deleted=True, - ) - - async def openai_search_vector_store( - self, - vector_store_id: str, - query: str | list[str], - filters: dict[str, Any] | None = None, - max_num_results: int = 10, - ranking_options: dict[str, Any] | None = None, - rewrite_query: bool = False, - search_mode: Literal["keyword", "vector", "hybrid"] = "vector", - ) -> VectorStoreSearchResponse: - """Search for chunks in a vector store.""" - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") - - if isinstance(query, list): - search_query = " ".join(query) - else: - search_query = query - - try: - score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0 - params = { - "max_chunks": max_num_results * CHUNK_MULTIPLIER, - "score_threshold": score_threshold, - "mode": search_mode, - } - # TODO: Add support for ranking_options.ranker - - response = await self.query_chunks( - vector_db_id=vector_store_id, - query=search_query, - params=params, - ) - - # Convert response to OpenAI format - data = [] - for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)): - # Apply score based filtering - if score < score_threshold: - continue - - # Apply filters if provided - if filters: - # Simple metadata filtering - if not self._matches_filters(chunk.metadata, filters): - continue - - chunk_data = { - "id": f"chunk_{i}", - "object": "vector_store.search_result", - "score": score, - "content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content), - "metadata": chunk.metadata, - } - data.append(chunk_data) - if len(data) >= max_num_results: - break - - return VectorStoreSearchResponse( - search_query=search_query, - data=data, - has_more=False, # For simplicity, we don't implement pagination here - next_page=None, - ) - - except Exception as e: - logger.error(f"Error searching vector store {vector_store_id}: {e}") - # Return empty results on error - return VectorStoreSearchResponse( - search_query=search_query, - data=[], - has_more=False, - next_page=None, - ) - - def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool: - """Check if metadata matches the provided filters.""" - for key, value in filters.items(): - if key not in metadata: - return False - if metadata[key] != value: - return False - return True diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 0a0f0e653..02f04e766 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -10,9 +10,8 @@ import json import logging import sqlite3 import struct -import time import uuid -from typing import Any, Literal +from typing import Any import numpy as np import sqlite_vec @@ -24,12 +23,9 @@ from llama_stack.apis.vector_io import ( Chunk, QueryChunksResponse, VectorIO, - VectorStoreDeleteResponse, - VectorStoreListResponse, - VectorStoreObject, - VectorStoreSearchResponse, ) from llama_stack.providers.datatypes import VectorDBsProtocolPrivate +from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex logger = logging.getLogger(__name__) @@ -39,11 +35,6 @@ VECTOR_SEARCH = "vector" KEYWORD_SEARCH = "keyword" SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH} -# Constants for OpenAI vector stores (similar to faiss) -VERSION = "v3" -OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::" -CHUNK_MULTIPLIER = 5 - def serialize_vector(vector: list[float]) -> bytes: """Serialize a list of floats into a compact binary representation.""" @@ -303,7 +294,7 @@ class SQLiteVecIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) -class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): +class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): """ A VectorIO implementation using SQLite + sqlite_vec. This class handles vector database registration (with metadata stored in a table named `vector_dbs`) @@ -340,15 +331,12 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): # Load any existing vector DB registrations. cur.execute("SELECT metadata FROM vector_dbs") vector_db_rows = cur.fetchall() - # Load any existing OpenAI vector stores. - cur.execute("SELECT metadata FROM openai_vector_stores") - openai_store_rows = cur.fetchall() - return vector_db_rows, openai_store_rows + return vector_db_rows finally: cur.close() connection.close() - vector_db_rows, openai_store_rows = await asyncio.to_thread(_setup_connection) + vector_db_rows = await asyncio.to_thread(_setup_connection) # Load existing vector DBs for row in vector_db_rows: @@ -359,11 +347,8 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): ) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) - # Load existing OpenAI vector stores - for row in openai_store_rows: - store_data = row[0] - store_info = json.loads(store_data) - self.openai_vector_stores[store_info["id"]] = store_info + # Load existing OpenAI vector stores using the mixin method + self.openai_vector_stores = await self._load_openai_vector_stores() async def shutdown(self) -> None: # nothing to do since we don't maintain a persistent connection @@ -409,6 +394,87 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): await asyncio.to_thread(_delete_vector_db_from_registry) + # OpenAI Vector Store Mixin abstract method implementations + async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Save vector store metadata to SQLite database.""" + + def _store(): + connection = _create_sqlite_connection(self.config.db_path) + cur = connection.cursor() + try: + cur.execute( + "INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)", + (store_id, json.dumps(store_info)), + ) + connection.commit() + except Exception as e: + logger.error(f"Error saving openai vector store {store_id}: {e}") + raise + finally: + cur.close() + connection.close() + + try: + await asyncio.to_thread(_store) + except Exception as e: + logger.error(f"Error saving openai vector store {store_id}: {e}") + raise + + async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: + """Load all vector store metadata from SQLite database.""" + + def _load(): + connection = _create_sqlite_connection(self.config.db_path) + cur = connection.cursor() + try: + cur.execute("SELECT metadata FROM openai_vector_stores") + rows = cur.fetchall() + return rows + finally: + cur.close() + connection.close() + + rows = await asyncio.to_thread(_load) + stores = {} + for row in rows: + store_data = row[0] + store_info = json.loads(store_data) + stores[store_info["id"]] = store_info + return stores + + async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Update vector store metadata in SQLite database.""" + + def _update(): + connection = _create_sqlite_connection(self.config.db_path) + cur = connection.cursor() + try: + cur.execute( + "UPDATE openai_vector_stores SET metadata = ? WHERE id = ?", + (json.dumps(store_info), store_id), + ) + connection.commit() + finally: + cur.close() + connection.close() + + await asyncio.to_thread(_update) + + async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: + """Delete vector store metadata from SQLite database.""" + + def _delete(): + connection = _create_sqlite_connection(self.config.db_path) + cur = connection.cursor() + try: + cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,)) + connection.commit() + finally: + cur.close() + connection.close() + + await asyncio.to_thread(_delete) + async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: if vector_db_id not in self.cache: raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}") @@ -423,318 +489,6 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): raise ValueError(f"Vector DB {vector_db_id} not found") return await self.cache[vector_db_id].query_chunks(query, params) - async def openai_create_vector_store( - self, - name: str | None = None, - file_ids: list[str] | None = None, - expires_after: dict[str, Any] | None = None, - chunking_strategy: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - embedding_model: str | None = None, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - provider_vector_db_id: str | None = None, - ) -> VectorStoreObject: - """Creates a vector store.""" - # store and vector_db have the same id - store_id = name or str(uuid.uuid4()) - created_at = int(time.time()) - - if provider_id is None: - raise ValueError("Provider ID is required") - - if embedding_model is None: - raise ValueError("Embedding model is required") - - # Use provided embedding dimension or default to 384 - if embedding_dimension is None: - raise ValueError("Embedding dimension is required") - - provider_vector_db_id = provider_vector_db_id or store_id - vector_db = VectorDB( - identifier=store_id, - embedding_dimension=embedding_dimension, - embedding_model=embedding_model, - provider_id=provider_id, - provider_resource_id=provider_vector_db_id, - ) - - # Register the vector DB - await self.register_vector_db(vector_db) - - # Create OpenAI vector store metadata - store_info = { - "id": store_id, - "object": "vector_store", - "created_at": created_at, - "name": store_id, - "usage_bytes": 0, - "file_counts": {}, - "status": "completed", - "expires_after": expires_after, - "expires_at": None, - "last_active_at": created_at, - "file_ids": file_ids or [], - "chunking_strategy": chunking_strategy, - } - - # Add provider information to metadata if provided - metadata = metadata or {} - if provider_id: - metadata["provider_id"] = provider_id - if provider_vector_db_id: - metadata["provider_vector_db_id"] = provider_vector_db_id - store_info["metadata"] = metadata - - # Store in SQLite database - def _store_openai_vector_store(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute( - "INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)", - (store_id, json.dumps(store_info)), - ) - connection.commit() - finally: - cur.close() - connection.close() - - await asyncio.to_thread(_store_openai_vector_store) - - # Store in memory cache - self.openai_vector_stores[store_id] = store_info - - return VectorStoreObject( - id=store_id, - created_at=created_at, - name=store_id, - usage_bytes=0, - file_counts={}, - status="completed", - expires_after=expires_after, - expires_at=None, - last_active_at=created_at, - metadata=metadata, - ) - - async def openai_list_vector_stores( - self, - limit: int = 20, - order: str = "desc", - after: str | None = None, - before: str | None = None, - ) -> VectorStoreListResponse: - """Returns a list of vector stores.""" - # Get all vector stores - all_stores = list(self.openai_vector_stores.values()) - - # Sort by created_at - reverse_order = order == "desc" - all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order) - - # Apply cursor-based pagination - if after: - after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1) - if after_index >= 0: - all_stores = all_stores[after_index + 1 :] - - if before: - before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores)) - all_stores = all_stores[:before_index] - - # Apply limit - limited_stores = all_stores[:limit] - # Convert to VectorStoreObject instances - data = [VectorStoreObject(**store) for store in limited_stores] - - # Determine pagination info - has_more = len(all_stores) > limit - first_id = data[0].id if data else None - last_id = data[-1].id if data else None - - return VectorStoreListResponse( - data=data, - has_more=has_more, - first_id=first_id, - last_id=last_id, - ) - - async def openai_retrieve_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreObject: - """Retrieves a vector store.""" - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") - - store_info = self.openai_vector_stores[vector_store_id] - return VectorStoreObject(**store_info) - - async def openai_update_vector_store( - self, - vector_store_id: str, - name: str | None = None, - expires_after: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - ) -> VectorStoreObject: - """Modifies a vector store.""" - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") - - store_info = self.openai_vector_stores[vector_store_id].copy() - - # Update fields if provided - if name is not None: - store_info["name"] = name - if expires_after is not None: - store_info["expires_after"] = expires_after - if metadata is not None: - store_info["metadata"] = metadata - - # Update last_active_at - store_info["last_active_at"] = int(time.time()) - - # Save to SQLite database - def _update_openai_vector_store(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute( - "UPDATE openai_vector_stores SET metadata = ? WHERE id = ?", - (json.dumps(store_info), vector_store_id), - ) - connection.commit() - finally: - cur.close() - connection.close() - - await asyncio.to_thread(_update_openai_vector_store) - - # Update in-memory cache - self.openai_vector_stores[vector_store_id] = store_info - - return VectorStoreObject(**store_info) - - async def openai_delete_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreDeleteResponse: - """Delete a vector store.""" - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") - - # Delete from SQLite database - def _delete_openai_vector_store(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (vector_store_id,)) - connection.commit() - finally: - cur.close() - connection.close() - - await asyncio.to_thread(_delete_openai_vector_store) - - # Delete from in-memory cache - del self.openai_vector_stores[vector_store_id] - - # Also delete the underlying vector DB - try: - await self.unregister_vector_db(vector_store_id) - except Exception as e: - logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}") - - return VectorStoreDeleteResponse( - id=vector_store_id, - deleted=True, - ) - - async def openai_search_vector_store( - self, - vector_store_id: str, - query: str | list[str], - filters: dict[str, Any] | None = None, - max_num_results: int = 10, - ranking_options: dict[str, Any] | None = None, - rewrite_query: bool = False, - search_mode: Literal["keyword", "vector", "hybrid"] = "vector", - ) -> VectorStoreSearchResponse: - """Search for chunks in a vector store.""" - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") - - if isinstance(query, list): - search_query = " ".join(query) - else: - search_query = query - - try: - score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0 - params = { - "max_chunks": max_num_results * CHUNK_MULTIPLIER, - "score_threshold": score_threshold, - "mode": search_mode, - } - # TODO: Add support for ranking_options.ranker - - response = await self.query_chunks( - vector_db_id=vector_store_id, - query=search_query, - params=params, - ) - - # Convert response to OpenAI format - data = [] - for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)): - # Apply score based filtering - if score < score_threshold: - continue - - # Apply filters if provided - if filters: - # Simple metadata filtering - if not self._matches_filters(chunk.metadata, filters): - continue - - chunk_data = { - "id": f"chunk_{i}", - "object": "vector_store.search_result", - "score": score, - "content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content), - "metadata": chunk.metadata, - } - data.append(chunk_data) - if len(data) >= max_num_results: - break - - return VectorStoreSearchResponse( - search_query=search_query, - data=data, - has_more=False, # For simplicity, we don't implement pagination here - next_page=None, - ) - - except Exception as e: - logger.error(f"Error searching vector store {vector_store_id}: {e}") - # Return empty results on error - return VectorStoreSearchResponse( - search_query=search_query, - data=[], - has_more=False, - next_page=None, - ) - - def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool: - """Check if metadata matches the provided filters.""" - for key, value in filters.items(): - if key not in metadata: - return False - if metadata[key] != value: - return False - return True - def generate_chunk_id(document_id: str, chunk_text: str) -> str: """Generate a unique chunk ID using a hash of document ID and chunk text.""" diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 50b864fcd..de41f388c 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -189,7 +189,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def openai_create_vector_store( self, - name: str | None = None, + name: str, file_ids: list[str] | None = None, expires_after: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None, diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 9e1ee9f6d..9360ef36a 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -187,7 +187,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def openai_create_vector_store( self, - name: str | None = None, + name: str, file_ids: list[str] | None = None, expires_after: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None, diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index a11513174..cff62bff5 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -189,7 +189,7 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def openai_create_vector_store( self, - name: str | None = None, + name: str, file_ids: list[str] | None = None, expires_after: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None, diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py new file mode 100644 index 000000000..345171828 --- /dev/null +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -0,0 +1,354 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +import time +import uuid +from abc import ABC, abstractmethod +from typing import Any, Literal + +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import ( + QueryChunksResponse, + VectorStoreDeleteResponse, + VectorStoreListResponse, + VectorStoreObject, + VectorStoreSearchResponse, +) + +logger = logging.getLogger(__name__) + +# Constants for OpenAI vector stores +CHUNK_MULTIPLIER = 5 + + +class OpenAIVectorStoreMixin(ABC): + """ + Mixin class that provides common OpenAI Vector Store API implementation. + Providers need to implement the abstract storage methods and maintain + an openai_vector_stores in-memory cache. + """ + + # These should be provided by the implementing class + openai_vector_stores: dict[str, dict[str, Any]] + + @abstractmethod + async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Save vector store metadata to persistent storage.""" + pass + + @abstractmethod + async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: + """Load all vector store metadata from persistent storage.""" + pass + + @abstractmethod + async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Update vector store metadata in persistent storage.""" + pass + + @abstractmethod + async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: + """Delete vector store metadata from persistent storage.""" + pass + + @abstractmethod + async def register_vector_db(self, vector_db: VectorDB) -> None: + """Register a vector database (provider-specific implementation).""" + pass + + @abstractmethod + async def unregister_vector_db(self, vector_db_id: str) -> None: + """Unregister a vector database (provider-specific implementation).""" + pass + + @abstractmethod + async def query_chunks( + self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None + ) -> QueryChunksResponse: + """Query chunks from a vector database (provider-specific implementation).""" + pass + + async def openai_create_vector_store( + self, + name: str, + file_ids: list[str] | None = None, + expires_after: dict[str, Any] | None = None, + chunking_strategy: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + embedding_model: str | None = None, + embedding_dimension: int | None = 384, + provider_id: str | None = None, + provider_vector_db_id: str | None = None, + ) -> VectorStoreObject: + """Creates a vector store.""" + print("IN OPENAI VECTOR STORE MIXIN, openai_create_vector_store") + # store and vector_db have the same id + store_id = name or str(uuid.uuid4()) + created_at = int(time.time()) + + if provider_id is None: + raise ValueError("Provider ID is required") + + if embedding_model is None: + raise ValueError("Embedding model is required") + + # Use provided embedding dimension or default to 384 + if embedding_dimension is None: + raise ValueError("Embedding dimension is required") + + provider_vector_db_id = provider_vector_db_id or store_id + vector_db = VectorDB( + identifier=store_id, + embedding_dimension=embedding_dimension, + embedding_model=embedding_model, + provider_id=provider_id, + provider_resource_id=provider_vector_db_id, + ) + from rich.pretty import pprint + + print("VECTOR DB") + pprint(vector_db) + + # Register the vector DB + await self.register_vector_db(vector_db) + + # Create OpenAI vector store metadata + store_info = { + "id": store_id, + "object": "vector_store", + "created_at": created_at, + "name": store_id, + "usage_bytes": 0, + "file_counts": {}, + "status": "completed", + "expires_after": expires_after, + "expires_at": None, + "last_active_at": created_at, + "file_ids": file_ids or [], + "chunking_strategy": chunking_strategy, + } + + # Add provider information to metadata if provided + metadata = metadata or {} + if provider_id: + metadata["provider_id"] = provider_id + if provider_vector_db_id: + metadata["provider_vector_db_id"] = provider_vector_db_id + store_info["metadata"] = metadata + + # Save to persistent storage (provider-specific) + await self._save_openai_vector_store(store_id, store_info) + + # Store in memory cache + self.openai_vector_stores[store_id] = store_info + + return VectorStoreObject( + id=store_id, + created_at=created_at, + name=store_id, + usage_bytes=0, + file_counts={}, + status="completed", + expires_after=expires_after, + expires_at=None, + last_active_at=created_at, + metadata=metadata, + ) + + async def openai_list_vector_stores( + self, + limit: int = 20, + order: str = "desc", + after: str | None = None, + before: str | None = None, + ) -> VectorStoreListResponse: + """Returns a list of vector stores.""" + # Get all vector stores + all_stores = list(self.openai_vector_stores.values()) + + # Sort by created_at + reverse_order = order == "desc" + all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order) + + # Apply cursor-based pagination + if after: + after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1) + if after_index >= 0: + all_stores = all_stores[after_index + 1 :] + + if before: + before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores)) + all_stores = all_stores[:before_index] + + # Apply limit + limited_stores = all_stores[:limit] + # Convert to VectorStoreObject instances + data = [VectorStoreObject(**store) for store in limited_stores] + + # Determine pagination info + has_more = len(all_stores) > limit + first_id = data[0].id if data else None + last_id = data[-1].id if data else None + + return VectorStoreListResponse( + data=data, + has_more=has_more, + first_id=first_id, + last_id=last_id, + ) + + async def openai_retrieve_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreObject: + """Retrieves a vector store.""" + if vector_store_id not in self.openai_vector_stores: + raise ValueError(f"Vector store {vector_store_id} not found") + + store_info = self.openai_vector_stores[vector_store_id] + return VectorStoreObject(**store_info) + + async def openai_update_vector_store( + self, + vector_store_id: str, + name: str | None = None, + expires_after: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> VectorStoreObject: + """Modifies a vector store.""" + if vector_store_id not in self.openai_vector_stores: + raise ValueError(f"Vector store {vector_store_id} not found") + + store_info = self.openai_vector_stores[vector_store_id].copy() + + # Update fields if provided + if name is not None: + store_info["name"] = name + if expires_after is not None: + store_info["expires_after"] = expires_after + if metadata is not None: + store_info["metadata"] = metadata + + # Update last_active_at + store_info["last_active_at"] = int(time.time()) + + # Save to persistent storage (provider-specific) + await self._update_openai_vector_store(vector_store_id, store_info) + + # Update in-memory cache + self.openai_vector_stores[vector_store_id] = store_info + + return VectorStoreObject(**store_info) + + async def openai_delete_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreDeleteResponse: + """Delete a vector store.""" + if vector_store_id not in self.openai_vector_stores: + raise ValueError(f"Vector store {vector_store_id} not found") + + # Delete from persistent storage (provider-specific) + await self._delete_openai_vector_store_from_storage(vector_store_id) + + # Delete from in-memory cache + del self.openai_vector_stores[vector_store_id] + + # Also delete the underlying vector DB + try: + await self.unregister_vector_db(vector_store_id) + except Exception as e: + logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}") + + return VectorStoreDeleteResponse( + id=vector_store_id, + deleted=True, + ) + + async def openai_search_vector_store( + self, + vector_store_id: str, + query: str | list[str], + filters: dict[str, Any] | None = None, + max_num_results: int = 10, + ranking_options: dict[str, Any] | None = None, + rewrite_query: bool = False, + search_mode: Literal["keyword", "vector", "hybrid"] = "vector", + ) -> VectorStoreSearchResponse: + """Search for chunks in a vector store.""" + if vector_store_id not in self.openai_vector_stores: + raise ValueError(f"Vector store {vector_store_id} not found") + + if isinstance(query, list): + search_query = " ".join(query) + else: + search_query = query + + try: + score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0 + params = { + "max_chunks": max_num_results * CHUNK_MULTIPLIER, + "score_threshold": score_threshold, + "mode": search_mode, + } + # TODO: Add support for ranking_options.ranker + + response = await self.query_chunks( + vector_db_id=vector_store_id, + query=search_query, + params=params, + ) + + # Convert response to OpenAI format + data = [] + for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)): + # Apply score based filtering + if score < score_threshold: + continue + + # Apply filters if provided + if filters: + # Simple metadata filtering + if not self._matches_filters(chunk.metadata, filters): + continue + + chunk_data = { + "id": f"chunk_{i}", + "object": "vector_store.search_result", + "score": score, + "content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content), + "metadata": chunk.metadata, + } + data.append(chunk_data) + if len(data) >= max_num_results: + break + + return VectorStoreSearchResponse( + search_query=search_query, + data=data, + has_more=False, # For simplicity, we don't implement pagination here + next_page=None, + ) + + except Exception as e: + logger.error(f"Error searching vector store {vector_store_id}: {e}") + # Return empty results on error + return VectorStoreSearchResponse( + search_query=search_query, + data=[], + has_more=False, + next_page=None, + ) + + def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool: + """Check if metadata matches the provided filters.""" + for key, value in filters.items(): + if key not in metadata: + return False + if metadata[key] != value: + return False + return True