feat: Add openAI compatible APIs to QDrant

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha Prasad Narsing 2025-06-17 16:38:02 -07:00
parent 0ddb293d77
commit 61bddfe70e
6 changed files with 100 additions and 13 deletions

View file

@ -4,14 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.providers.datatypes import Api, ProviderSpec from typing import Any
from llama_stack.providers.datatypes import Api
from .config import QdrantVectorIOConfig from .config import QdrantVectorIOConfig
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]): async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]):
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
impl = QdrantVectorIOAdapter(config, deps[Api.inference]) assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
files_api = deps.get(Api.files)
impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any from typing import Any, Literal
from pydantic import BaseModel from pydantic import BaseModel
@ -15,6 +15,7 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class QdrantVectorIOConfig(BaseModel): class QdrantVectorIOConfig(BaseModel):
path: str path: str
distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE"
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any from typing import Any, Literal
from pydantic import BaseModel from pydantic import BaseModel
@ -23,6 +23,7 @@ class QdrantVectorIOConfig(BaseModel):
prefix: str | None = None prefix: str | None = None
timeout: int | None = None timeout: int | None = None
host: str | None = None host: str | None = None
distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE"
@classmethod @classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:

View file

@ -12,6 +12,7 @@ from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from llama_stack.apis.files import Files
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import ( from llama_stack.apis.vector_io import (
@ -31,6 +32,7 @@ from llama_stack.apis.vector_io import (
) )
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
@ -40,6 +42,7 @@ from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id" CHUNK_ID_KEY = "_chunk_id"
OPENAI_VECTOR_STORES_METADATA_COLLECTION = "openai_vector_stores_metadata"
def convert_id(_id: str) -> str: def convert_id(_id: str) -> str:
@ -54,9 +57,10 @@ def convert_id(_id: str) -> str:
class QdrantIndex(EmbeddingIndex): class QdrantIndex(EmbeddingIndex):
def __init__(self, client: AsyncQdrantClient, collection_name: str): def __init__(self, client: AsyncQdrantClient, collection_name: str, distance_metric: str = "COSINE"):
self.client = client self.client = client
self.collection_name = collection_name self.collection_name = collection_name
self.distance_metric = distance_metric
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
@ -64,9 +68,12 @@ class QdrantIndex(EmbeddingIndex):
) )
if not await self.client.collection_exists(self.collection_name): if not await self.client.collection_exists(self.collection_name):
# Get distance metric, defaulting to COSINE
distance = getattr(models.Distance, self.distance_metric, models.Distance.COSINE)
await self.client.create_collection( await self.client.create_collection(
self.collection_name, self.collection_name,
vectors_config=models.VectorParams(size=len(embeddings[0]), distance=models.Distance.COSINE), vectors_config=models.VectorParams(size=len(embeddings[0]), distance=distance),
) )
points = [] points = []
@ -132,28 +139,100 @@ class QdrantIndex(EmbeddingIndex):
await self.client.delete_collection(collection_name=self.collection_name) await self.client.delete_collection(collection_name=self.collection_name)
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__( def __init__(
self, config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, inference_api: Api.inference self,
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
inference_api: Api.inference,
files_api: Files | None,
) -> None: ) -> None:
self.config = config self.config = config
self.client: AsyncQdrantClient = None self.client: AsyncQdrantClient = None
self.cache = {} self.cache = {}
self.inference_api = inference_api self.inference_api = inference_api
self.files_api = files_api
self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
async def initialize(self) -> None: async def initialize(self) -> None:
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
# Load existing OpenAI vector stores using the mixin method
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
await self.client.close() await self.client.close()
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to Qdrant collection metadata."""
# Store metadata in a special collection for vector store metadata
metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION
# Create metadata collection if it doesn't exist
if not await self.client.collection_exists(metadata_collection):
# Get distance metric from config, defaulting to COSINE for backward compatibility
distance_metric = getattr(self.config, "distance_metric", "COSINE")
distance = getattr(models.Distance, distance_metric, models.Distance.COSINE)
await self.client.create_collection(
collection_name=metadata_collection,
vectors_config=models.VectorParams(size=1, distance=distance),
)
# Store metadata as a point with dummy vector
await self.client.upsert(
collection_name=metadata_collection,
points=[
models.PointStruct(
id=convert_id(store_id),
vector=[0.0], # Dummy vector
payload={"metadata": store_info},
)
],
)
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from Qdrant."""
metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION
if not await self.client.collection_exists(metadata_collection):
return {}
# Get all points from metadata collection
points = await self.client.scroll(
collection_name=metadata_collection,
limit=1000, # Reasonable limit for metadata
with_payload=True,
)
stores = {}
for point in points[0]: # points[0] contains the actual points
if point.payload and "metadata" in point.payload:
store_info = point.payload["metadata"]
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in Qdrant."""
await self._save_openai_vector_store(store_id, store_info)
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from Qdrant."""
metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION
if await self.client.collection_exists(metadata_collection):
await self.client.delete(
collection_name=metadata_collection, points_selector=models.PointIdsList(points=[convert_id(store_id)])
)
async def register_vector_db( async def register_vector_db(
self, self,
vector_db: VectorDB, vector_db: VectorDB,
) -> None: ) -> None:
index = VectorDBWithIndex( index = VectorDBWithIndex(
vector_db=vector_db, vector_db=vector_db,
index=QdrantIndex(self.client, vector_db.identifier), index=QdrantIndex(self.client, vector_db.identifier, self.config.distance_metric),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
@ -174,7 +253,9 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
index = VectorDBWithIndex( index = VectorDBWithIndex(
vector_db=vector_db, vector_db=vector_db,
index=QdrantIndex(client=self.client, collection_name=vector_db.identifier), index=QdrantIndex(
client=self.client, collection_name=vector_db.identifier, distance_metric=self.config.distance_metric
),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[vector_db_id] = index self.cache[vector_db_id] = index

View file

@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
for p in vector_io_providers: for p in vector_io_providers:
if p.provider_type in ["inline::faiss", "inline::sqlite-vec"]: if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::qdrant"]:
return return
pytest.skip("OpenAI vector stores are not supported by any provider") pytest.skip("OpenAI vector stores are not supported by any provider")

View file

@ -70,7 +70,7 @@ def mock_api_service(sample_embeddings):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter: async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service) adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service, files_api=None)
adapter.vector_db_store = mock_vector_db_store adapter.vector_db_store = mock_vector_db_store
await adapter.initialize() await adapter.initialize()
yield adapter yield adapter