mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: Add openAI compatible APIs to QDrant
Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
0ddb293d77
commit
61bddfe70e
6 changed files with 100 additions and 13 deletions
|
@ -4,14 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import QdrantVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference])
|
||||
assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
files_api = deps.get(Api.files)
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -15,6 +15,7 @@ from llama_stack.schema_utils import json_schema_type
|
|||
@json_schema_type
|
||||
class QdrantVectorIOConfig(BaseModel):
|
||||
path: str
|
||||
distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -23,6 +23,7 @@ class QdrantVectorIOConfig(BaseModel):
|
|||
prefix: str | None = None
|
||||
timeout: int | None = None
|
||||
host: str | None = None
|
||||
distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
|
|
|
@ -12,6 +12,7 @@ from numpy.typing import NDArray
|
|||
from qdrant_client import AsyncQdrantClient, models
|
||||
from qdrant_client.models import PointStruct
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
|
@ -31,6 +32,7 @@ from llama_stack.apis.vector_io import (
|
|||
)
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
|
@ -40,6 +42,7 @@ from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
CHUNK_ID_KEY = "_chunk_id"
|
||||
OPENAI_VECTOR_STORES_METADATA_COLLECTION = "openai_vector_stores_metadata"
|
||||
|
||||
|
||||
def convert_id(_id: str) -> str:
|
||||
|
@ -54,9 +57,10 @@ def convert_id(_id: str) -> str:
|
|||
|
||||
|
||||
class QdrantIndex(EmbeddingIndex):
|
||||
def __init__(self, client: AsyncQdrantClient, collection_name: str):
|
||||
def __init__(self, client: AsyncQdrantClient, collection_name: str, distance_metric: str = "COSINE"):
|
||||
self.client = client
|
||||
self.collection_name = collection_name
|
||||
self.distance_metric = distance_metric
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
|
@ -64,9 +68,12 @@ class QdrantIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
if not await self.client.collection_exists(self.collection_name):
|
||||
# Get distance metric, defaulting to COSINE
|
||||
distance = getattr(models.Distance, self.distance_metric, models.Distance.COSINE)
|
||||
|
||||
await self.client.create_collection(
|
||||
self.collection_name,
|
||||
vectors_config=models.VectorParams(size=len(embeddings[0]), distance=models.Distance.COSINE),
|
||||
vectors_config=models.VectorParams(size=len(embeddings[0]), distance=distance),
|
||||
)
|
||||
|
||||
points = []
|
||||
|
@ -132,28 +139,100 @@ class QdrantIndex(EmbeddingIndex):
|
|||
await self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
|
||||
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
self, config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, inference_api: Api.inference
|
||||
self,
|
||||
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
|
||||
inference_api: Api.inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.client: AsyncQdrantClient = None
|
||||
self.cache = {}
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
||||
# Load existing OpenAI vector stores using the mixin method
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await self.client.close()
|
||||
|
||||
# OpenAI Vector Store Mixin abstract method implementations
|
||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Save vector store metadata to Qdrant collection metadata."""
|
||||
# Store metadata in a special collection for vector store metadata
|
||||
metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION
|
||||
|
||||
# Create metadata collection if it doesn't exist
|
||||
if not await self.client.collection_exists(metadata_collection):
|
||||
# Get distance metric from config, defaulting to COSINE for backward compatibility
|
||||
distance_metric = getattr(self.config, "distance_metric", "COSINE")
|
||||
distance = getattr(models.Distance, distance_metric, models.Distance.COSINE)
|
||||
|
||||
await self.client.create_collection(
|
||||
collection_name=metadata_collection,
|
||||
vectors_config=models.VectorParams(size=1, distance=distance),
|
||||
)
|
||||
|
||||
# Store metadata as a point with dummy vector
|
||||
await self.client.upsert(
|
||||
collection_name=metadata_collection,
|
||||
points=[
|
||||
models.PointStruct(
|
||||
id=convert_id(store_id),
|
||||
vector=[0.0], # Dummy vector
|
||||
payload={"metadata": store_info},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||
"""Load all vector store metadata from Qdrant."""
|
||||
metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION
|
||||
|
||||
if not await self.client.collection_exists(metadata_collection):
|
||||
return {}
|
||||
|
||||
# Get all points from metadata collection
|
||||
points = await self.client.scroll(
|
||||
collection_name=metadata_collection,
|
||||
limit=1000, # Reasonable limit for metadata
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
stores = {}
|
||||
for point in points[0]: # points[0] contains the actual points
|
||||
if point.payload and "metadata" in point.payload:
|
||||
store_info = point.payload["metadata"]
|
||||
stores[store_info["id"]] = store_info
|
||||
|
||||
return stores
|
||||
|
||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Update vector store metadata in Qdrant."""
|
||||
await self._save_openai_vector_store(store_id, store_info)
|
||||
|
||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||
"""Delete vector store metadata from Qdrant."""
|
||||
metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION
|
||||
|
||||
if await self.client.collection_exists(metadata_collection):
|
||||
await self.client.delete(
|
||||
collection_name=metadata_collection, points_selector=models.PointIdsList(points=[convert_id(store_id)])
|
||||
)
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=QdrantIndex(self.client, vector_db.identifier),
|
||||
index=QdrantIndex(self.client, vector_db.identifier, self.config.distance_metric),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
|
@ -174,7 +253,9 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=QdrantIndex(client=self.client, collection_name=vector_db.identifier),
|
||||
index=QdrantIndex(
|
||||
client=self.client, collection_name=vector_db.identifier, distance_metric=self.config.distance_metric
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
|
|
|
@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
|||
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
||||
for p in vector_io_providers:
|
||||
if p.provider_type in ["inline::faiss", "inline::sqlite-vec"]:
|
||||
if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::qdrant"]:
|
||||
return
|
||||
|
||||
pytest.skip("OpenAI vector stores are not supported by any provider")
|
||||
|
|
|
@ -70,7 +70,7 @@ def mock_api_service(sample_embeddings):
|
|||
|
||||
@pytest_asyncio.fixture
|
||||
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
|
||||
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
|
||||
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service, files_api=None)
|
||||
adapter.vector_db_store = mock_vector_db_store
|
||||
await adapter.initialize()
|
||||
yield adapter
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue