[WIP] Configurable distance_metric:

- Configurable distance_metric enabled for PGVector.
- Added plumbing to support configuring more distance metrics for each vector provider.
This commit is contained in:
ChristianZaccaria 2025-10-23 15:29:21 +01:00
parent 658fb2c777
commit 9658581cf7
11 changed files with 187 additions and 18 deletions

View file

@ -45,10 +45,14 @@ async def maybe_await(result):
class ChromaIndex(EmbeddingIndex):
def __init__(self, client: ChromaClientType, collection, kvstore: KVStore | None = None):
def __init__(
self, client: ChromaClientType, collection, kvstore: KVStore | None = None, distance_metric: str = "COSINE"
):
self.client = client
self.collection = collection
self.kvstore = kvstore
self._check_distance_metric_support(distance_metric)
self.distance_metric = distance_metric
async def initialize(self):
pass
@ -102,6 +106,22 @@ class ChromaIndex(EmbeddingIndex):
ids = [f"{chunk.document_id}:{chunk.chunk_id}" for chunk in chunks_for_deletion]
await maybe_await(self.collection.delete(ids=ids))
def _check_distance_metric_support(self, distance_metric: str) -> None:
"""Check if the distance metric is supported by Chroma.
Args:
distance_metric: The distance metric to check
Raises:
NotImplementedError: If the distance metric is not supported yet
"""
if distance_metric != "COSINE":
# TODO: Implement support for other distance metrics in Chroma
raise NotImplementedError(
f"Distance metric '{distance_metric}' is not yet supported by the Chroma provider. "
f"Currently only 'COSINE' is supported. Please use 'COSINE' or switch to a different provider."
)
async def query_hybrid(
self,
embedding: NDArray,
@ -157,8 +177,9 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
name=vector_store.identifier, metadata={"vector_store": vector_store.model_dump_json()}
)
)
distance_metric = vector_store.distance_metric or "COSINE"
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, ChromaIndex(self.client, collection), self.inference_api
vector_store, ChromaIndex(self.client, collection, distance_metric=distance_metric), self.inference_api
)
async def unregister_vector_store(self, vector_store_id: str) -> None: