mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: Add openAI compatible APIs to QDrant
Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
0ddb293d77
commit
61bddfe70e
6 changed files with 100 additions and 13 deletions
|
@ -4,14 +4,18 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
from .config import QdrantVectorIOConfig
|
from .config import QdrantVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]):
|
||||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||||
|
|
||||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference])
|
assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
files_api = deps.get(Api.files)
|
||||||
|
impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ from llama_stack.schema_utils import json_schema_type
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class QdrantVectorIOConfig(BaseModel):
|
class QdrantVectorIOConfig(BaseModel):
|
||||||
path: str
|
path: str
|
||||||
|
distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ class QdrantVectorIOConfig(BaseModel):
|
||||||
prefix: str | None = None
|
prefix: str | None = None
|
||||||
timeout: int | None = None
|
timeout: int | None = None
|
||||||
host: str | None = None
|
host: str | None = None
|
||||||
|
distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||||
|
|
|
@ -12,6 +12,7 @@ from numpy.typing import NDArray
|
||||||
from qdrant_client import AsyncQdrantClient, models
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
from qdrant_client.models import PointStruct
|
from qdrant_client.models import PointStruct
|
||||||
|
|
||||||
|
from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.inference import InterleavedContent
|
from llama_stack.apis.inference import InterleavedContent
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
|
@ -31,6 +32,7 @@ from llama_stack.apis.vector_io import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||||
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
|
@ -40,6 +42,7 @@ from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
CHUNK_ID_KEY = "_chunk_id"
|
CHUNK_ID_KEY = "_chunk_id"
|
||||||
|
OPENAI_VECTOR_STORES_METADATA_COLLECTION = "openai_vector_stores_metadata"
|
||||||
|
|
||||||
|
|
||||||
def convert_id(_id: str) -> str:
|
def convert_id(_id: str) -> str:
|
||||||
|
@ -54,9 +57,10 @@ def convert_id(_id: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
class QdrantIndex(EmbeddingIndex):
|
class QdrantIndex(EmbeddingIndex):
|
||||||
def __init__(self, client: AsyncQdrantClient, collection_name: str):
|
def __init__(self, client: AsyncQdrantClient, collection_name: str, distance_metric: str = "COSINE"):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
|
self.distance_metric = distance_metric
|
||||||
|
|
||||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||||
assert len(chunks) == len(embeddings), (
|
assert len(chunks) == len(embeddings), (
|
||||||
|
@ -64,9 +68,12 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
)
|
)
|
||||||
|
|
||||||
if not await self.client.collection_exists(self.collection_name):
|
if not await self.client.collection_exists(self.collection_name):
|
||||||
|
# Get distance metric, defaulting to COSINE
|
||||||
|
distance = getattr(models.Distance, self.distance_metric, models.Distance.COSINE)
|
||||||
|
|
||||||
await self.client.create_collection(
|
await self.client.create_collection(
|
||||||
self.collection_name,
|
self.collection_name,
|
||||||
vectors_config=models.VectorParams(size=len(embeddings[0]), distance=models.Distance.COSINE),
|
vectors_config=models.VectorParams(size=len(embeddings[0]), distance=distance),
|
||||||
)
|
)
|
||||||
|
|
||||||
points = []
|
points = []
|
||||||
|
@ -132,28 +139,100 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
await self.client.delete_collection(collection_name=self.collection_name)
|
await self.client.delete_collection(collection_name=self.collection_name)
|
||||||
|
|
||||||
|
|
||||||
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, inference_api: Api.inference
|
self,
|
||||||
|
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
|
||||||
|
inference_api: Api.inference,
|
||||||
|
files_api: Files | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.client: AsyncQdrantClient = None
|
self.client: AsyncQdrantClient = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
self.files_api = files_api
|
||||||
|
self.vector_db_store = None
|
||||||
|
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
||||||
|
# Load existing OpenAI vector stores using the mixin method
|
||||||
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
await self.client.close()
|
await self.client.close()
|
||||||
|
|
||||||
|
# OpenAI Vector Store Mixin abstract method implementations
|
||||||
|
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
|
"""Save vector store metadata to Qdrant collection metadata."""
|
||||||
|
# Store metadata in a special collection for vector store metadata
|
||||||
|
metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION
|
||||||
|
|
||||||
|
# Create metadata collection if it doesn't exist
|
||||||
|
if not await self.client.collection_exists(metadata_collection):
|
||||||
|
# Get distance metric from config, defaulting to COSINE for backward compatibility
|
||||||
|
distance_metric = getattr(self.config, "distance_metric", "COSINE")
|
||||||
|
distance = getattr(models.Distance, distance_metric, models.Distance.COSINE)
|
||||||
|
|
||||||
|
await self.client.create_collection(
|
||||||
|
collection_name=metadata_collection,
|
||||||
|
vectors_config=models.VectorParams(size=1, distance=distance),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store metadata as a point with dummy vector
|
||||||
|
await self.client.upsert(
|
||||||
|
collection_name=metadata_collection,
|
||||||
|
points=[
|
||||||
|
models.PointStruct(
|
||||||
|
id=convert_id(store_id),
|
||||||
|
vector=[0.0], # Dummy vector
|
||||||
|
payload={"metadata": store_info},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Load all vector store metadata from Qdrant."""
|
||||||
|
metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION
|
||||||
|
|
||||||
|
if not await self.client.collection_exists(metadata_collection):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Get all points from metadata collection
|
||||||
|
points = await self.client.scroll(
|
||||||
|
collection_name=metadata_collection,
|
||||||
|
limit=1000, # Reasonable limit for metadata
|
||||||
|
with_payload=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
stores = {}
|
||||||
|
for point in points[0]: # points[0] contains the actual points
|
||||||
|
if point.payload and "metadata" in point.payload:
|
||||||
|
store_info = point.payload["metadata"]
|
||||||
|
stores[store_info["id"]] = store_info
|
||||||
|
|
||||||
|
return stores
|
||||||
|
|
||||||
|
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
|
"""Update vector store metadata in Qdrant."""
|
||||||
|
await self._save_openai_vector_store(store_id, store_info)
|
||||||
|
|
||||||
|
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||||
|
"""Delete vector store metadata from Qdrant."""
|
||||||
|
metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION
|
||||||
|
|
||||||
|
if await self.client.collection_exists(metadata_collection):
|
||||||
|
await self.client.delete(
|
||||||
|
collection_name=metadata_collection, points_selector=models.PointIdsList(points=[convert_id(store_id)])
|
||||||
|
)
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
self,
|
self,
|
||||||
vector_db: VectorDB,
|
vector_db: VectorDB,
|
||||||
) -> None:
|
) -> None:
|
||||||
index = VectorDBWithIndex(
|
index = VectorDBWithIndex(
|
||||||
vector_db=vector_db,
|
vector_db=vector_db,
|
||||||
index=QdrantIndex(self.client, vector_db.identifier),
|
index=QdrantIndex(self.client, vector_db.identifier, self.config.distance_metric),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -174,7 +253,9 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
index = VectorDBWithIndex(
|
index = VectorDBWithIndex(
|
||||||
vector_db=vector_db,
|
vector_db=vector_db,
|
||||||
index=QdrantIndex(client=self.client, collection_name=vector_db.identifier),
|
index=QdrantIndex(
|
||||||
|
client=self.client, collection_name=vector_db.identifier, distance_metric=self.config.distance_metric
|
||||||
|
),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[vector_db_id] = index
|
self.cache[vector_db_id] = index
|
||||||
|
|
|
@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||||
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||||
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
||||||
for p in vector_io_providers:
|
for p in vector_io_providers:
|
||||||
if p.provider_type in ["inline::faiss", "inline::sqlite-vec"]:
|
if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::qdrant"]:
|
||||||
return
|
return
|
||||||
|
|
||||||
pytest.skip("OpenAI vector stores are not supported by any provider")
|
pytest.skip("OpenAI vector stores are not supported by any provider")
|
||||||
|
|
|
@ -70,7 +70,7 @@ def mock_api_service(sample_embeddings):
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
|
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
|
||||||
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
|
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service, files_api=None)
|
||||||
adapter.vector_db_store = mock_vector_db_store
|
adapter.vector_db_store = mock_vector_db_store
|
||||||
await adapter.initialize()
|
await adapter.initialize()
|
||||||
yield adapter
|
yield adapter
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue