From 61bddfe70e3b48856068e6d8c5d6e26da52729c8 Mon Sep 17 00:00:00 2001 From: Varsha Prasad Narsing Date: Tue, 17 Jun 2025 16:38:02 -0700 Subject: [PATCH] feat: Add openAI compatible APIs to QDrant Signed-off-by: Varsha Prasad Narsing --- .../inline/vector_io/qdrant/__init__.py | 10 +- .../inline/vector_io/qdrant/config.py | 3 +- .../remote/vector_io/qdrant/config.py | 3 +- .../remote/vector_io/qdrant/qdrant.py | 93 +++++++++++++++++-- .../vector_io/test_openai_vector_stores.py | 2 +- tests/unit/providers/vector_io/test_qdrant.py | 2 +- 6 files changed, 100 insertions(+), 13 deletions(-) diff --git a/llama_stack/providers/inline/vector_io/qdrant/__init__.py b/llama_stack/providers/inline/vector_io/qdrant/__init__.py index ee33b3797..bc9014c68 100644 --- a/llama_stack/providers/inline/vector_io/qdrant/__init__.py +++ b/llama_stack/providers/inline/vector_io/qdrant/__init__.py @@ -4,14 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.datatypes import Api, ProviderSpec +from typing import Any + +from llama_stack.providers.datatypes import Api from .config import QdrantVectorIOConfig -async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]): +async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]): from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter - impl = QdrantVectorIOAdapter(config, deps[Api.inference]) + assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}" + files_api = deps.get(Api.files) + impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/vector_io/qdrant/config.py b/llama_stack/providers/inline/vector_io/qdrant/config.py index 7cc91d918..61d026984 100644 --- a/llama_stack/providers/inline/vector_io/qdrant/config.py +++ b/llama_stack/providers/inline/vector_io/qdrant/config.py @@ -5,7 +5,7 @@ # the root directory of this source tree. -from typing import Any +from typing import Any, Literal from pydantic import BaseModel @@ -15,6 +15,7 @@ from llama_stack.schema_utils import json_schema_type @json_schema_type class QdrantVectorIOConfig(BaseModel): path: str + distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE" @classmethod def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/vector_io/qdrant/config.py b/llama_stack/providers/remote/vector_io/qdrant/config.py index 314d3f5f1..0d8e08663 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/config.py +++ b/llama_stack/providers/remote/vector_io/qdrant/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any +from typing import Any, Literal from pydantic import BaseModel @@ -23,6 +23,7 @@ class QdrantVectorIOConfig(BaseModel): prefix: str | None = None timeout: int | None = None host: str | None = None + distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE" @classmethod def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 09ea08fa0..8787a4900 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -12,6 +12,7 @@ from numpy.typing import NDArray from qdrant_client import AsyncQdrantClient, models from qdrant_client.models import PointStruct +from llama_stack.apis.files import Files from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( @@ -31,6 +32,7 @@ from llama_stack.apis.vector_io import ( ) from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig +from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, @@ -40,6 +42,7 @@ from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig log = logging.getLogger(__name__) CHUNK_ID_KEY = "_chunk_id" +OPENAI_VECTOR_STORES_METADATA_COLLECTION = "openai_vector_stores_metadata" def convert_id(_id: str) -> str: @@ -54,9 +57,10 @@ def convert_id(_id: str) -> str: class QdrantIndex(EmbeddingIndex): - def __init__(self, client: AsyncQdrantClient, collection_name: str): + def __init__(self, client: AsyncQdrantClient, collection_name: str, distance_metric: str = "COSINE"): self.client = client self.collection_name = collection_name + self.distance_metric = distance_metric async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( @@ -64,9 +68,12 @@ class QdrantIndex(EmbeddingIndex): ) if not await self.client.collection_exists(self.collection_name): + # Get distance metric, defaulting to COSINE + distance = getattr(models.Distance, self.distance_metric, models.Distance.COSINE) + await self.client.create_collection( self.collection_name, - vectors_config=models.VectorParams(size=len(embeddings[0]), distance=models.Distance.COSINE), + vectors_config=models.VectorParams(size=len(embeddings[0]), distance=distance), ) points = [] @@ -132,28 +139,100 @@ class QdrantIndex(EmbeddingIndex): await self.client.delete_collection(collection_name=self.collection_name) -class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): +class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): def __init__( - self, config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, inference_api: Api.inference + self, + config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, + inference_api: Api.inference, + files_api: Files | None, ) -> None: self.config = config self.client: AsyncQdrantClient = None self.cache = {} self.inference_api = inference_api + self.files_api = files_api + self.vector_db_store = None + self.openai_vector_stores: dict[str, dict[str, Any]] = {} async def initialize(self) -> None: self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) + # Load existing OpenAI vector stores using the mixin method + self.openai_vector_stores = await self._load_openai_vector_stores() async def shutdown(self) -> None: await self.client.close() + # OpenAI Vector Store Mixin abstract method implementations + async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Save vector store metadata to Qdrant collection metadata.""" + # Store metadata in a special collection for vector store metadata + metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION + + # Create metadata collection if it doesn't exist + if not await self.client.collection_exists(metadata_collection): + # Get distance metric from config, defaulting to COSINE for backward compatibility + distance_metric = getattr(self.config, "distance_metric", "COSINE") + distance = getattr(models.Distance, distance_metric, models.Distance.COSINE) + + await self.client.create_collection( + collection_name=metadata_collection, + vectors_config=models.VectorParams(size=1, distance=distance), + ) + + # Store metadata as a point with dummy vector + await self.client.upsert( + collection_name=metadata_collection, + points=[ + models.PointStruct( + id=convert_id(store_id), + vector=[0.0], # Dummy vector + payload={"metadata": store_info}, + ) + ], + ) + + async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: + """Load all vector store metadata from Qdrant.""" + metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION + + if not await self.client.collection_exists(metadata_collection): + return {} + + # Get all points from metadata collection + points = await self.client.scroll( + collection_name=metadata_collection, + limit=1000, # Reasonable limit for metadata + with_payload=True, + ) + + stores = {} + for point in points[0]: # points[0] contains the actual points + if point.payload and "metadata" in point.payload: + store_info = point.payload["metadata"] + stores[store_info["id"]] = store_info + + return stores + + async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Update vector store metadata in Qdrant.""" + await self._save_openai_vector_store(store_id, store_info) + + async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: + """Delete vector store metadata from Qdrant.""" + metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION + + if await self.client.collection_exists(metadata_collection): + await self.client.delete( + collection_name=metadata_collection, points_selector=models.PointIdsList(points=[convert_id(store_id)]) + ) + async def register_vector_db( self, vector_db: VectorDB, ) -> None: index = VectorDBWithIndex( vector_db=vector_db, - index=QdrantIndex(self.client, vector_db.identifier), + index=QdrantIndex(self.client, vector_db.identifier, self.config.distance_metric), inference_api=self.inference_api, ) @@ -174,7 +253,9 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): index = VectorDBWithIndex( vector_db=vector_db, - index=QdrantIndex(client=self.client, collection_name=vector_db.identifier), + index=QdrantIndex( + client=self.client, collection_name=vector_db.identifier, distance_metric=self.config.distance_metric + ), inference_api=self.inference_api, ) self.cache[vector_db_id] = index diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 4c061f519..93ae27a8e 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] for p in vector_io_providers: - if p.provider_type in ["inline::faiss", "inline::sqlite-vec"]: + if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::qdrant"]: return pytest.skip("OpenAI vector stores are not supported by any provider") diff --git a/tests/unit/providers/vector_io/test_qdrant.py b/tests/unit/providers/vector_io/test_qdrant.py index 6902c8850..87cd18ce3 100644 --- a/tests/unit/providers/vector_io/test_qdrant.py +++ b/tests/unit/providers/vector_io/test_qdrant.py @@ -70,7 +70,7 @@ def mock_api_service(sample_embeddings): @pytest_asyncio.fixture async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter: - adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service) + adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service, files_api=None) adapter.vector_db_store = mock_vector_db_store await adapter.initialize() yield adapter