From b87968113de0dfdc87417a92d07e3c92656672e0 Mon Sep 17 00:00:00 2001 From: Varsha Prasad Narsing Date: Tue, 17 Jun 2025 16:38:02 -0700 Subject: [PATCH 1/2] feat: Add openAI compatible APIs to QDrant Signed-off-by: Varsha Prasad Narsing --- .../inline/vector_io/qdrant/__init__.py | 10 +- .../inline/vector_io/qdrant/config.py | 3 +- .../remote/vector_io/qdrant/config.py | 3 +- .../remote/vector_io/qdrant/qdrant.py | 93 +++++++++++++++++-- .../vector_io/test_openai_vector_stores.py | 1 + tests/unit/providers/vector_io/test_qdrant.py | 2 +- 6 files changed, 100 insertions(+), 12 deletions(-) diff --git a/llama_stack/providers/inline/vector_io/qdrant/__init__.py b/llama_stack/providers/inline/vector_io/qdrant/__init__.py index ee33b3797..bc9014c68 100644 --- a/llama_stack/providers/inline/vector_io/qdrant/__init__.py +++ b/llama_stack/providers/inline/vector_io/qdrant/__init__.py @@ -4,14 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.datatypes import Api, ProviderSpec +from typing import Any + +from llama_stack.providers.datatypes import Api from .config import QdrantVectorIOConfig -async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]): +async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]): from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter - impl = QdrantVectorIOAdapter(config, deps[Api.inference]) + assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}" + files_api = deps.get(Api.files) + impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/vector_io/qdrant/config.py b/llama_stack/providers/inline/vector_io/qdrant/config.py index 7cc91d918..61d026984 100644 --- a/llama_stack/providers/inline/vector_io/qdrant/config.py +++ b/llama_stack/providers/inline/vector_io/qdrant/config.py @@ -5,7 +5,7 @@ # the root directory of this source tree. -from typing import Any +from typing import Any, Literal from pydantic import BaseModel @@ -15,6 +15,7 @@ from llama_stack.schema_utils import json_schema_type @json_schema_type class QdrantVectorIOConfig(BaseModel): path: str + distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE" @classmethod def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/vector_io/qdrant/config.py b/llama_stack/providers/remote/vector_io/qdrant/config.py index 314d3f5f1..0d8e08663 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/config.py +++ b/llama_stack/providers/remote/vector_io/qdrant/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any +from typing import Any, Literal from pydantic import BaseModel @@ -23,6 +23,7 @@ class QdrantVectorIOConfig(BaseModel): prefix: str | None = None timeout: int | None = None host: str | None = None + distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE" @classmethod def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 3df3da27f..48f027527 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -12,6 +12,7 @@ from numpy.typing import NDArray from qdrant_client import AsyncQdrantClient, models from qdrant_client.models import PointStruct +from llama_stack.apis.files import Files from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( @@ -31,6 +32,7 @@ from llama_stack.apis.vector_io import ( ) from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig +from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, @@ -40,6 +42,7 @@ from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig log = logging.getLogger(__name__) CHUNK_ID_KEY = "_chunk_id" +OPENAI_VECTOR_STORES_METADATA_COLLECTION = "openai_vector_stores_metadata" def convert_id(_id: str) -> str: @@ -54,9 +57,10 @@ def convert_id(_id: str) -> str: class QdrantIndex(EmbeddingIndex): - def __init__(self, client: AsyncQdrantClient, collection_name: str): + def __init__(self, client: AsyncQdrantClient, collection_name: str, distance_metric: str = "COSINE"): self.client = client self.collection_name = collection_name + self.distance_metric = distance_metric async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( @@ -64,9 +68,12 @@ class QdrantIndex(EmbeddingIndex): ) if not await self.client.collection_exists(self.collection_name): + # Get distance metric, defaulting to COSINE + distance = getattr(models.Distance, self.distance_metric, models.Distance.COSINE) + await self.client.create_collection( self.collection_name, - vectors_config=models.VectorParams(size=len(embeddings[0]), distance=models.Distance.COSINE), + vectors_config=models.VectorParams(size=len(embeddings[0]), distance=distance), ) points = [] @@ -135,28 +142,100 @@ class QdrantIndex(EmbeddingIndex): await self.client.delete_collection(collection_name=self.collection_name) -class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): +class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): def __init__( - self, config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, inference_api: Api.inference + self, + config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, + inference_api: Api.inference, + files_api: Files | None, ) -> None: self.config = config self.client: AsyncQdrantClient = None self.cache = {} self.inference_api = inference_api + self.files_api = files_api + self.vector_db_store = None + self.openai_vector_stores: dict[str, dict[str, Any]] = {} async def initialize(self) -> None: self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) + # Load existing OpenAI vector stores using the mixin method + self.openai_vector_stores = await self._load_openai_vector_stores() async def shutdown(self) -> None: await self.client.close() + # OpenAI Vector Store Mixin abstract method implementations + async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Save vector store metadata to Qdrant collection metadata.""" + # Store metadata in a special collection for vector store metadata + metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION + + # Create metadata collection if it doesn't exist + if not await self.client.collection_exists(metadata_collection): + # Get distance metric from config, defaulting to COSINE for backward compatibility + distance_metric = getattr(self.config, "distance_metric", "COSINE") + distance = getattr(models.Distance, distance_metric, models.Distance.COSINE) + + await self.client.create_collection( + collection_name=metadata_collection, + vectors_config=models.VectorParams(size=1, distance=distance), + ) + + # Store metadata as a point with dummy vector + await self.client.upsert( + collection_name=metadata_collection, + points=[ + models.PointStruct( + id=convert_id(store_id), + vector=[0.0], # Dummy vector + payload={"metadata": store_info}, + ) + ], + ) + + async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: + """Load all vector store metadata from Qdrant.""" + metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION + + if not await self.client.collection_exists(metadata_collection): + return {} + + # Get all points from metadata collection + points = await self.client.scroll( + collection_name=metadata_collection, + limit=1000, # Reasonable limit for metadata + with_payload=True, + ) + + stores = {} + for point in points[0]: # points[0] contains the actual points + if point.payload and "metadata" in point.payload: + store_info = point.payload["metadata"] + stores[store_info["id"]] = store_info + + return stores + + async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Update vector store metadata in Qdrant.""" + await self._save_openai_vector_store(store_id, store_info) + + async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: + """Delete vector store metadata from Qdrant.""" + metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION + + if await self.client.collection_exists(metadata_collection): + await self.client.delete( + collection_name=metadata_collection, points_selector=models.PointIdsList(points=[convert_id(store_id)]) + ) + async def register_vector_db( self, vector_db: VectorDB, ) -> None: index = VectorDBWithIndex( vector_db=vector_db, - index=QdrantIndex(self.client, vector_db.identifier), + index=QdrantIndex(self.client, vector_db.identifier, self.config.distance_metric), inference_api=self.inference_api, ) @@ -177,7 +256,9 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): index = VectorDBWithIndex( vector_db=vector_db, - index=QdrantIndex(client=self.client, collection_name=vector_db.identifier), + index=QdrantIndex( + client=self.client, collection_name=vector_db.identifier, distance_metric=self.config.distance_metric + ), inference_api=self.inference_api, ) self.cache[vector_db_id] = index diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index a34c5b410..6d47a5684 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -22,6 +22,7 @@ logger = logging.getLogger(__name__) def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] for p in vector_io_providers: + if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "inline::qdrant", "remote::pgvector"]: if p.provider_type in [ "inline::faiss", "inline::sqlite-vec", diff --git a/tests/unit/providers/vector_io/test_qdrant.py b/tests/unit/providers/vector_io/test_qdrant.py index d3ffe711c..b52043e86 100644 --- a/tests/unit/providers/vector_io/test_qdrant.py +++ b/tests/unit/providers/vector_io/test_qdrant.py @@ -69,7 +69,7 @@ def mock_api_service(sample_embeddings): @pytest.fixture async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter: - adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service) + adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service, files_api=None) adapter.vector_db_store = mock_vector_db_store await adapter.initialize() yield adapter From c9dad596868a2552dc9d5be268262593bf0c9717 Mon Sep 17 00:00:00 2001 From: Varsha Prasad Narsing Date: Wed, 25 Jun 2025 16:59:29 -0700 Subject: [PATCH 2/2] feat: rebase and implement file API methods Signed-off-by: Varsha Prasad Narsing --- .../workflows/integration-vector-io-tests.yml | 32 ++- .../providers/vector_io/inline_qdrant.md | 4 + .../providers/vector_io/remote_qdrant.md | 5 +- .../inline/vector_io/qdrant/config.py | 12 +- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 4 +- llama_stack/providers/registry/vector_io.py | 2 + .../remote/vector_io/qdrant/__init__.py | 3 +- .../remote/vector_io/qdrant/config.py | 15 +- .../remote/vector_io/qdrant/qdrant.py | 265 +++++------------- .../utils/memory/openai_vector_store_mixin.py | 14 +- .../vector_io/test_openai_vector_stores.py | 22 +- tests/integration/vector_io/test_vector_io.py | 4 + tests/unit/providers/vector_io/conftest.py | 58 +++- tests/unit/providers/vector_io/test_qdrant.py | 9 +- .../test_vector_io_openai_vector_stores.py | 6 +- 15 files changed, 240 insertions(+), 215 deletions(-) diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index 9a02bbcf8..7b9f0c459 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -24,7 +24,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector"] + vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::qdrant"] python-version: ["3.12", "3.13"] fail-fast: false # we want to run all tests regardless of failure @@ -78,6 +78,29 @@ jobs: PGPASSWORD=llamastack psql -h localhost -U llamastack -d llamastack \ -c "CREATE EXTENSION IF NOT EXISTS vector;" + - name: Setup Qdrant + if: matrix.vector-io-provider == 'remote::qdrant' + run: | + docker run --rm -d --pull always \ + --name qdrant \ + -p 6333:6333 \ + qdrant/qdrant + + - name: Wait for Qdrant to be ready + if: matrix.vector-io-provider == 'remote::qdrant' + run: | + echo "Waiting for Qdrant to be ready..." + for i in {1..30}; do + if curl -s http://localhost:6333/collections | grep -q '"status":"ok"'; then + echo "Qdrant is ready!" + exit 0 + fi + sleep 2 + done + echo "Qdrant failed to start" + docker logs qdrant + exit 1 + - name: Wait for ChromaDB to be ready if: matrix.vector-io-provider == 'remote::chromadb' run: | @@ -113,6 +136,8 @@ jobs: PGVECTOR_DB: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }} PGVECTOR_USER: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }} PGVECTOR_PASSWORD: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }} + ENABLE_QDRANT: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'true' || '' }} + QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }} run: | uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ tests/integration/vector_io \ @@ -134,6 +159,11 @@ jobs: run: | docker logs chromadb > chromadb.log + - name: Write Qdrant logs to file + if: ${{ always() && matrix.vector-io-provider == 'remote::qdrant' }} + run: | + docker logs qdrant > qdrant.log + - name: Upload all logs to artifacts if: ${{ always() }} uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 diff --git a/docs/source/providers/vector_io/inline_qdrant.md b/docs/source/providers/vector_io/inline_qdrant.md index 63e2d81d8..e989a3554 100644 --- a/docs/source/providers/vector_io/inline_qdrant.md +++ b/docs/source/providers/vector_io/inline_qdrant.md @@ -51,11 +51,15 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `path` | `` | No | PydanticUndefined | | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | | ## Sample Configuration ```yaml path: ${env.QDRANT_PATH:=~/.llama/~/.llama/dummy}/qdrant.db +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/qdrant_registry.db ``` diff --git a/docs/source/providers/vector_io/remote_qdrant.md b/docs/source/providers/vector_io/remote_qdrant.md index 14c821f35..7dac0a3d1 100644 --- a/docs/source/providers/vector_io/remote_qdrant.md +++ b/docs/source/providers/vector_io/remote_qdrant.md @@ -20,11 +20,14 @@ Please refer to the inline provider documentation. | `prefix` | `str \| None` | No | | | | `timeout` | `int \| None` | No | | | | `host` | `str \| None` | No | | | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | | ## Sample Configuration ```yaml -api_key: ${env.QDRANT_API_KEY} +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/qdrant_registry.db ``` diff --git a/llama_stack/providers/inline/vector_io/qdrant/config.py b/llama_stack/providers/inline/vector_io/qdrant/config.py index 61d026984..c23bb9608 100644 --- a/llama_stack/providers/inline/vector_io/qdrant/config.py +++ b/llama_stack/providers/inline/vector_io/qdrant/config.py @@ -5,20 +5,28 @@ # the root directory of this source tree. -from typing import Any, Literal +from typing import Any from pydantic import BaseModel +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) from llama_stack.schema_utils import json_schema_type @json_schema_type class QdrantVectorIOConfig(BaseModel): path: str - distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE" + kvstore: KVStoreConfig @classmethod def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: return { "path": "${env.QDRANT_PATH:=~/.llama/" + __distro_dir__ + "}/" + "qdrant.db", + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="qdrant_registry.db", + ), } diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index cfa4e2263..a09c14721 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -192,7 +192,9 @@ class SQLiteVecIndex(EmbeddingIndex): await asyncio.to_thread(_drop_tables) - async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, batch_size: int = 500): + async def add_chunks( + self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None, batch_size: int = 500 + ): """ Add new chunks along with their embeddings using batch inserts. For each chunk, we insert its JSON into the metadata table and then insert its diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 063b382df..846f7b88e 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -460,6 +460,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more module="llama_stack.providers.inline.vector_io.qdrant", config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig", api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], description=r""" [Qdrant](https://qdrant.tech/documentation/) is an inline and remote vector database provider for Llama Stack. It allows you to store and query vectors directly in memory. @@ -516,6 +517,7 @@ Please refer to the inline provider documentation. """, ), api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], ), remote_provider_spec( Api.vector_io, diff --git a/llama_stack/providers/remote/vector_io/qdrant/__init__.py b/llama_stack/providers/remote/vector_io/qdrant/__init__.py index 029de285f..6ce98b17c 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/__init__.py +++ b/llama_stack/providers/remote/vector_io/qdrant/__init__.py @@ -12,6 +12,7 @@ from .config import QdrantVectorIOConfig async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]): from .qdrant import QdrantVectorIOAdapter - impl = QdrantVectorIOAdapter(config, deps[Api.inference]) + files_api = deps.get(Api.files) + impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/qdrant/config.py b/llama_stack/providers/remote/vector_io/qdrant/config.py index 0d8e08663..71d51c062 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/config.py +++ b/llama_stack/providers/remote/vector_io/qdrant/config.py @@ -4,10 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Literal +from typing import Any from pydantic import BaseModel +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) from llama_stack.schema_utils import json_schema_type @@ -23,10 +27,13 @@ class QdrantVectorIOConfig(BaseModel): prefix: str | None = None timeout: int | None = None host: str | None = None - distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE" + kvstore: KVStoreConfig @classmethod - def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { - "api_key": "${env.QDRANT_API_KEY}", + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="qdrant_registry.db", + ), } diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 48f027527..85d6fe05b 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -18,20 +18,11 @@ from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( Chunk, QueryChunksResponse, - SearchRankingOptions, VectorIO, - VectorStoreChunkingStrategy, - VectorStoreDeleteResponse, - VectorStoreFileContentsResponse, - VectorStoreFileObject, - VectorStoreFileStatus, - VectorStoreListFilesResponse, - VectorStoreListResponse, - VectorStoreObject, - VectorStoreSearchResponsePage, ) from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig +from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, @@ -42,7 +33,10 @@ from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig log = logging.getLogger(__name__) CHUNK_ID_KEY = "_chunk_id" -OPENAI_VECTOR_STORES_METADATA_COLLECTION = "openai_vector_stores_metadata" + +# KV store prefixes for vector databases +VERSION = "v3" +VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::" def convert_id(_id: str) -> str: @@ -57,10 +51,14 @@ def convert_id(_id: str) -> str: class QdrantIndex(EmbeddingIndex): - def __init__(self, client: AsyncQdrantClient, collection_name: str, distance_metric: str = "COSINE"): + def __init__(self, client: AsyncQdrantClient, collection_name: str): self.client = client self.collection_name = collection_name - self.distance_metric = distance_metric + + async def initialize(self) -> None: + # Qdrant collections are created on-demand in add_chunks + # If the collection does not exist, it will be created in add_chunks. + pass async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( @@ -68,12 +66,9 @@ class QdrantIndex(EmbeddingIndex): ) if not await self.client.collection_exists(self.collection_name): - # Get distance metric, defaulting to COSINE - distance = getattr(models.Distance, self.distance_metric, models.Distance.COSINE) - await self.client.create_collection( self.collection_name, - vectors_config=models.VectorParams(size=len(embeddings[0]), distance=distance), + vectors_config=models.VectorParams(size=len(embeddings[0]), distance=models.Distance.COSINE), ) points = [] @@ -90,7 +85,15 @@ class QdrantIndex(EmbeddingIndex): await self.client.upsert(collection_name=self.collection_name, points=points) async def delete_chunk(self, chunk_id: str) -> None: - raise NotImplementedError("delete_chunk is not supported in qdrant") + """Remove a chunk from the Qdrant collection.""" + try: + await self.client.delete( + collection_name=self.collection_name, + points_selector=models.PointIdsList(points=[convert_id(chunk_id)]), + ) + except Exception as e: + log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}") + raise async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: results = ( @@ -155,87 +158,55 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.inference_api = inference_api self.files_api = files_api self.vector_db_store = None + self.kvstore: KVStore | None = None self.openai_vector_stores: dict[str, dict[str, Any]] = {} async def initialize(self) -> None: - self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) - # Load existing OpenAI vector stores using the mixin method + # Close existing client if it exists + # Qdrant doesn't allow multiple clients to access the same storage path simultaneously + # This prevents "Storage folder is already accessed by another instance" errors during re-initialization + if self.client is not None: + await self.client.close() + self.client = None + + # Create client config excluding kvstore (which is used for metadata storage, not Qdrant client connection) + client_config = self.config.model_dump(exclude_none=True, exclude={"kvstore"}) + self.client = AsyncQdrantClient(**client_config) + self.kvstore = await kvstore_impl(self.config.kvstore) + + # Load existing vector DBs from kvstore + start_key = VECTOR_DBS_PREFIX + end_key = f"{VECTOR_DBS_PREFIX}\xff" + stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) + + for vector_db_data in stored_vector_dbs: + vector_db = VectorDB.model_validate_json(vector_db_data) + index = VectorDBWithIndex( + vector_db, + QdrantIndex(self.client, vector_db.identifier), + self.inference_api, + ) + self.cache[vector_db.identifier] = index + + # Load OpenAI vector stores as before self.openai_vector_stores = await self._load_openai_vector_stores() async def shutdown(self) -> None: await self.client.close() - # OpenAI Vector Store Mixin abstract method implementations - async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: - """Save vector store metadata to Qdrant collection metadata.""" - # Store metadata in a special collection for vector store metadata - metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION - - # Create metadata collection if it doesn't exist - if not await self.client.collection_exists(metadata_collection): - # Get distance metric from config, defaulting to COSINE for backward compatibility - distance_metric = getattr(self.config, "distance_metric", "COSINE") - distance = getattr(models.Distance, distance_metric, models.Distance.COSINE) - - await self.client.create_collection( - collection_name=metadata_collection, - vectors_config=models.VectorParams(size=1, distance=distance), - ) - - # Store metadata as a point with dummy vector - await self.client.upsert( - collection_name=metadata_collection, - points=[ - models.PointStruct( - id=convert_id(store_id), - vector=[0.0], # Dummy vector - payload={"metadata": store_info}, - ) - ], - ) - - async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: - """Load all vector store metadata from Qdrant.""" - metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION - - if not await self.client.collection_exists(metadata_collection): - return {} - - # Get all points from metadata collection - points = await self.client.scroll( - collection_name=metadata_collection, - limit=1000, # Reasonable limit for metadata - with_payload=True, - ) - - stores = {} - for point in points[0]: # points[0] contains the actual points - if point.payload and "metadata" in point.payload: - store_info = point.payload["metadata"] - stores[store_info["id"]] = store_info - - return stores - - async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: - """Update vector store metadata in Qdrant.""" - await self._save_openai_vector_store(store_id, store_info) - - async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: - """Delete vector store metadata from Qdrant.""" - metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION - - if await self.client.collection_exists(metadata_collection): - await self.client.delete( - collection_name=metadata_collection, points_selector=models.PointIdsList(points=[convert_id(store_id)]) - ) - async def register_vector_db( self, vector_db: VectorDB, ) -> None: + # Save to kvstore + assert self.kvstore is not None + key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}" + await self.kvstore.set(key=key, value=vector_db.model_dump_json()) + + # Store in cache index = VectorDBWithIndex( vector_db=vector_db, - index=QdrantIndex(self.client, vector_db.identifier, self.config.distance_metric), + index=QdrantIndex(self.client, vector_db.identifier), inference_api=self.inference_api, ) @@ -246,19 +217,24 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP await self.cache[vector_db_id].index.delete() del self.cache[vector_db_id] + # Remove from kvstore + assert self.kvstore is not None + await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}") + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: if vector_db_id in self.cache: return self.cache[vector_db_id] + if self.vector_db_store is None: + raise ValueError(f"Vector DB {vector_db_id} not found") + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) if not vector_db: raise ValueError(f"Vector DB {vector_db_id} not found") index = VectorDBWithIndex( vector_db=vector_db, - index=QdrantIndex( - client=self.client, collection_name=vector_db.identifier, distance_metric=self.config.distance_metric - ), + index=QdrantIndex(client=self.client, collection_name=vector_db.identifier), inference_api=self.inference_api, ) self.cache[vector_db_id] = index @@ -273,7 +249,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: raise ValueError(f"Vector DB {vector_db_id} not found") - await index.insert_chunks(chunks) async def query_chunks( @@ -288,109 +263,11 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP return await index.query_chunks(query, params) - async def openai_create_vector_store( - self, - name: str, - file_ids: list[str] | None = None, - expires_after: dict[str, Any] | None = None, - chunking_strategy: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - embedding_model: str | None = None, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - ) -> VectorStoreObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_list_vector_stores( - self, - limit: int | None = 20, - order: str | None = "desc", - after: str | None = None, - before: str | None = None, - ) -> VectorStoreListResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_retrieve_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_update_vector_store( - self, - vector_store_id: str, - name: str | None = None, - expires_after: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - ) -> VectorStoreObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_delete_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreDeleteResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_search_vector_store( - self, - vector_store_id: str, - query: str | list[str], - filters: dict[str, Any] | None = None, - max_num_results: int | None = 10, - ranking_options: SearchRankingOptions | None = None, - rewrite_query: bool | None = False, - search_mode: str | None = "vector", - ) -> VectorStoreSearchResponsePage: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_attach_file_to_vector_store( - self, - vector_store_id: str, - file_id: str, - attributes: dict[str, Any] | None = None, - chunking_strategy: VectorStoreChunkingStrategy | None = None, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_list_files_in_vector_store( - self, - vector_store_id: str, - limit: int | None = 20, - order: str | None = "desc", - after: str | None = None, - before: str | None = None, - filter: VectorStoreFileStatus | None = None, - ) -> VectorStoreListFilesResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_retrieve_vector_store_file( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_retrieve_vector_store_file_contents( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileContentsResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_update_vector_store_file( - self, - vector_store_id: str, - file_id: str, - attributes: dict[str, Any] | None = None, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_delete_vector_store_file( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") + """Delete chunks from a Qdrant vector store.""" + index = await self._get_and_cache_vector_db_index(store_id) + if not index: + raise ValueError(f"Vector DB {store_id} not found") + + for chunk_id in chunk_ids: + await index.index.delete_chunk(chunk_id) diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index ee69d7c52..fad708e56 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -264,8 +264,18 @@ class OpenAIVectorStoreMixin(ABC): # Now that our vector store is created, attach any files that were provided file_ids = file_ids or [] - tasks = [self.openai_attach_file_to_vector_store(vector_db_id, file_id) for file_id in file_ids] - await asyncio.gather(*tasks) + + # Try concurrent processing first, fall back to sequential if it fails + if file_ids: + try: + # Process files concurrently for better performance + tasks = [self.openai_attach_file_to_vector_store(vector_db_id, file_id) for file_id in file_ids] + await asyncio.gather(*tasks) + except Exception as e: + logger.warning(f"Concurrent file processing failed: {e}. Falling back to sequential processing.") + # Fall back to sequential processing if concurrent processing fails + for file_id in file_ids: + await self.openai_attach_file_to_vector_store(vector_db_id, file_id) # Get the updated store info and return it store_info = self.openai_vector_stores[vector_db_id] diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 6d47a5684..c806d9c77 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -22,14 +22,30 @@ logger = logging.getLogger(__name__) def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] for p in vector_io_providers: - if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "inline::qdrant", "remote::pgvector"]: if p.provider_type in [ "inline::faiss", "inline::sqlite-vec", "inline::milvus", - "inline::chromadb", + "inline::qdrant", "remote::pgvector", - "remote::chromadb", + "inline::chromadb", + "remote::qdrant", + ]: + return + + pytest.skip("OpenAI vector stores are not supported by any provider") + + +def skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models): + vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] + for p in vector_io_providers: + if p.provider_type in [ + "inline::faiss", + "inline::sqlite-vec", + "inline::milvus", + "inline::qdrant", + "remote::pgvector", + "remote::qdrant", ]: return diff --git a/tests/integration/vector_io/test_vector_io.py b/tests/integration/vector_io/test_vector_io.py index 9cd4fc18c..07faa0db1 100644 --- a/tests/integration/vector_io/test_vector_io.py +++ b/tests/integration/vector_io/test_vector_io.py @@ -125,6 +125,8 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id, embedding_dimension): vector_io_provider_params_dict = { "inline::milvus": {"score_threshold": -1.0}, + "remote::qdrant": {"score_threshold": -1.0}, + "inline::qdrant": {"score_threshold": -1.0}, } vector_db_id = "test_precomputed_embeddings_db" client_with_empty_registry.vector_dbs.register( @@ -168,6 +170,8 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb( ): vector_io_provider_params_dict = { "inline::milvus": {"score_threshold": 0.0}, + "remote::qdrant": {"score_threshold": 0.0}, + "inline::qdrant": {"score_threshold": 0.0}, } vector_db_id = "test_precomputed_embeddings_db" client_with_empty_registry.vector_dbs.register( diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index bcba06140..e52d53d48 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -17,10 +17,12 @@ from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOC from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig +from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter +from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter EMBEDDING_DIMENSION = 384 COLLECTION_PREFIX = "test_collection" @@ -134,7 +136,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory): await index.initialize() index.db_path = db_path yield index - index.delete() + await index.delete() @pytest.fixture @@ -280,14 +282,66 @@ async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_d await adapter.shutdown() +@pytest.fixture +def qdrant_vec_db_path(tmp_path_factory): + import uuid + + db_path = str(tmp_path_factory.getbasetemp() / f"test_qdrant_{uuid.uuid4()}.db") + return db_path + + +@pytest.fixture +async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension): + import uuid + + config = QdrantVectorIOConfig( + path=qdrant_vec_db_path, + kvstore=SqliteKVStoreConfig(), + ) + adapter = QdrantVectorIOAdapter( + config=config, + inference_api=mock_inference_api, + files_api=None, + ) + collection_id = f"qdrant_test_collection_{uuid.uuid4()}" + await adapter.initialize() + await adapter.register_vector_db( + VectorDB( + identifier=collection_id, + provider_id="test_provider", + embedding_model="test_model", + embedding_dimension=embedding_dimension, + ) + ) + adapter.test_collection_id = collection_id + yield adapter + await adapter.shutdown() + + +@pytest.fixture +async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension): + import uuid + + from qdrant_client import AsyncQdrantClient + + from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantIndex + + client = AsyncQdrantClient(path=qdrant_vec_db_path) + collection_name = f"qdrant_test_collection_{uuid.uuid4()}" + index = QdrantIndex(client, collection_name) + yield index + await index.delete() + + @pytest.fixture def vector_io_adapter(vector_provider, request): """Returns the appropriate vector IO adapter based on the provider parameter.""" vector_provider_dict = { "milvus": "milvus_vec_adapter", "faiss": "faiss_vec_adapter", - "sqlite_vec": "sqlite_vec_adapter", + "qdrant": "qdrant_vec_adapter", "chroma": "chroma_vec_adapter", + "sqlite_vec": "sqlite_vec_adapter", } return request.getfixturevalue(vector_provider_dict[vector_provider]) diff --git a/tests/unit/providers/vector_io/test_qdrant.py b/tests/unit/providers/vector_io/test_qdrant.py index b52043e86..b16aa9b89 100644 --- a/tests/unit/providers/vector_io/test_qdrant.py +++ b/tests/unit/providers/vector_io/test_qdrant.py @@ -23,6 +23,7 @@ from llama_stack.providers.inline.vector_io.qdrant.config import ( from llama_stack.providers.remote.vector_io.qdrant.qdrant import ( QdrantVectorIOAdapter, ) +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig # This test is a unit test for the QdrantVectorIOAdapter class. This should only contain # tests which are specific to this class. More general (API-level) tests should be placed in @@ -36,7 +37,9 @@ from llama_stack.providers.remote.vector_io.qdrant.qdrant import ( @pytest.fixture def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig: - return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db")) + kvstore_config = SqliteKVStoreConfig(db_name=os.path.join(tmp_path, "test_kvstore.db")) + + return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"), kvstore=kvstore_config) @pytest.fixture(scope="session") @@ -50,6 +53,10 @@ def mock_vector_db(vector_db_id) -> MagicMock: mock_vector_db.embedding_model = "embedding_model" mock_vector_db.identifier = vector_db_id mock_vector_db.embedding_dimension = 384 + # Mock model_dump_json to return a proper JSON string for kvstore persistence + mock_vector_db.model_dump_json.return_value = ( + '{"identifier": "' + vector_db_id + '", "embedding_model": "embedding_model", "embedding_dimension": 384}' + ) return mock_vector_db diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index 98889f38e..9669d3922 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -30,12 +30,12 @@ async def test_initialize_index(vector_index): async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings): - vector_index.delete() - vector_index.initialize() + await vector_index.delete() + await vector_index.initialize() await vector_index.add_chunks(sample_chunks, sample_embeddings) resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1) assert resp.chunks[0].content == sample_chunks[0].content - vector_index.delete() + await vector_index.delete() async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):