[WIP] Configurable distance_metric:

- Configurable distance_metric enabled for PGVector.
- Added plumbing to support configuring more distance metrics for each vector provider.
This commit is contained in:
ChristianZaccaria 2025-10-23 15:29:21 +01:00
parent 658fb2c777
commit 9658581cf7
11 changed files with 187 additions and 18 deletions

View file

@ -56,9 +56,11 @@ def convert_id(_id: str) -> str:
class QdrantIndex(EmbeddingIndex):
def __init__(self, client: AsyncQdrantClient, collection_name: str):
def __init__(self, client: AsyncQdrantClient, collection_name: str, distance_metric: str = "COSINE"):
self.client = client
self.collection_name = collection_name
self._check_distance_metric_support(distance_metric)
self.distance_metric = distance_metric
async def initialize(self) -> None:
# Qdrant collections are created on-demand in add_chunks
@ -144,6 +146,22 @@ class QdrantIndex(EmbeddingIndex):
async def delete(self):
await self.client.delete_collection(collection_name=self.collection_name)
def _check_distance_metric_support(self, distance_metric: str) -> None:
"""Check if the distance metric is supported by Qdrant.
Args:
distance_metric: The distance metric to check
Raises:
NotImplementedError: If the distance metric is not supported yet
"""
if distance_metric != "COSINE":
# TODO: Implement support for other distance metrics in Qdrant
raise NotImplementedError(
f"Distance metric '{distance_metric}' is not yet supported by the Qdrant provider. "
f"Currently only 'COSINE' is supported. Please use 'COSINE' or switch to a different provider."
)
class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(
@ -187,9 +205,10 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
distance_metric = vector_store.distance_metric or "COSINE"
index = VectorStoreWithIndex(
vector_store=vector_store,
index=QdrantIndex(self.client, vector_store.identifier),
index=QdrantIndex(self.client, vector_store.identifier, distance_metric=distance_metric),
inference_api=self.inference_api,
)