mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 16:02:38 +00:00
[WIP] Configurable distance_metric:
- Configurable distance_metric enabled for PGVector. - Added plumbing to support configuring more distance metrics for each vector provider.
This commit is contained in:
parent
658fb2c777
commit
9658581cf7
11 changed files with 187 additions and 18 deletions
|
|
@ -45,10 +45,18 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
|
|||
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(self, client: weaviate.WeaviateClient, collection_name: str, kvstore: KVStore | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
client: weaviate.WeaviateClient,
|
||||
collection_name: str,
|
||||
kvstore: KVStore | None = None,
|
||||
distance_metric: str = "COSINE",
|
||||
):
|
||||
self.client = client
|
||||
self.collection_name = sanitize_collection_name(collection_name, weaviate_format=True)
|
||||
self.kvstore = kvstore
|
||||
self._check_distance_metric_support(distance_metric)
|
||||
self.distance_metric = distance_metric
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
|
@ -82,6 +90,22 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
chunk_ids = [chunk.chunk_id for chunk in chunks_for_deletion]
|
||||
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
|
||||
|
||||
def _check_distance_metric_support(self, distance_metric: str) -> None:
|
||||
"""Check if the distance metric is supported by Weaviate.
|
||||
|
||||
Args:
|
||||
distance_metric: The distance metric to check
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the distance metric is not supported yet
|
||||
"""
|
||||
if distance_metric != "COSINE":
|
||||
# TODO: Implement support for other distance metrics in Weaviate
|
||||
raise NotImplementedError(
|
||||
f"Distance metric '{distance_metric}' is not yet supported by the Weaviate provider. "
|
||||
f"Currently only 'COSINE' is supported. Please use 'COSINE' or switch to a different provider."
|
||||
)
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs vector search using Weaviate's built-in vector search.
|
||||
|
|
@ -329,8 +353,11 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
|||
],
|
||||
)
|
||||
|
||||
distance_metric = vector_store.distance_metric or "COSINE"
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
|
||||
vector_store,
|
||||
WeaviateIndex(client=client, collection_name=sanitized_collection_name, distance_metric=distance_metric),
|
||||
self.inference_api,
|
||||
)
|
||||
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue