feat: Add openAI compatible APIs to QDrant

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha Prasad Narsing 2025-06-17 16:38:02 -07:00
parent 0ddb293d77
commit 61bddfe70e
6 changed files with 100 additions and 13 deletions

View file

@ -4,14 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.datatypes import Api, ProviderSpec
from typing import Any
from llama_stack.providers.datatypes import Api
from .config import QdrantVectorIOConfig
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]):
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
impl = QdrantVectorIOAdapter(config, deps[Api.inference])
assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
files_api = deps.get(Api.files)
impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api)
await impl.initialize()
return impl

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from typing import Any
from typing import Any, Literal
from pydantic import BaseModel
@ -15,6 +15,7 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class QdrantVectorIOConfig(BaseModel):
path: str
distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE"
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from typing import Any, Literal
from pydantic import BaseModel
@ -23,6 +23,7 @@ class QdrantVectorIOConfig(BaseModel):
prefix: str | None = None
timeout: int | None = None
host: str | None = None
distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE"
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:

View file

@ -12,6 +12,7 @@ from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct
from llama_stack.apis.files import Files
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
@ -31,6 +32,7 @@ from llama_stack.apis.vector_io import (
)
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
@ -40,6 +42,7 @@ from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id"
OPENAI_VECTOR_STORES_METADATA_COLLECTION = "openai_vector_stores_metadata"
def convert_id(_id: str) -> str:
@ -54,9 +57,10 @@ def convert_id(_id: str) -> str:
class QdrantIndex(EmbeddingIndex):
def __init__(self, client: AsyncQdrantClient, collection_name: str):
def __init__(self, client: AsyncQdrantClient, collection_name: str, distance_metric: str = "COSINE"):
self.client = client
self.collection_name = collection_name
self.distance_metric = distance_metric
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
@ -64,9 +68,12 @@ class QdrantIndex(EmbeddingIndex):
)
if not await self.client.collection_exists(self.collection_name):
# Get distance metric, defaulting to COSINE
distance = getattr(models.Distance, self.distance_metric, models.Distance.COSINE)
await self.client.create_collection(
self.collection_name,
vectors_config=models.VectorParams(size=len(embeddings[0]), distance=models.Distance.COSINE),
vectors_config=models.VectorParams(size=len(embeddings[0]), distance=distance),
)
points = []
@ -132,28 +139,100 @@ class QdrantIndex(EmbeddingIndex):
await self.client.delete_collection(collection_name=self.collection_name)
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(
self, config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, inference_api: Api.inference
self,
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
inference_api: Api.inference,
files_api: Files | None,
) -> None:
self.config = config
self.client: AsyncQdrantClient = None
self.cache = {}
self.inference_api = inference_api
self.files_api = files_api
self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
async def initialize(self) -> None:
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
# Load existing OpenAI vector stores using the mixin method
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
await self.client.close()
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to Qdrant collection metadata."""
# Store metadata in a special collection for vector store metadata
metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION
# Create metadata collection if it doesn't exist
if not await self.client.collection_exists(metadata_collection):
# Get distance metric from config, defaulting to COSINE for backward compatibility
distance_metric = getattr(self.config, "distance_metric", "COSINE")
distance = getattr(models.Distance, distance_metric, models.Distance.COSINE)
await self.client.create_collection(
collection_name=metadata_collection,
vectors_config=models.VectorParams(size=1, distance=distance),
)
# Store metadata as a point with dummy vector
await self.client.upsert(
collection_name=metadata_collection,
points=[
models.PointStruct(
id=convert_id(store_id),
vector=[0.0], # Dummy vector
payload={"metadata": store_info},
)
],
)
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from Qdrant."""
metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION
if not await self.client.collection_exists(metadata_collection):
return {}
# Get all points from metadata collection
points = await self.client.scroll(
collection_name=metadata_collection,
limit=1000, # Reasonable limit for metadata
with_payload=True,
)
stores = {}
for point in points[0]: # points[0] contains the actual points
if point.payload and "metadata" in point.payload:
store_info = point.payload["metadata"]
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in Qdrant."""
await self._save_openai_vector_store(store_id, store_info)
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from Qdrant."""
metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION
if await self.client.collection_exists(metadata_collection):
await self.client.delete(
collection_name=metadata_collection, points_selector=models.PointIdsList(points=[convert_id(store_id)])
)
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
index = VectorDBWithIndex(
vector_db=vector_db,
index=QdrantIndex(self.client, vector_db.identifier),
index=QdrantIndex(self.client, vector_db.identifier, self.config.distance_metric),
inference_api=self.inference_api,
)
@ -174,7 +253,9 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
index = VectorDBWithIndex(
vector_db=vector_db,
index=QdrantIndex(client=self.client, collection_name=vector_db.identifier),
index=QdrantIndex(
client=self.client, collection_name=vector_db.identifier, distance_metric=self.config.distance_metric
),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index

View file

@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
for p in vector_io_providers:
if p.provider_type in ["inline::faiss", "inline::sqlite-vec"]:
if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::qdrant"]:
return
pytest.skip("OpenAI vector stores are not supported by any provider")

View file

@ -70,7 +70,7 @@ def mock_api_service(sample_embeddings):
@pytest_asyncio.fixture
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service, files_api=None)
adapter.vector_db_store = mock_vector_db_store
await adapter.initialize()
yield adapter